Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve BitonicSort performance for sorting floats #952

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 89 additions & 0 deletions benchmark/bench_sort.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
module BenchSort

using BenchmarkTools
using Random: rand!
using StaticArrays
using StaticArrays: BitonicSort

const SUITE = BenchmarkGroup()

# 1 second is sufficient for reasonably consistent timings.
BenchmarkTools.DEFAULT_PARAMETERS.seconds = 1

const LEN = 1000

const Floats = (Float16, Float32, Float64)
const Ints = (Int8, Int16, Int32, Int64, Int128)
const UInts = (UInt8, UInt16, UInt32, UInt64, UInt128)

map_sort!(vs; kwargs...) = map!(v -> sort(v; kwargs...), vs, vs)

addgroup!(SUITE, "BitonicSort")

g = addgroup!(SUITE["BitonicSort"], "SVector")
for lt in (isless, <)
n = 1
while (n = nextprod([2, 3], n + 1)) <= 24
for T in (Floats..., Ints..., UInts...)
(lt === <) && (T <: Integer) && continue # For Integers, isless is <.
vs = Vector{SVector{n, T}}(undef, LEN)
g[lt, n, T] = @benchmarkable(
map_sort!($vs; alg=BitonicSort, lt=$lt),
evals=1, # Redundant on @benchmarkable as of BenchmarkTools 1.1.3.
# We need evals=1 so that setup runs before every eval. But PkgBenchmark
# always `tunes!` benchmarks before running, which overrides this. As a
# workaround, use the unhygienic symbol `__params` to set evals just before
# execution at
# https://github.com/JuliaCI/BenchmarkTools.jl/blob/v1.1.3/src/execution.jl#L482
# See also: https://github.com/JuliaCI/PkgBenchmark.jl/issues/120
setup=(__params.evals = 1; rand!($vs)),
)
end
end
end

g = addgroup!(SUITE["BitonicSort"], "MVector")
for (lt, n, T) in ((isless, 16, Int64), (isless, 16, Float64), (<, 16, Float64))
vs = Vector{MVector{n, T}}(undef, LEN)
g[lt, n, T] = @benchmarkable(
map_sort!($vs; alg=BitonicSort, lt=$lt),
evals=1,
setup=(__params.evals = 1; rand!($vs)),
)
end

g = addgroup!(SUITE["BitonicSort"], "SizedVector")
for (lt, n, T) in ((isless, 16, Int64), (isless, 16, Float64), (<, 16, Float64))
vs = Vector{SizedVector{n, T, Vector{T}}}(undef, LEN)
g[lt, n, T] = @benchmarkable(
map_sort!($vs; alg=BitonicSort, lt=$lt),
evals=1,
setup=(__params.evals = 1; rand!($vs)),
)
end

function map_floats_nans!(vs::Vector{SVector{N, T}}, p) where {N, T}
@inline _rand(_) = ifelse(rand(Float32) < p, T(NaN), rand(T))
for i in eachindex(vs)
@inbounds vs[i] = SVector(ntuple(_rand, Val(N)))
end
return vs
end

g = addgroup!(SUITE["BitonicSort"], "NaNs")
for p in (0.001, 0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 1.0)
(lt, n, T) = (isless, 16, Float64)
vs = Vector{SVector{n, T}}(undef, LEN)
g[lt, n, T, p] = @benchmarkable(
map_sort!($vs; alg=BitonicSort, lt=$lt),
evals=1,
setup=(__params.evals = 1; map_floats_nans!($vs, $p)),
)
end

end # module BenchSort

# Allow PkgBenchmark.benchmarkpkg to call this file directly.
@isdefined(SUITE) || (SUITE = BenchSort.SUITE)

BenchSort.SUITE
1 change: 1 addition & 0 deletions src/StaticArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ include("indexing.jl")
include("broadcast.jl")
include("mapreduce.jl")
include("sort.jl")
using .Sort
include("arraymath.jl")
include("linalg.jl")
include("matrix_multiply_add.jl")
Expand Down
98 changes: 86 additions & 12 deletions src/sort.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,30 @@
import Base.Order: Forward, Ordering, Perm, ord
import Base.Sort: Algorithm, lt, sort, sortperm
module Sort

import Base: sort, sortperm

using ..StaticArrays
using Base: @_inline_meta
using Base.Order: Forward, Ordering, Perm, Reverse, ord
using Base.Sort: Algorithm, lt

export BitonicSort

struct BitonicSortAlg <: Algorithm end

# For consistency with Julia Base, track their *Sort docstring text in base/sort.jl.
"""
StaticArrays.BitonicSort

Indicate that a sorting function should use a bitonic sorting network, which is *not*
stable. By default, `StaticVector`s with at most 20 elements are sorted with `BitonicSort`.

Characteristics:
* *not stable*: does not preserve the ordering of elements which compare equal (e.g. "a"
and "A" in a sort of letters which ignores case).
* *in-place* in memory.
* *good performance* for small collections.
* compilation time increases dramatically with the number of elements.
"""
const BitonicSort = BitonicSortAlg()


Expand All @@ -19,8 +40,7 @@ defalg(a::StaticVector) =
rev::Union{Bool,Nothing} = nothing,
order::Ordering = Forward)
length(a) <= 1 && return a
ordr = ord(lt, by, rev, order)
return _sort(a, alg, ordr)
return _sort(a, alg, lt, by, rev, order)
end

@inline function sortperm(a::StaticVector;
Expand All @@ -33,21 +53,73 @@ end
length(a) <= 1 && return SVector{length(a),Int}(p)

ordr = Perm(ord(lt, by, rev, order), a)
return SVector{length(a),Int}(_sort(p, alg, ordr))
return SVector{length(a),Int}(_sort(p, alg, isless, identity, nothing, ordr))
vyu marked this conversation as resolved.
Show resolved Hide resolved
end

@inline _sort(a::StaticVector, alg, lt, by, rev, order) =
similar_type(a)(sort!(Base.copymutable(a); alg=alg, lt=lt, by=by, rev=rev, order=order))

@inline _sort(a::StaticVector, alg::BitonicSortAlg, lt, by, rev, order) =
similar_type(a)(_sort(Tuple(a), alg, lt, by, rev, order))

@inline _sort(a::NTuple, alg, lt, by, rev, order) =
sort!(Base.copymutable(a); alg=alg, lt=lt, by=by, rev=rev, order=order)

@inline _sort(a::NTuple, ::BitonicSortAlg, lt, by, rev, order) =
_bitonic_sort(a, ord(lt, by, rev, order))

# For better performance sorting floats under the isless relation, apply an order-preserving
# bijection to sort them as integers.
@inline function _sort(
a::NTuple{N, <:Base.IEEEFloat},
::BitonicSortAlg,
lt::typeof(isless),
by::Union{typeof.((identity, +, -))...},
rev::Union{Bool, Nothing},
order,
) where N
# Exclude N == 2 to avoid a performance regression on AArch64.
if N > 2 && (order === Forward || order === Reverse)
_rev = xor(by === -, rev === true, order === Reverse)
return _intfp.(_bitonic_sort(_fpint.(a), ord(isless, identity, _rev, Forward)))
end
return _bitonic_sort(a, ord(lt, by, rev, order))
end

@inline _sort(a::StaticVector, alg, order) =
similar_type(a)(sort!(Base.copymutable(a); alg=alg, order=order))

@inline _sort(a::StaticVector, alg::BitonicSortAlg, order) =
similar_type(a)(_sort(Tuple(a), alg, order))
_inttype(::Type{Float64}) = Int64
_inttype(::Type{Float32}) = Int32
_inttype(::Type{Float16}) = Int16

_floattype(::Type{Int64}) = Float64
_floattype(::Type{Int32}) = Float32
_floattype(::Type{Int16}) = Float16

# Modified from the _fpint function added to base/float.jl in Julia 1.7. This is a strictly
# increasing function with respect to the isless relation. `isless` is trichotomous with the
# isequal relation and treats every NaN as identical. This function on the other hand
# distinguishes between NaNs with different payloads and signs, but this difference is
# inconsequential for unstable sorting. The `offset` is necessary because NaNs (in
# particular, those with the sign bit set) must be mapped to the greatest Ints, which is
# Julia-specific.
@inline function _fpint(x::F) where F
I = _inttype(F)
offset = Base.significand_mask(F) % I
n = reinterpret(I, x)
return ifelse(n < zero(I), n ⊻ typemax(I), n) - offset
end

_sort(a::NTuple, alg, order) = sort!(Base.copymutable(a); alg=alg, order=order)
# Inverse of _fpint.
@inline function _intfp(n::I) where I
F = _floattype(I)
offset = Base.significand_mask(F) % I
n += offset
n = ifelse(n < zero(I), n ⊻ typemax(I), n)
return reinterpret(F, n)
end

# Implementation loosely following
# https://www.inf.hs-flensburg.de/lang/algorithmen/sortieren/bitonic/oddn.htm
@generated function _sort(a::NTuple{N}, ::BitonicSortAlg, order) where N
@generated function _bitonic_sort(a::NTuple{N}, order) where N
function swap_expr(i, j, rev)
ai = Symbol('a', i)
aj = Symbol('a', j)
Expand Down Expand Up @@ -87,3 +159,5 @@ _sort(a::NTuple, alg, order) = sort!(Base.copymutable(a); alg=alg, order=order)
return ($(symlist...),)
end
end

end # module Sort
90 changes: 89 additions & 1 deletion test/sort.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
module SortTests

using StaticArrays, Test
using StaticArrays.Sort: _inttype
using Base.Order: Forward, Reverse

@testset "sort" begin

Expand Down Expand Up @@ -30,4 +34,88 @@ using StaticArrays, Test
@test sortperm(SA[1, 1, 1, 0]) == SA[4, 1, 2, 3]
end

end
@testset "NaNs" begin
# Return an SVector with floats and NaNs that have random sign and payload bits.
function floats_randnans(::Type{SVector{N, T}}, p) where {N, T}
float_or(x, y) = reinterpret(T, |(reinterpret.(_inttype(T), (x, y))...))
@inline function _rand(_)
r = rand(T)
# The bitwise or of any T with T(Inf) is either ±T(Inf) or a NaN.
ifelse(rand(Float32) < p, float_or(typemax(T), r - T(0.5)), r)
end
return SVector(ntuple(_rand, Val(N)))
end

# Sort floats and arbitrary NaNs.
for T in (Float16, Float32, Float64)
buffer = Vector{T}(undef, 16)
@test all(floats_randnans(SVector{16, T}, 0.5) for _ in 1:10_000) do a
copyto!(buffer, a)
isequal(sort(a), sort!(buffer))
end
end

# Sort signed Infs, signed zeros, and signed NaNs with extremal payloads.
for T in (Float16, Float32, Float64)
U = _inttype(T)
small_nan = reinterpret(T, reinterpret(U, typemax(T)) + one(U))
large_nan = reinterpret(T, typemax(U))
nans = (small_nan, large_nan, T(NaN), -small_nan, -large_nan, -T(NaN))
(a, b, c, d) = (-T(Inf), -zero(T), zero(T), T(Inf))
sorted = [a, b, c, d, nans..., nans...]
@test isequal(sorted, sort(SA[nans..., d, c, b, a, nans...]))
@test isequal(sorted, sort(SA[d, c, nans..., nans..., b, a]))
end
end

# These tests are selected and modified from Julia's test/ordering.jl and test/sorting.jl.
@testset "Base tests" begin
# This testset partially fails on Julia versions < 1.5 because order could be
# discarded: https://github.com/JuliaLang/julia/pull/34719
if VERSION >= v"1.5"
@testset "ordering" begin
for T in (Int, Float64)
for (s1, rev) in enumerate([nothing, true, false])
for (s2, lt) in enumerate([>, <, (a, b) -> a - b > 0, (a, b) -> a - b < 0])
for (s3, by) in enumerate([-, +])
for (s4, order) in enumerate([Reverse, Forward])
if isodd(s1 + s2 + s3 + s4)
target = T.(SA[1, 2, 3])
else
target = T.(SA[3, 2, 1])
end
@test target == sort(T.(SA[2, 3, 1]), rev=rev, lt=lt, by=by, order=order)
end
end
end
end
end

@test SA[1 => 3, 2 => 5, 3 => 1] ==
sort(SA[1 => 3, 2 => 5, 3 => 1]) ==
sort(SA[1 => 3, 2 => 5, 3 => 1], by=first) ==
sort(SA[1 => 3, 2 => 5, 3 => 1], rev=true, order=Reverse) ==
sort(SA[1 => 3, 2 => 5, 3 => 1], lt= >, order=Reverse)

@test SA[3 => 1, 1 => 3, 2 => 5] ==
sort(SA[1 => 3, 2 => 5, 3 => 1], by=last) ==
sort(SA[1 => 3, 2 => 5, 3 => 1], by=last, rev=true, order=Reverse) ==
sort(SA[1 => 3, 2 => 5, 3 => 1], by=last, lt= >, order=Reverse)
end
end

@testset "sort" begin
for T in (Int, Float64)
@test sort(T.(SA[2,3,1])) == T.(SA[1,2,3]) == sort(T.(SA[2,3,1]); order=Forward)
@test sort(T.(SA[2,3,1]), rev=true) == T.(SA[3,2,1]) == sort(T.(SA[2,3,1]), order=Reverse)
end
@test sort(SA['z':-1:'a'...]) == SA['a':'z'...]
@test sort(SA['a':'z'...], rev=true) == SA['z':-1:'a'...]
end

@test sortperm(SA[2,3,1]) == SA[3,1,2]
end

end # @testset "sort"

end # module SortTests