diff --git a/src/rulesets/Base/sort.jl b/src/rulesets/Base/sort.jl index cf3eadc4b..be7840c8c 100644 --- a/src/rulesets/Base/sort.jl +++ b/src/rulesets/Base/sort.jl @@ -1,3 +1,7 @@ +##### +##### `sort` +##### + function rrule(::typeof(partialsort), xs::AbstractVector, k::Union{Integer,OrdinalRange}; kwargs...) inds = partialsortperm(xs, k; kwargs...) ys = xs[inds] @@ -33,3 +37,55 @@ function rrule(::typeof(sort), xs::AbstractVector; kwargs...) end return ys, sort_pullback end + +##### +##### `sortslices` +##### + +function rrule(::typeof(sortslices), x::AbstractArray; dims::Integer, kw...) + p = sortperm(collect(eachslice(x; dims=dims)); kw...) + inds = ntuple(d -> d == dims ? p : (:), ndims(x)) + function sortslices_pullback(dy) + # No actual need to zero this, and if you didn't, then you could widen eltype + # Also, you could use similar(dy) here not x, same size? + dx = _zerolike_writeat(x, unthunk(dy), (), inds...) + return (NoTangent(), ProjectTo(x)(dx)) + end + return x[inds...], sortslices_pullback +end + +##### +##### `unique` +##### + +function rrule(::typeof(unique), x::AbstractArray{<:Number}; dims=:) + axes_x = axes(x) + y = unique(x; dims=dims) # accepts only dims=: or dims::Integer + function unique_pullback(dy_raw) + dy = unthunk(dy_raw) + if length(x) == length(y) + # Short-circuit for the case of all unique, since `mask` is fairly expensive: + dx = reshape(dy, axes_x) + return (NoTangent(), ProjectTo(x)(dx)) + end + + if dims isa Colon + xs, ys = vec(x), y + else + xs, ys = collect(eachslice(x; dims=dims)), collect(eachslice(y; dims=dims)) + end + mask = isequal.(permutedims(ys), xs) # unique([0.0, -0.0, NaN, NaN]) + mask .= (mask .== cumsum(mask, dims=1) .== true) # this implements findfirst(mask; dims=1) + keep = map(I -> I[1], findall(mask)) + if dims isa Colon + # The function `_zerolike_writeat` allows second derivatives. + # Should perhaps eventually be shared with `getindex`. + dx = reshape(_zerolike_writeat(vec(x), vec(dy), (), keep), axes_x) + else + inds = ntuple(d -> d==dims ? keep : (:), length(axes_x)) + dx = _zerolike_writeat(x, dy, (), inds...) + end + return (NoTangent(), ProjectTo(x)(dx)) + end + return y, unique_pullback +end diff --git a/test/rulesets/Base/sort.jl b/test/rulesets/Base/sort.jl index 5f3bc4213..f76109586 100644 --- a/test/rulesets/Base/sort.jl +++ b/test/rulesets/Base/sort.jl @@ -12,4 +12,33 @@ test_rrule(partialsort, a, 4, fkwargs=(;rev=true)) end + + @testset "sortslices" begin + test_rrule(sortslices, rand(3,4); fkwargs=(; dims=2)) + test_rrule(sortslices, rand(5,4); fkwargs=(; dims=1, rev=true, by=last)) + test_rrule(sortslices, rand(3,4,5); fkwargs=(; dims=3, by=sum), check_inferred=false) + + @test_throws Exception sortslices(Diagonal(1:3), dims=1) + end + + @testset "unique" begin + # Trivial case, all unique: + test_rrule(unique, rand(5)) + test_rrule(unique, rand(3,4)) + test_rrule(unique, rand(3,4); fkwargs=(; dims=2)) + + # Not all unique: + @test rrule(unique, [1,1,2,3])[1] == [1,2,3] + @test rrule(unique, [1,1,2,3])[2]([10,20,30]) == (NoTangent(), [10, 0, 20, 30]) + + @test rrule(unique, [1 2; 1 4])[1] == [1,2,4] + @test rrule(unique, [1 2; 1 4])[2]([10,20,30]) == (NoTangent(), [10 20; 0 30]) + + @test rrule(unique, [1 2 1 2; 1 2 1 4], dims=2)[1] == [1 2 2; 1 2 4] + @test rrule(unique, [1 2 1 2; 1 2 1 4], dims=2)[2]([10 20 30; 40 50 60])[2] == [10 20 0 30; 40 50 0 60] + + @test rrule(unique, Diagonal([1,2,3]))[1] == [1,0,2,3] + @test rrule(unique, Diagonal([1,2,3]))[2]([10 20 30 40])[2] == [10.0 0.0 0.0; 0.0 30.0 0.0; 0.0 0.0 40.0] + @test rrule(unique, Diagonal([1,2,3]))[2]([10 20 30 40])[2] isa Diagonal + end end