Skip to content

Commit

Permalink
Rules for sortslices, unique (#546)
Browse files Browse the repository at this point in the history
* unique, take 1

* add shortcut, and tests

* sortslices, too

* fixup

* Apply 3 suggestions

Co-authored-by: Lyndon White <[email protected]>

* comment

Co-authored-by: Lyndon White <[email protected]>

Co-authored-by: Lyndon White <[email protected]>
  • Loading branch information
mcabbott and oxinabox authored Nov 30, 2021
1 parent 605354c commit ce78d3d
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 0 deletions.
56 changes: 56 additions & 0 deletions src/rulesets/Base/sort.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
#####
##### `sort`
#####

function rrule(::typeof(partialsort), xs::AbstractVector, k::Union{Integer,OrdinalRange}; kwargs...)
inds = partialsortperm(xs, k; kwargs...)
ys = xs[inds]
Expand Down Expand Up @@ -33,3 +37,55 @@ function rrule(::typeof(sort), xs::AbstractVector; kwargs...)
end
return ys, sort_pullback
end

#####
##### `sortslices`
#####

function rrule(::typeof(sortslices), x::AbstractArray; dims::Integer, kw...)
p = sortperm(collect(eachslice(x; dims=dims)); kw...)
inds = ntuple(d -> d == dims ? p : (:), ndims(x))
function sortslices_pullback(dy)
# No actual need to zero this, and if you didn't, then you could widen eltype
# Also, you could use similar(dy) here not x, same size?
dx = _zerolike_writeat(x, unthunk(dy), (), inds...)
return (NoTangent(), ProjectTo(x)(dx))
end
return x[inds...], sortslices_pullback
end

#####
##### `unique`
#####

function rrule(::typeof(unique), x::AbstractArray{<:Number}; dims=:)
axes_x = axes(x)
y = unique(x; dims=dims) # accepts only dims=: or dims::Integer
function unique_pullback(dy_raw)
dy = unthunk(dy_raw)
if length(x) == length(y)
# Short-circuit for the case of all unique, since `mask` is fairly expensive:
dx = reshape(dy, axes_x)
return (NoTangent(), ProjectTo(x)(dx))
end

if dims isa Colon
xs, ys = vec(x), y
else
xs, ys = collect(eachslice(x; dims=dims)), collect(eachslice(y; dims=dims))
end
mask = isequal.(permutedims(ys), xs) # unique([0.0, -0.0, NaN, NaN])
mask .= (mask .== cumsum(mask, dims=1) .== true) # this implements findfirst(mask; dims=1)
keep = map(I -> I[1], findall(mask))
if dims isa Colon
# The function `_zerolike_writeat` allows second derivatives.
# Should perhaps eventually be shared with `getindex`.
dx = reshape(_zerolike_writeat(vec(x), vec(dy), (), keep), axes_x)
else
inds = ntuple(d -> d==dims ? keep : (:), length(axes_x))
dx = _zerolike_writeat(x, dy, (), inds...)
end
return (NoTangent(), ProjectTo(x)(dx))
end
return y, unique_pullback
end
29 changes: 29 additions & 0 deletions test/rulesets/Base/sort.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,33 @@

test_rrule(partialsort, a, 4, fkwargs=(;rev=true))
end

@testset "sortslices" begin
test_rrule(sortslices, rand(3,4); fkwargs=(; dims=2))
test_rrule(sortslices, rand(5,4); fkwargs=(; dims=1, rev=true, by=last))
test_rrule(sortslices, rand(3,4,5); fkwargs=(; dims=3, by=sum), check_inferred=false)

@test_throws Exception sortslices(Diagonal(1:3), dims=1)
end

@testset "unique" begin
# Trivial case, all unique:
test_rrule(unique, rand(5))
test_rrule(unique, rand(3,4))
test_rrule(unique, rand(3,4); fkwargs=(; dims=2))

# Not all unique:
@test rrule(unique, [1,1,2,3])[1] == [1,2,3]
@test rrule(unique, [1,1,2,3])[2]([10,20,30]) == (NoTangent(), [10, 0, 20, 30])

@test rrule(unique, [1 2; 1 4])[1] == [1,2,4]
@test rrule(unique, [1 2; 1 4])[2]([10,20,30]) == (NoTangent(), [10 20; 0 30])

@test rrule(unique, [1 2 1 2; 1 2 1 4], dims=2)[1] == [1 2 2; 1 2 4]
@test rrule(unique, [1 2 1 2; 1 2 1 4], dims=2)[2]([10 20 30; 40 50 60])[2] == [10 20 0 30; 40 50 0 60]

@test rrule(unique, Diagonal([1,2,3]))[1] == [1,0,2,3]
@test rrule(unique, Diagonal([1,2,3]))[2]([10 20 30 40])[2] == [10.0 0.0 0.0; 0.0 30.0 0.0; 0.0 0.0 40.0]
@test rrule(unique, Diagonal([1,2,3]))[2]([10 20 30 40])[2] isa Diagonal
end
end

0 comments on commit ce78d3d

Please sign in to comment.