From 6b10e706fb0a4d1477891da095df22fc3730cbe0 Mon Sep 17 00:00:00 2001 From: Neven Sajko Date: Sat, 2 Nov 2024 21:41:04 +0100 Subject: [PATCH] support sorting tuples Uses merge sort, as an obvious choice for a stable sort of tuples. A recursive data structure of singleton type, representing Peano natural numbers, is used to help with splitting a tuple into two halves in the merge sort. An alternative design would use a reference tuple, but this would require relying on `tail`, which seems more harsh on the compiler. With the recursive datastructure the predecessor operation and the successor operation are both trivial. Allows inference to preserve inferred element type even when tuple length is not known. Follow-up PRs may add further improvements, such as the ability to select an unstable sorting algorithm. The added file, typedomainnumbers.jl is not specific to sorting, thus making it a separate file. Xref #55571. Fixes #54489 --- NEWS.md | 1 + base/Base.jl | 2 + base/sort.jl | 89 +++++++++++++++++ base/typedomainnumbers.jl | 194 ++++++++++++++++++++++++++++++++++++++ test/sorting.jl | 34 +++++++ 5 files changed, 320 insertions(+) create mode 100644 base/typedomainnumbers.jl diff --git a/NEWS.md b/NEWS.md index ba9ca1c521c55b..b47d974cad4424 100644 --- a/NEWS.md +++ b/NEWS.md @@ -119,6 +119,7 @@ New library features * `Base.require_one_based_indexing` and `Base.has_offset_axes` are now public ([#56196]) * New `ltruncate`, `rtruncate` and `ctruncate` functions for truncating strings to text width, accounting for char widths ([#55351]) * `isless` (and thus `cmp`, sorting, etc.) is now supported for zero-dimensional `AbstractArray`s ([#55772]) +* `sort` now sorts tuples (#5XXXX) Standard library changes ------------------------ diff --git a/base/Base.jl b/base/Base.jl index 3b56dca166cee1..8315c01a6a7dad 100644 --- a/base/Base.jl +++ b/base/Base.jl @@ -390,6 +390,8 @@ include("cartesian.jl") using .Cartesian include("multidimensional.jl") +include("typedomainnumbers.jl") + include("broadcast.jl") using .Broadcast using .Broadcast: broadcasted, broadcasted_kwsyntax, materialize, materialize!, diff --git a/base/sort.jl b/base/sort.jl index ef0f208209fc8d..fed295f72a7c01 100644 --- a/base/sort.jl +++ b/base/sort.jl @@ -1736,6 +1736,95 @@ julia> v """ sort(v::AbstractVector; kws...) = sort!(copymutable(v); kws...) +module _SortTupleStable + using + Base._TypeDomainNumbers.PositiveIntegers, Base._TypeDomainNumbers.IntegersGreaterThanOne, + Base._TypeDomainNumbers.Utils, Base._TypeDomainNumberTupleUtils, Base._TupleTypeByLength + using Base: tail + using Base.Order: Ordering, lt + export sort_tuple_stable + function merge_recursive((@nospecialize ord::Ordering), a::Tuple, b::Tuple) + ret = if a isa Tuple1OrMore + a + else + b + end + type_assert_tuple_0_or_more(ret) + end + function merge_recursive(ord::Ordering, a::Tuple1OrMore, b::Tuple1OrMore) + l = first(a) + r = first(b) + x = tail(a) + y = tail(b) + merged = if lt(ord, r, l) + let rec = type_assert_tuple_1_or_more(merge_recursive(ord, a, y)) + (r, rec...) + end + else + let rec = type_assert_tuple_1_or_more(merge_recursive(ord, x, b)) + (l, rec...) + end + end + type_assert_tuple_2_or_more(merged) + end + function merge_nontrivial(ord::Ordering, a::Tuple1OrMore, b::Tuple1OrMore) + ret = merge_recursive(ord, a, b) + type_assert_tuple_2_or_more(ret) + end + function split_tuple(@nospecialize tup::Tuple2OrMore) + len = type_assert_integer_greater_than_1(tuple_type_domain_length(tup)) + len_l = type_assert_positive_integer(half_floor_nontrivial(len)) + len_r = type_assert_positive_integer(half_ceiling_nontrivial(len)) + tup_l = type_assert_tuple_1_or_more(skip_from_tail_nontrivial(tup, len_r)) + tup_r = type_assert_tuple_1_or_more(skip_from_front_nontrivial(tup, len_l)) + (tup_l, tup_r) + end + function sort_recursive((@nospecialize ord::Ordering), @nospecialize tup::Tuple{Any}) + tup + end + function sort_recursive(ord::Ordering, tup::Tuple2OrMore) + (tup_l, tup_r) = split_tuple(tup) + sorted_l = type_assert_tuple_1_or_more(sort_recursive(ord, tup_l)) + sorted_r = type_assert_tuple_1_or_more(sort_recursive(ord, tup_r)) + type_assert_tuple_2_or_more(merge_nontrivial(ord, sorted_l, sorted_r)) + end + function sort_tuple_stable_2_or_more(ord::Ordering, tup::Tuple2OrMore) + ret = sort_recursive(ord, tup) + type_assert_tuple_2_or_more(ret) + end + function sort_tuple_array_fallback(ord::Ordering, tup::Tuple2OrMore) + vec = if tup isa NTuple + [tup...] + else + Any[tup...] + end + sort!(vec; order = ord) + (vec...,) + end + function sort_tuple_stable((@nospecialize ord::Ordering), @nospecialize tup::Tuple) + if tup isa Tuple2OrMore + if tup isa Tuple32OrMore + sort_tuple_array_fallback(ord, tup) + else + sort_tuple_stable_2_or_more(ord, tup) + end + else + tup + end + end +end + +function sort( + tup::Tuple; + lt = isless, + by = identity, + rev::Union{Nothing, Bool} = nothing, + order::Ordering = Forward, +) + o = ord(lt, by, rev, order) + _SortTupleStable.sort_tuple_stable(o, tup) +end + ## partialsortperm: the permutation to sort the first k elements of an array ## """ diff --git a/base/typedomainnumbers.jl b/base/typedomainnumbers.jl new file mode 100644 index 00000000000000..acd97d7dcbd636 --- /dev/null +++ b/base/typedomainnumbers.jl @@ -0,0 +1,194 @@ +# This file is a part of Julia. License is MIT: https://julialang.org/license + +# Adapted from the TypeDomainNaturalNumbers.jl package. +module _TypeDomainNumbers + module Zeros + export Zero + struct Zero end + end + + module PositiveIntegers + module RecursiveStep + using ...Zeros + export recursive_step + function recursive_step(@nospecialize t::Type) + Union{Zero, t} + end + end + module UpperBounds + using ..RecursiveStep + abstract type A end + abstract type B{P <: recursive_step(A)} <: A end + abstract type C{P <: recursive_step(B)} <: B{P} end + abstract type D{P <: recursive_step(C)} <: C{P} end + end + using .RecursiveStep + const PositiveIntegerUpperBound = UpperBounds.A + const PositiveIntegerUpperBoundTighter = UpperBounds.D + export + natural_successor, natural_predecessor, + NonnegativeInteger, NonnegativeIntegerUpperBound, + PositiveInteger, PositiveIntegerUpperBound, + type_assert_nonnegative_integer, type_assert_positive_integer + struct PositiveInteger{ + Predecessor <: recursive_step(PositiveIntegerUpperBoundTighter), + } <: PositiveIntegerUpperBoundTighter{Predecessor} + predecessor::Predecessor + global const NonnegativeInteger = recursive_step(PositiveInteger) + global const NonnegativeIntegerUpperBound = recursive_step(PositiveIntegerUpperBound) + global function natural_successor(p::P) where {P <: NonnegativeInteger} + ret = new{P}(p) + type_assert_positive_integer(ret) + end + end + function type_assert_nonnegative_integer(@nospecialize x::NonnegativeInteger) + x + end + function type_assert_positive_integer(@nospecialize x::PositiveInteger) + x + end + function natural_predecessor(@nospecialize o::PositiveInteger) + ret = getfield(o, :predecessor) # avoid specializing `getproperty` for each number + type_assert_nonnegative_integer(ret) + end + end + + module IntegersGreaterThanOne + using ..PositiveIntegers + export + IntegerGreaterThanOne, IntegerGreaterThanOneUpperBound, + type_assert_integer_greater_than_1 + const IntegerGreaterThanOne = let t = PositiveInteger + t{P} where {P <: t} + end + const IntegerGreaterThanOneUpperBound = let t = PositiveIntegerUpperBound + PositiveIntegers.UpperBounds.B{P} where {P <: t} + end + function type_assert_integer_greater_than_1(@nospecialize x::IntegerGreaterThanOne) + x + end + end + + module Constants + using ..Zeros, ..PositiveIntegers + export n0, n1 + const n0 = Zero() + const n1 = natural_successor(n0) + end + + module Utils + using ..PositiveIntegers, ..IntegersGreaterThanOne, ..Constants + using Base: @assume_effects + export minus_two, half_floor, half_ceiling, half_floor_nontrivial, half_ceiling_nontrivial + function minus_two(@nospecialize m::IntegerGreaterThanOne) + natural_predecessor(natural_predecessor(m)) + end + @assume_effects :foldable :nothrow function half_floor(@nospecialize m::NonnegativeInteger) + ret = if m isa IntegerGreaterThanOneUpperBound + let n = minus_two(m), rec = @inline half_floor(n) + type_assert_positive_integer(natural_successor(rec)) + end + else + n0 + end + type_assert_nonnegative_integer(ret) + end + @assume_effects :foldable :nothrow function half_ceiling(@nospecialize m::NonnegativeInteger) + ret = if m isa IntegerGreaterThanOneUpperBound + let n = minus_two(m), rec = @inline half_ceiling(n) + type_assert_positive_integer(natural_successor(rec)) + end + else + if m isa PositiveIntegerUpperBound + n1 + else + n0 + end + end + type_assert_nonnegative_integer(ret) + end + function half_floor_nontrivial(@nospecialize m::IntegerGreaterThanOne) + ret = half_floor(m) + type_assert_positive_integer(ret) + end + function half_ceiling_nontrivial(@nospecialize m::IntegerGreaterThanOne) + ret = half_ceiling(m) + type_assert_positive_integer(ret) + end + end +end + +module _TupleTypeByLength + export + Tuple1OrMore, Tuple2OrMore, Tuple3OrMore, Tuple4OrMore, Tuple32OrMore, + type_assert_tuple_0_or_more, type_assert_tuple_1_or_more, type_assert_tuple_2_or_more, + type_assert_tuple_3_or_more, type_assert_tuple_4_or_more, + type_assert_tuple_1 + const Tuple1OrMore = Tuple{Any, Vararg} + const Tuple2OrMore = Tuple{Any, Any, Vararg} + const Tuple3OrMore = Tuple{Any, Any, Any, Vararg} + const Tuple4OrMore = Tuple{Any, Any, Any, Any, Vararg} + const Tuple32OrMore = Base.Any32 + function type_assert_tuple_0_or_more(@nospecialize x::Tuple) + x + end + function type_assert_tuple_1_or_more(@nospecialize x::Tuple1OrMore) + x + end + function type_assert_tuple_2_or_more(@nospecialize x::Tuple2OrMore) + x + end + function type_assert_tuple_3_or_more(@nospecialize x::Tuple3OrMore) + x + end + function type_assert_tuple_4_or_more(@nospecialize x::Tuple4OrMore) + x + end +end + +module _TypeDomainNumberTupleUtils + using + .._TypeDomainNumbers.PositiveIntegers, .._TypeDomainNumbers.IntegersGreaterThanOne, + .._TypeDomainNumbers.Constants, .._TupleTypeByLength + using Base: @assume_effects, front, tail + export + tuple_type_domain_length, + skip_from_front, skip_from_tail, + skip_from_front_nontrivial, skip_from_tail_nontrivial + # The `@nospecialize` and `@inline` together should effectively result in specializing + # on the length, without specializing on the types of the elements. + @assume_effects :foldable :nothrow function tuple_type_domain_length(@nospecialize tup::Tuple) + ret = if tup isa Tuple1OrMore + let t = tail(tup), rec = @inline tuple_type_domain_length(t) + type_assert_positive_integer(natural_successor(rec)) + end + else + n0 + end + type_assert_nonnegative_integer(ret) + end + @assume_effects :foldable function skip_from_front((@nospecialize tup::Tuple), @nospecialize skip_count::NonnegativeInteger) + if skip_count isa PositiveIntegerUpperBound + let cm1 = natural_predecessor(skip_count), t = tail(tup) + @inline skip_from_front(t, cm1) + end + else + tup + end + end + @assume_effects :foldable function skip_from_tail((@nospecialize tup::Tuple), @nospecialize skip_count::NonnegativeInteger) + if skip_count isa PositiveIntegerUpperBound + let cm1 = natural_predecessor(skip_count), t = front(tup) + @inline skip_from_tail(t, cm1) + end + else + tup + end + end + function skip_from_front_nontrivial((@nospecialize tup::Tuple2OrMore), @nospecialize skip_count::PositiveInteger) + skip_from_front(tup, skip_count) + end + function skip_from_tail_nontrivial((@nospecialize tup::Tuple2OrMore), @nospecialize skip_count::PositiveInteger) + skip_from_tail(tup, skip_count) + end +end diff --git a/test/sorting.jl b/test/sorting.jl index 2714197f58823a..4118391abf5d8d 100644 --- a/test/sorting.jl +++ b/test/sorting.jl @@ -92,6 +92,40 @@ end end @test sort(1:2000, by=x->x÷100, rev=true) == sort(1:2000, by=x->-x÷100) == vcat(2000, (x:x+99 for x in 1900:-100:100)..., 1:99) + @testset "tuples" begin + tup = Tuple(0:9) + @test tup === sort(tup; by = _ -> 0) + @test (0, 2, 4, 6, 8, 1, 3, 5, 7, 9) === sort(tup; by = x -> isodd(x)) + @test (1, 3, 5, 7, 9, 0, 2, 4, 6, 8) === sort(tup; by = x -> iseven(x)) + end +end + +@testset "tuple sorting" begin + max_unrolled_length = 31 + @testset "correctness" begin + tup = Tuple(0:9) + tup_rev = reverse(tup) + @test tup === @inferred sort(tup) + @test tup === sort(tup; rev = false) + @test tup_rev === sort(tup; rev = true) + @test tup_rev === sort(tup; lt = >) + end + @testset "inference" begin + known_length = (Tuple{Vararg{Int, max_unrolled_length}}, Tuple{Vararg{Float64, max_unrolled_length}}) + unknown_length = (Tuple{Vararg{Int}}, Tuple{Vararg{Float64}}) + for Tup ∈ (known_length..., unknown_length...) + @test Tup == Base.infer_return_type(sort, Tuple{Tup}) + end + for Tup ∈ (known_length...,) + @test Core.Compiler.is_foldable(Base.infer_effects(sort, Tuple{Tup})) + end + end + @testset "alloc" begin + function test_zero_allocated(tup::Tuple) + @test iszero(@allocated sort(tup)) + end + test_zero_allocated(ntuple(identity, max_unrolled_length)) + end end @testset "partialsort" begin