Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve BitonicSort performance for sorting floats #952

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 97 additions & 0 deletions benchmark/bench_sort.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
module BenchSort

using BenchmarkTools
using Random: rand!
using StaticArrays
using StaticArrays: BitonicSort

const SUITE = BenchmarkGroup()

# 1 second is sufficient for reasonably consistent timings.
BenchmarkTools.DEFAULT_PARAMETERS.seconds = 1

const LEN = 1000

const Floats = (Float16, Float32, Float64)
const Ints = (Int8, Int16, Int32, Int64, Int128)
const UInts = (UInt8, UInt16, UInt32, UInt64, UInt128)

map_sort!(vs; kwargs...) = map!(v -> sort(v; kwargs...), vs, vs)

addgroup!(SUITE, "BitonicSort")

g = addgroup!(SUITE["BitonicSort"], "SVector")
for lt in (isless, <)
n = 1
while (n = nextprod([2, 3], n + 1)) <= 24
for T in (Floats..., Ints..., UInts...)
(lt === <) && (T <: Integer) && continue # For Integers, isless is <.
vs = Vector{SVector{n, T}}(undef, LEN)
g[lt, n, T] = @benchmarkable(
map_sort!($vs; alg=BitonicSort, lt=$lt),
evals=1, # Redundant on @benchmarkable as of BenchmarkTools 1.1.3.
# We need evals=1 so that setup runs before every eval. But PkgBenchmark
# always `tunes!` benchmarks before running, which overrides this. As a
# workaround, use the unhygienic symbol `__params` to set evals just before
# execution at
# https://github.com/JuliaCI/BenchmarkTools.jl/blob/v1.1.3/src/execution.jl#L482
# See also: https://github.com/JuliaCI/PkgBenchmark.jl/issues/120
setup=(__params.evals = 1; rand!($vs)),
)
end
end
end

g = addgroup!(SUITE["BitonicSort"], "MVector")
for (lt, n, T) in ((isless, 16, Int64), (isless, 16, Float64), (<, 16, Float64))
vs = Vector{MVector{n, T}}(undef, LEN)
g[lt, n, T] = @benchmarkable(
map_sort!($vs; alg=BitonicSort, lt=$lt),
evals=1,
setup=(__params.evals = 1; rand!($vs)),
)
end

g = addgroup!(SUITE["BitonicSort"], "SizedVector")
for (lt, n, T) in ((isless, 16, Int64), (isless, 16, Float64), (<, 16, Float64))
vs = Vector{SizedVector{n, T, Vector{T}}}(undef, LEN)
g[lt, n, T] = @benchmarkable(
map_sort!($vs; alg=BitonicSort, lt=$lt),
evals=1,
setup=(__params.evals = 1; rand!($vs)),
)
end

# @generated to unroll the tuple.
@generated function floats_nans(::Type{SVector{N, T}}, p) where {N, T}
exprs = (:(ifelse(rand(Float32) < p, T(NaN), rand(T))) for _ in 1:N)
return quote
Base.@_inline_meta
return SVector(($(exprs...),))
end
end

function floats_nans!(vs::Vector{SVector{N, T}}, p) where {N, T}
for i in eachindex(vs)
@inbounds vs[i] = floats_nans(SVector{N, T}, p)
end
return vs
end

g = addgroup!(SUITE["BitonicSort"], "NaNs")
for p in (0.001, 0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 1.0)
(lt, n, T) = (isless, 16, Float64)
vs = Vector{SVector{n, T}}(undef, LEN)
g[lt, n, T, p] = @benchmarkable(
map_sort!($vs; alg=BitonicSort, lt=$lt),
evals=1,
setup=(__params.evals = 1; floats_nans!($vs, $p)),
)
end

end # module BenchSort

# Allow PkgBenchmark.benchmarkpkg to call this file directly.
@isdefined(SUITE) || (SUITE = BenchSort.SUITE)

BenchSort.SUITE
1 change: 1 addition & 0 deletions src/StaticArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ include("indexing.jl")
include("broadcast.jl")
include("mapreduce.jl")
include("sort.jl")
using .Sort
include("arraymath.jl")
include("linalg.jl")
include("matrix_multiply_add.jl")
Expand Down
106 changes: 95 additions & 11 deletions src/sort.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,30 @@
import Base.Order: Forward, Ordering, Perm, ord
import Base.Sort: Algorithm, lt, sort, sortperm
module Sort

import Base: sort, sortperm

using ..StaticArrays
using Base: @_inline_meta
using Base.Order: Forward, Ordering, Perm, Reverse, ord
using Base.Sort: Algorithm, lt

export BitonicSort

struct BitonicSortAlg <: Algorithm end

# For consistency with Julia Base, track their *Sort docstring text in base/sort.jl.
"""
StaticArrays.BitonicSort

Indicate that a sorting function should use a bitonic sorting network, which is *not*
stable. By default, `StaticVector`s with at most 20 elements are sorted with `BitonicSort`.

Characteristics:
* *not stable*: does not preserve the ordering of elements which compare equal (e.g. "a"
and "A" in a sort of letters which ignores case).
* *in-place* in memory.
* *good performance* for small collections.
* compilation time increases dramatically with the number of elements.
"""
const BitonicSort = BitonicSortAlg()


Expand All @@ -19,8 +40,7 @@ defalg(a::StaticVector) =
rev::Union{Bool,Nothing} = nothing,
order::Ordering = Forward)
length(a) <= 1 && return a
ordr = ord(lt, by, rev, order)
return _sort(a, alg, ordr)
return _sort(a, alg, lt, by, rev, order)
end

@inline function sortperm(a::StaticVector;
Expand All @@ -33,21 +53,83 @@ end
length(a) <= 1 && return SVector{length(a),Int}(p)

ordr = Perm(ord(lt, by, rev, order), a)
return SVector{length(a),Int}(_sort(p, alg, ordr))
return SVector{length(a),Int}(_sort(p, alg, isless, identity, nothing, ordr))
vyu marked this conversation as resolved.
Show resolved Hide resolved
end

@inline _sort(a::StaticVector, alg, lt, by, rev, order) =
similar_type(a)(sort!(Base.copymutable(a); alg=alg, lt=lt, by=by, rev=rev, order=order))

@inline _sort(a::StaticVector, alg::BitonicSortAlg, lt, by, rev, order) =
similar_type(a)(_sort(Tuple(a), alg, lt, by, rev, order))

@inline _sort(a::NTuple, alg, lt, by, rev, order) =
sort!(Base.copymutable(a); alg=alg, lt=lt, by=by, rev=rev, order=order)

@inline _sort(a::NTuple, ::BitonicSortAlg, lt, by, rev, order) =
_bitonic_sort(a, ord(lt, by, rev, order))

@inline _sort(a::StaticVector, alg, order) =
similar_type(a)(sort!(Base.copymutable(a); alg=alg, order=order))
# For better performance sorting floats under the isless relation, apply an order-preserving
# bijection to sort them as integers.
@inline function _sort(
a::NTuple{N, <:Union{Float16, Float32, Float64}}, ::BitonicSortAlg, lt, by, rev, order
vyu marked this conversation as resolved.
Show resolved Hide resolved
) where N
# Skip this special treatment when N = 2 to avoid a performance regression on AArch64.
N <= 2 && return _bitonic_sort(a, ord(lt, by, rev, order))
lt_rev = _simplify_order(lt, by, rev, order)
if lt_rev === nothing || lt_rev[1] !== isless
return _bitonic_sort(a, ord(lt, by, rev, order))
end
return _intfp.(_bitonic_sort(_fpint.(a), ord(isless, identity, lt_rev[2], Forward)))
end

@inline _sort(a::StaticVector, alg::BitonicSortAlg, order) =
similar_type(a)(_sort(Tuple(a), alg, order))
# Given the order ord(lt, by, rev, order) on floats or integers, attempt to simplify it to
# ord(_lt, identity, _rev, Forward), where _rev is a Bool and _lt is isless or <. If
# successful, return (_lt, _rev). Otherwise, return `nothing`.
vyu marked this conversation as resolved.
Show resolved Hide resolved
@inline function _simplify_order(lt, by, rev::Union{Bool, Nothing}, order::Ordering)
(
any(Ref(lt) .=== (isless, <, >)) &&
vyu marked this conversation as resolved.
Show resolved Hide resolved
any(Ref(by) .=== (identity, +, -)) &&
any(Ref(order) .=== (Forward, Reverse))
) || return nothing
rev = xor(lt === >, by === -, rev === true, order === Reverse)
lt = ifelse(lt === >, <, lt)
return (lt, rev)
end

_sort(a::NTuple, alg, order) = sort!(Base.copymutable(a); alg=alg, order=order)
_inttype(::Type{Float64}) = Int64
_inttype(::Type{Float32}) = Int32
_inttype(::Type{Float16}) = Int16

_floattype(::Type{Int64}) = Float64
_floattype(::Type{Int32}) = Float32
_floattype(::Type{Int16}) = Float16

# Modified from the _fpint function added to base/float.jl in Julia 1.7. This is a strictly
# increasing function with respect to the isless relation. `isless` is trichotomous with the
# isequal relation and treats every NaN as identical. This function on the other hand
# distinguishes between NaNs with different payloads and signs, but this difference is
# inconsequential for unstable sorting. The `offset` is necessary because NaNs (in
# particular, those with the sign bit set) must be mapped to the greatest Ints, which is
# Julia-specific.
@inline function _fpint(x::F) where F
I = _inttype(F)
offset = reinterpret(I, typemin(F)) ⊻ -one(I)
vyu marked this conversation as resolved.
Show resolved Hide resolved
n = reinterpret(I, x)
return ifelse(n < zero(I), n ⊻ typemax(I), n) - offset
end

# Inverse of _fpint.
@inline function _intfp(n::I) where I
F = _floattype(I)
offset = reinterpret(I, typemin(F)) ⊻ -one(I)
n += offset
n = ifelse(n < zero(I), n ⊻ typemax(I), n)
return reinterpret(F, n)
end

# Implementation loosely following
# https://www.inf.hs-flensburg.de/lang/algorithmen/sortieren/bitonic/oddn.htm
@generated function _sort(a::NTuple{N}, ::BitonicSortAlg, order) where N
@generated function _bitonic_sort(a::NTuple{N}, order) where N
function swap_expr(i, j, rev)
ai = Symbol('a', i)
aj = Symbol('a', j)
Expand Down Expand Up @@ -87,3 +169,5 @@ _sort(a::NTuple, alg, order) = sort!(Base.copymutable(a); alg=alg, order=order)
return ($(symlist...),)
end
end

end # module Sort
95 changes: 94 additions & 1 deletion test/sort.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,13 @@
module SortTests

using StaticArrays, Test
using Base.Order: Forward, Reverse

_inttype(::Type{Float64}) = Int64
_inttype(::Type{Float32}) = Int32
_inttype(::Type{Float16}) = Int16
vyu marked this conversation as resolved.
Show resolved Hide resolved

float_or(x::T, y::T) where T = reinterpret(T, |(reinterpret.(_inttype(T), (x, y))...))

@testset "sort" begin

Expand Down Expand Up @@ -30,4 +39,88 @@ using StaticArrays, Test
@test sortperm(SA[1, 1, 1, 0]) == SA[4, 1, 2, 3]
end

end
@testset "NaNs" begin
# Return an SVector with floats and NaNs that have random sign and payload bits.
@generated function floats_randnans(::Type{SVector{N, T}}, p) where {N, T}
vyu marked this conversation as resolved.
Show resolved Hide resolved
exprs = Base.Generator(1:N) do _
quote
r = rand(T)
# The bitwise or of any T with T(Inf) is either ±T(Inf) or a NaN.
ifelse(rand(Float32) < p, float_or(typemax(T), r - T(0.5)), r)
end
end
return quote
Base.@_inline_meta
return SVector(($(exprs...),))
end
end

# Sort floats and arbitrary NaNs.
for T in (Float16, Float32, Float64)
buffer = Vector{T}(undef, 16)
@test all(floats_randnans(SVector{16, T}, 0.5) for _ in 1:10_000) do a
copyto!(buffer, a)
isequal(sort(a), sort!(buffer))
end
end

# Sort signed Infs, signed zeros, and NaNs with extremal payloads.
for T in (Float16, Float32, Float64)
U = _inttype(T)
small_nan = reinterpret(T, reinterpret(U, typemax(T)) + one(U))
large_nan = reinterpret(T, typemax(U))
nans = (small_nan, large_nan, T(NaN), -small_nan, -large_nan, -T(NaN))
(a, b, c, d) = (-T(Inf), -zero(T), zero(T), T(Inf))
sorted = [a, b, c, d, nans..., nans...]
@test isequal(sorted, sort(SA[nans..., d, c, b, a, nans...]))
@test isequal(sorted, sort(SA[d, c, nans..., nans..., b, a]))
end
end

# These tests are selected from Julia's test/ordering.jl and test/sorting.jl.
@testset "Base tests" begin
# This testset partially fails on Julia versions < 1.5 because order could be
# discarded: https://github.com/JuliaLang/julia/pull/34719
if VERSION >= v"1.5"
@testset "ordering" begin
for (s1, rev) in enumerate([nothing, true, false])
for (s2, lt) in enumerate([>, <, (a, b) -> a - b > 0, (a, b) -> a - b < 0])
for (s3, by) in enumerate([-, +])
for (s4, order) in enumerate([Reverse, Forward])
if isodd(s1 + s2 + s3 + s4)
target = SA[1, 2, 3]
else
target = SA[3, 2, 1]
end
@test target == sort(SA[2, 3, 1], rev=rev, lt=lt, by=by, order=order)
end
end
end
end

@test SA[1 => 3, 2 => 5, 3 => 1] ==
sort(SA[1 => 3, 2 => 5, 3 => 1]) ==
sort(SA[1 => 3, 2 => 5, 3 => 1], by=first) ==
sort(SA[1 => 3, 2 => 5, 3 => 1], rev=true, order=Reverse) ==
sort(SA[1 => 3, 2 => 5, 3 => 1], lt= >, order=Reverse)

@test SA[3 => 1, 1 => 3, 2 => 5] ==
sort(SA[1 => 3, 2 => 5, 3 => 1], by=last) ==
sort(SA[1 => 3, 2 => 5, 3 => 1], by=last, rev=true, order=Reverse) ==
sort(SA[1 => 3, 2 => 5, 3 => 1], by=last, lt= >, order=Reverse)
end
end

@testset "sort" begin
@test sort(SA[2,3,1]) == SA[1,2,3] == sort(SA[2,3,1]; order=Forward)
@test sort(SA[2,3,1], rev=true) == SA[3,2,1] == sort(SA[2,3,1], order=Reverse)
@test sort(SA['z':-1:'a'...]) == SA['a':'z'...]
@test sort(SA['a':'z'...], rev=true) == SA['z':-1:'a'...]
end

@test sortperm(SA[2,3,1]) == SA[3,1,2]
end

end # @testset "sort"

end # module SortTests