Skip to content

Commit

Permalink
support sorting tuples
Browse files Browse the repository at this point in the history
Uses merge sort, as an obvious choice for a stable sort of tuples.

A recursive data structure of singleton type, representing Peano
natural numbers, is used to help with splitting a tuple into two halves
in the merge sort. An alternative design would use a reference tuple,
but this would require relying on `tail`, which seems more harsh on the
compiler. With the recursive datastructure the predecessor operation
and the successor operation are both trivial.

Allows inference to preserve inferred element type even when tuple
length is not known.

Follow-up PRs may add further improvements, such as the ability to
select an unstable sorting algorithm.

The added file, typedomainnumbers.jl is not specific to sorting, thus
making it a separate file. Xref JuliaLang#55571.

Fixes JuliaLang#54489
  • Loading branch information
nsajko committed Nov 10, 2024
1 parent cd748a5 commit 59a5c2d
Show file tree
Hide file tree
Showing 8 changed files with 309 additions and 6 deletions.
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ New library features
* `Base.require_one_based_indexing` and `Base.has_offset_axes` are now public ([#56196])
* New `ltruncate`, `rtruncate` and `ctruncate` functions for truncating strings to text width, accounting for char widths ([#55351])
* `isless` (and thus `cmp`, sorting, etc.) is now supported for zero-dimensional `AbstractArray`s ([#55772])
* `sort` now sorts tuples (#56425)

Standard library changes
------------------------
Expand Down
2 changes: 2 additions & 0 deletions base/Base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,8 @@ include("cartesian.jl")
using .Cartesian
include("multidimensional.jl")

include("typedomainnumbers.jl")

include("broadcast.jl")
using .Broadcast
using .Broadcast: broadcasted, broadcasted_kwsyntax, materialize, materialize!,
Expand Down
76 changes: 76 additions & 0 deletions base/sort.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1736,6 +1736,82 @@ julia> v
"""
sort(v::AbstractVector; kws...) = sort!(copymutable(v); kws...)

module _SortTupleStable
using
Base._TypeDomainNumbers.PositiveIntegers, Base._TypeDomainNumbers.IntegersGreaterThanOne,
Base._TypeDomainNumberTupleUtils, Base._TupleTypeByLength
using Base: tail
using Base.Order: Ordering, lt
function merge_recursive((@nospecialize ord::Ordering), a::Tuple, b::Tuple)
if a isa Tuple1OrMore
a
else
b
end
end
function merge_recursive(ord::Ordering, a::Tuple1OrMore, b::Tuple1OrMore)
l = first(a)
r = first(b)
x = tail(a)
y = tail(b)
if lt(ord, r, l)
let rec = merge_recursive(ord, a, y)
(r, rec...)
end
else
let rec = merge_recursive(ord, x, b)
(l, rec...)
end
end
end
function merge_nontrivial(ord::Ordering, a::Tuple1OrMore, b::Tuple1OrMore)
merge_recursive(ord, a, b)
end
function sort_recursive((@nospecialize ord::Ordering), @nospecialize tup::Tuple{Any})
tup
end
function sort_recursive(ord::Ordering, tup::Tuple2OrMore)
(tup_l, tup_r) = split_tuple_into_halves(tup)
sorted_l = sort_recursive(ord, tup_l)
sorted_r = sort_recursive(ord, tup_r)
merge_nontrivial(ord, sorted_l, sorted_r)
end
function sort_tuple_stable_2_or_more(ord::Ordering, tup::Tuple2OrMore)
sort_recursive(ord, tup)
end
function sort_tuple_array_fallback(ord::Ordering, tup::Tuple2OrMore)
vec = if tup isa NTuple
[tup...]
else
Any[tup...]
end
sort!(vec; order = ord)
(vec...,)
end
function sort_tuple_stable((@nospecialize ord::Ordering), @nospecialize tup::Tuple)
if tup isa Tuple2OrMore
if tup isa Tuple32OrMore
sort_tuple_array_fallback(ord, tup)
else
sort_tuple_stable_2_or_more(ord, tup)
end
else
tup
end
end
end

function sort(
tup::Tuple;
lt = isless,
by = identity,
rev::Union{Nothing, Bool} = nothing,
order::Ordering = Forward,
)
o = ord(lt, by, rev, order)
_SortTupleStable.sort_tuple_stable(o, tup)
end

## partialsortperm: the permutation to sort the first k elements of an array ##

"""
Expand Down
19 changes: 14 additions & 5 deletions base/tuple.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,18 @@
# This file is a part of Julia. License is MIT: https://julialang.org/license

module _TupleTypeByLength
export Tuple1OrMore, Tuple2OrMore, Tuple32OrMore
const Tuple1OrMore = Tuple{Any, Vararg}
const Tuple2OrMore = Tuple{Any, Any, Vararg}
const Tuple32OrMore = Tuple{
Any, Any, Any, Any, Any, Any, Any, Any,
Any, Any, Any, Any, Any, Any, Any, Any,
Any, Any, Any, Any, Any, Any, Any, Any,
Any, Any, Any, Any, Any, Any, Any, Any,
Vararg{Any, N},
} where {N}
end

# Document NTuple here where we have everything needed for the doc system
"""
NTuple{N, T}
Expand Down Expand Up @@ -358,11 +371,7 @@ map(f, t::Tuple{Any, Any}) = (@inline; (f(t[1]), f(t[2])))
map(f, t::Tuple{Any, Any, Any}) = (@inline; (f(t[1]), f(t[2]), f(t[3])))
map(f, t::Tuple) = (@inline; (f(t[1]), map(f,tail(t))...))
# stop inlining after some number of arguments to avoid code blowup
const Any32{N} = Tuple{Any,Any,Any,Any,Any,Any,Any,Any,
Any,Any,Any,Any,Any,Any,Any,Any,
Any,Any,Any,Any,Any,Any,Any,Any,
Any,Any,Any,Any,Any,Any,Any,Any,
Vararg{Any,N}}
const Any32 = _TupleTypeByLength.Tuple32OrMore
const All32{T,N} = Tuple{T,T,T,T,T,T,T,T,
T,T,T,T,T,T,T,T,
T,T,T,T,T,T,T,T,
Expand Down
140 changes: 140 additions & 0 deletions base/typedomainnumbers.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
# This file is a part of Julia. License is MIT: https://julialang.org/license

# Adapted from the TypeDomainNaturalNumbers.jl package.
module _TypeDomainNumbers
module Zeros
export Zero
struct Zero end
end

module PositiveIntegers
module RecursiveStep
using ...Zeros
export recursive_step
function recursive_step(@nospecialize t::Type)
Union{Zero, t}
end
end
module UpperBounds
using ..RecursiveStep
abstract type A end
abstract type B{P <: recursive_step(A)} <: A end
abstract type C{P <: recursive_step(B)} <: B{P} end
abstract type D{P <: recursive_step(C)} <: C{P} end
end
using .RecursiveStep
const PositiveIntegerUpperBound = UpperBounds.A
const PositiveIntegerUpperBoundTighter = UpperBounds.D
export
natural_successor, natural_predecessor,
NonnegativeInteger, NonnegativeIntegerUpperBound,
PositiveInteger, PositiveIntegerUpperBound
struct PositiveInteger{
Predecessor <: recursive_step(PositiveIntegerUpperBoundTighter),
} <: PositiveIntegerUpperBoundTighter{Predecessor}
predecessor::Predecessor
global const NonnegativeInteger = recursive_step(PositiveInteger)
global const NonnegativeIntegerUpperBound = recursive_step(PositiveIntegerUpperBound)
global function natural_successor(p::P) where {P <: NonnegativeInteger}
new{P}(p)
end
end
function natural_predecessor(@nospecialize o::PositiveInteger)
getfield(o, :predecessor) # avoid specializing `getproperty` for each number
end
end

module IntegersGreaterThanOne
using ..PositiveIntegers
export
IntegerGreaterThanOne, IntegerGreaterThanOneUpperBound,
natural_predecessor_predecessor
const IntegerGreaterThanOne = let t = PositiveInteger
t{P} where {P <: t}
end
const IntegerGreaterThanOneUpperBound = let t = PositiveIntegerUpperBound
PositiveIntegers.UpperBounds.B{P} where {P <: t}
end
function natural_predecessor_predecessor(@nospecialize x::IntegerGreaterThanOne)
natural_predecessor(natural_predecessor(x))
end
end

module Constants
using ..Zeros, ..PositiveIntegers
export n0, n1
const n0 = Zero()
const n1 = natural_successor(n0)
end

module Utils
using ..PositiveIntegers, ..IntegersGreaterThanOne, ..Constants
using Base: @assume_effects
export half_floor, half_ceiling
@assume_effects :foldable :nothrow function half_floor(@nospecialize m::NonnegativeInteger)
if m isa IntegerGreaterThanOneUpperBound
let n = natural_predecessor_predecessor(m), rec = half_floor(n)
natural_successor(rec)
end
else
n0
end
end
@assume_effects :foldable :nothrow function half_ceiling(@nospecialize m::NonnegativeInteger)
if m isa IntegerGreaterThanOneUpperBound
let n = natural_predecessor_predecessor(m), rec = half_ceiling(n)
natural_successor(rec)
end
else
if m isa PositiveIntegerUpperBound
n1
else
n0
end
end
end
end
end

module _TypeDomainNumberTupleUtils
using
.._TypeDomainNumbers.PositiveIntegers, .._TypeDomainNumbers.IntegersGreaterThanOne,
.._TypeDomainNumbers.Constants, .._TypeDomainNumbers.Utils, .._TupleTypeByLength
using Base: @assume_effects, front, tail
export tuple_type_domain_length, split_tuple_into_halves, skip_from_front, skip_from_tail
@assume_effects :foldable :nothrow function tuple_type_domain_length(@nospecialize tup::Tuple)
if tup isa Tuple1OrMore
let t = tail(tup), rec = tuple_type_domain_length(t)
natural_successor(rec)
end
else
n0
end
end
@assume_effects :foldable function skip_from_front((@nospecialize tup::Tuple), @nospecialize skip_count::NonnegativeInteger)
if skip_count isa PositiveIntegerUpperBound
let cm1 = natural_predecessor(skip_count), t = tail(tup)
@inline skip_from_front(t, cm1)
end
else
tup
end
end
@assume_effects :foldable function skip_from_tail((@nospecialize tup::Tuple), @nospecialize skip_count::NonnegativeInteger)
if skip_count isa PositiveIntegerUpperBound
let cm1 = natural_predecessor(skip_count), t = front(tup)
@inline skip_from_tail(t, cm1)
end
else
tup
end
end
function split_tuple_into_halves(@nospecialize tup::Tuple)
len = tuple_type_domain_length(tup)
len_l = half_floor(len)
len_r = half_ceiling(len)
tup_l = skip_from_tail(tup, len_r)
tup_r = skip_from_front(tup, len_l)
(tup_l, tup_r)
end
end
2 changes: 1 addition & 1 deletion test/choosetests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ const TESTNAMES = [
"bitarray", "copy", "math", "fastmath", "functional", "iterators",
"operators", "ordering", "path", "ccall", "parse", "loading", "gmp",
"sorting", "spawn", "backtrace", "exceptions",
"file", "read", "version", "namedtuple",
"file", "read", "version", "namedtuple", "typedomainnumbers",
"mpfr", "broadcast", "complex",
"floatapprox", "stdlib", "reflection", "regex", "float16",
"combinatorics", "sysinfo", "env", "rounding", "ranges", "mod2pi",
Expand Down
44 changes: 44 additions & 0 deletions test/sorting.jl
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,50 @@ end
end
@test sort(1:2000, by=x->x÷100, rev=true) == sort(1:2000, by=x->-x÷100) ==
vcat(2000, (x:x+99 for x in 1900:-100:100)..., 1:99)
@testset "tuples" begin
tup = Tuple(0:9)
@test tup === sort(tup; by = _ -> 0)
@test (0, 2, 4, 6, 8, 1, 3, 5, 7, 9) === sort(tup; by = x -> isodd(x))
@test (1, 3, 5, 7, 9, 0, 2, 4, 6, 8) === sort(tup; by = x -> iseven(x))
end
end

@testset "tuple sorting" begin
max_unrolled_length = 31
@testset "correctness" begin
tup = Tuple(0:9)
tup_rev = reverse(tup)
@test tup === @inferred sort(tup)
@test tup === sort(tup; rev = false)
@test tup_rev === sort(tup; rev = true)
@test tup_rev === sort(tup; lt = >)
end
@testset "inference" begin
known_length = (Tuple{Vararg{Int, max_unrolled_length}}, Tuple{Vararg{Float64, max_unrolled_length}})
unknown_length = (Tuple{Vararg{Int}}, Tuple{Vararg{Float64}})
for Tup (known_length..., unknown_length...)
@test Tup == Base.infer_return_type(sort, Tuple{Tup})
end
for Tup (known_length...,)
@test Core.Compiler.is_foldable(Base.infer_effects(sort, Tuple{Tup}))
end
end
@testset "alloc" begin
function test_zero_allocated(tup::Tuple)
@test iszero(@allocated sort(tup))
end
test_zero_allocated(ntuple(identity, max_unrolled_length))
end
@testset "heterogeneous" begin
@testset "stability" begin
tup = (0, 0x0, 0x000)
@test tup === sort(tup)
end
tup = (1, 2, 3, missing, missing)
for t (tup, (1, missing, 2, missing, 3), (missing, missing, 1, 2, 3))
@test tup === @inferred sort(t)
end
end
end

@testset "partialsort" begin
Expand Down
31 changes: 31 additions & 0 deletions test/typedomainnumbers.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# This file is a part of Julia. License is MIT: https://julialang.org/license

using
Test,
Base._TypeDomainNumbers.PositiveIntegers,
Base._TypeDomainNumbers.IntegersGreaterThanOne,
Base._TypeDomainNumbers.Constants,
Base._TypeDomainNumberTupleUtils

@testset "type domain numbers" begin
@test n0 isa NonnegativeInteger
@test n1 isa NonnegativeInteger
@test n1 isa PositiveInteger
@testset "succ" begin
for x (n0, n1)
@test x === natural_predecessor(@inferred natural_successor(x))
@test x === natural_predecessor_predecessor(natural_successor(natural_successor(x)))
end
end
@testset "type safety" begin
@test_throws TypeError PositiveInteger{Int}
end
@testset "tuple utils" begin
@test n0 === @inferred tuple_type_domain_length(())
@test n1 === @inferred tuple_type_domain_length((7,))
@test ((), ()) === @inferred split_tuple_into_halves(())
@test ((), (7,)) === @inferred split_tuple_into_halves((7,))
@test ((3,), (7,)) === @inferred split_tuple_into_halves((3, 7))
@test ((3,), (7, 9)) === @inferred split_tuple_into_halves((3, 7, 9))
end
end

0 comments on commit 59a5c2d

Please sign in to comment.