Skip to content

Commit

Permalink
Add rrules for extrema, findmax, maximum (#480)
Browse files Browse the repository at this point in the history
* rules for extrema, findmax, maximum

* fixup extrema

* symmetric maximum rule

* promote types by hand

* argmax?

* allow more zeros

* upgrade tests

* don't do symmetric convention

* tests

* fix 1.0

* rm symmetric versions

* move extrema to last

* tidy

* fixup extrema

* tests

* tests

* use eval loop, tidy, tests

* forward rules for maximum

* frules for findmax

* tidy

* widen similar to ensure writeability

* comments

* dispatch -> branch

* allow for second derivatives

* frule?

* update to use CRC 1.3

* better writezero?

* fix tests

* allow arrays of arrays

* version
  • Loading branch information
mcabbott authored Nov 24, 2021
1 parent a751937 commit 605354c
Show file tree
Hide file tree
Showing 4 changed files with 243 additions and 3 deletions.
6 changes: 3 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ChainRules"
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
version = "1.13.0"
version = "1.14.0"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand All @@ -11,10 +11,10 @@ RealDot = "c1ae055f-0cd5-4b69-90a6-9a35b1a98df9"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[compat]
ChainRulesCore = "1.10"
ChainRulesCore = "1.11"
ChainRulesTestUtils = "1"
Compat = "3.35"
FiniteDifferences = "0.12.8"
FiniteDifferences = "0.12.20"
JuliaInterpreter = "0.8"
RealDot = "0.1"
StaticArrays = "1.2"
Expand Down
162 changes: 162 additions & 0 deletions src/rulesets/Base/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -342,3 +342,165 @@ function rrule(::typeof(fill), x::Any, dims...)
fill_pullback(Ȳ) = (NoTangent(), project(sum(Ȳ)), nots...)
return fill(x, dims...), fill_pullback
end

#####
##### `findmax`, `maximum`, etc.
#####

for findm in (:findmin, :findmax)
findm_pullback = Symbol(findm, :_pullback)

@eval function frule((_, xdot), ::typeof($findm), x; dims=:)
y, ind = $findm(x; dims=dims)
return (y, ind), Tangent{typeof((y, ind))}(xdot[ind], NoTangent())
end

@eval function rrule(::typeof($findm), x::AbstractArray; dims=:)
y, ind = $findm(x; dims=dims)
project = ProjectTo(x)
# This pullback is a lot like the one for getindex. Ideally they would probably be combined?
function $findm_pullback((dy, _)) # this accepts e.g. Tangent{Tuple{Float64, Int64}}(4.0, nothing)
dy isa AbstractZero && return (NoTangent(), NoTangent())
x_thunk = @thunk project(_zerolike_writeat(x, unthunk(dy), dims, ind))
x_ithunk = InplaceableThunk(x_thunk) do dx
if dims isa Colon
view(dx, ind) .= view(dx, ind) .+ Ref(unthunk(dy))
else
view(dx, ind) .= view(dx, ind) .+ unthunk(dy) # this could be .+=, but not on Julia 1.0
end
dx
end
return (NoTangent(), x_ithunk)
end
return (y, ind), $findm_pullback
end
end

# This function is roughly `setindex!(zero(x), dy, inds...)`:

function _zerolike_writeat(x::AbstractArray{<:Number}, dy, dims, inds...)
# It's unfortunate to close over `x`, but `similar(typeof(x), axes(x))` doesn't
# allow `eltype(dy)`, nor does it work for many structured matrices.
dx = fill!(similar(x, eltype(dy), axes(x)), 0)
view(dx, inds...) .= dy # possibly 0-dim view, allows dy::Number and dy::Array, and dx::CuArray
dx
end
function _zerolike_writeat(x::AbstractArray, dy, dims, inds...)
# Since we have `x`, we can also handle arrays of arrays.
dx = map(zero, x)
if dims isa Colon
view(dx, inds...) .= Ref(dy)
else
view(dx, inds...) .= dy
end
dx
end

# Allow for second derivatives, by writing rules for `_zerolike_writeat`;
# these rules are the reason it takes a `dims` argument.

function frule((_, _, dydot), ::typeof(_zerolike_writeat), x, dy, dims, inds...)
return _zerolike_writeat(x, dy, dims, inds...), _zerolike_writeat(x, dydot, dims, inds...)
end

function rrule(::typeof(_zerolike_writeat), x, dy, dims, inds...)
z = _zerolike_writeat(x, dy, dims, inds...)
function _zerolike_writeat_pullback(dz)
dx = sum(view(unthunk(dz), inds...); dims=dims)
nots = map(_ -> NoTangent(), inds)
return (NoTangent(), NoTangent(), dx, NoTangent(), nots...)
end
return z, _zerolike_writeat_pullback
end

# These rules for `maximum` pick the same subgradient as `findmax`:

function frule((_, xdot), ::typeof(maximum), x; dims=:)
y, ind = findmax(x; dims=dims)
return y, xdot[ind]
end

function rrule(::typeof(maximum), x::AbstractArray; dims=:)
(y, _), back = rrule(findmax, x; dims=dims)
maximum_pullback(dy) = back((dy, nothing))
return y, maximum_pullback
end

function frule((_, xdot), ::typeof(minimum), x; dims=:)
y, ind = findmin(x; dims=dims)
return y, xdot[ind]
end

function rrule(::typeof(minimum), x::AbstractArray; dims=:)
(y, _), back = rrule(findmin, x; dims=dims)
minimum_pullback(dy) = back((dy, nothing))
return y, minimum_pullback
end

#####
##### `extrema`
#####

function rrule(::typeof(extrema), x::AbstractArray{<:Number}; dims=:)
if dims isa Colon
return _extrema_colon(x)
else
return _extrema_dims(x, dims)
end
end

function _extrema_colon(x)
ylo, ilo = findmin(x)
yhi, ihi = findmax(x)
project = ProjectTo(x)
function extrema_pullback((dylo, dyhi)) # accepts Tangent
if (dylo, dyhi) isa Tuple{AbstractZero, AbstractZero}
return (NoTangent(), NoTangent())
end
# One argument may be AbstractZero here. Use promote_op because
# promote_type allows for * as well as +, hence gives Any.
T = Base.promote_op(+, typeof(dylo), typeof(dyhi))
x_nothunk = let
# x_thunk = @thunk begin # this doesn't infer
dx = fill!(similar(x, T, axes(x)), false)
view(dx, ilo) .= dylo
view(dx, ihi) .= view(dx, ihi) .+ dyhi
project(dx)
end
# x_ithunk = InplaceableThunk(x_thunk) do dx
# view(dx, ilo) .= view(dx, ilo) .+ dylo
# view(dx, ihi) .= view(dx, ihi) .+ dyhi
# dx
# end
return (NoTangent(), x_nothunk)
end
return (ylo, yhi), extrema_pullback
end

function _extrema_dims(x, dims)
ylo, ilo = findmin(x; dims=dims)
yhi, ihi = findmax(x; dims=dims)
y = similar(ylo, Tuple{eltype(ylo), eltype(yhi)})
map!(tuple, y, ylo, yhi) # this is a GPU-friendly version of collect(zip(ylo, yhi))
project = ProjectTo(x)
function extrema_pullback_dims(dy_raw)
dy = unthunk(dy_raw)
@assert dy isa AbstractArray{<:Tuple{Any,Any}}
# Can we actually get Array{Tuple{Float64,ZeroTangent}} here? Not sure.
T = Base.promote_op(+, eltype(dy).parameters...)
x_nothunk = let
# x_thunk = @thunk begin # this doesn't infer
dx = fill!(similar(x, T, axes(x)), false)
view(dx, ilo) .= first.(dy)
view(dx, ihi) .= view(dx, ihi) .+ last.(dy)
project(dx)
end
# x_ithunk = InplaceableThunk(x_thunk) do dx
# view(dx, ilo) .= first.(dy)
# view(dx, ihi) .= view(dx, ihi) .+ last.(dy)
# dx
# end
return (NoTangent(), x_nothunk)
end
return y, extrema_pullback_dims
end
2 changes: 2 additions & 0 deletions src/rulesets/Base/nondiff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,8 @@
@non_differentiable all(::Any, ::Any)
@non_differentiable any(::Any)
@non_differentiable any(::Any, ::Any)
@non_differentiable argmax(::Any)
@non_differentiable argmin(::Any)
@non_differentiable ascii(::AbstractString)
@non_differentiable axes(::Any)
@non_differentiable axes(::Any, ::Any)
Expand Down
76 changes: 76 additions & 0 deletions test/rulesets/Base/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -198,3 +198,79 @@ end
test_rrule(fill, 55 + 0.5im, 5)
test_rrule(fill, 3.3, (3, 3, 3))
end

@testset "findmin & findmax" begin
# Forward
test_frule(findmin, rand(10))
test_frule(findmax, rand(10))
@test @inferred(frule((nothing, rand(3,4)), findmin, rand(3,4))) isa Tuple{Tuple{Float64, CartesianIndex}, Tangent}
@test @inferred(frule((nothing, rand(3,4)), findmin, rand(3,4), dims=1)) isa Tuple{Tuple{Matrix, Matrix}, Tangent}
@test_skip test_frule(findmin, rand(3,4)) # error from test_approx(actual::CartesianIndex{2}, expected::CartesianIndex{2}
@test_skip test_frule(findmin, rand(3,4), output_tangent = (rand(), NoTangent()))
@test_skip test_frule(findmin, rand(3,4), fkwargs=(dims=1,))
# These skipped tests might be fixed by https://github.com/JuliaDiff/FiniteDifferences.jl/issues/188

# Reverse
test_rrule(findmin, rand(10), output_tangent = (rand(), false))
test_rrule(findmax, rand(10), output_tangent = (rand(), false))
test_rrule(findmin, rand(5,3))
test_rrule(findmax, rand(5,3))
@test [0 0; 0 5] == @inferred unthunk(rrule(findmax, [1 2; 3 4])[2]((5.0, nothing))[2])
@test [0 0; 0 5] == @inferred unthunk(rrule(findmax, [1 2; 3 4])[2]((5.0, NoTangent()))[2])

# Reverse with dims:
@test [0 0; 5 6] == @inferred unthunk(rrule(findmax, [1 2; 3 4], dims=1)[2](([5 6], nothing))[2])
@test [5 0; 6 0] == @inferred unthunk(rrule(findmin, [1 2; 3 4], dims=2)[2]((hcat([5,6]), nothing))[2])
test_rrule(findmin, rand(3,4), fkwargs=(dims=1,), output_tangent = (rand(1,4), NoTangent()))
test_rrule(findmin, rand(3,4), fkwargs=(dims=2,))

# Second derivatives
test_frule(ChainRules._zerolike_writeat, rand(2, 2), 5.0, :, CartesianIndex(2, 2))
test_rrule(ChainRules._zerolike_writeat, rand(2, 2), 5.0, :, CartesianIndex(2, 2))
@test_skip test_rrule(ChainRules._zerolike_writeat, rand(2, 2), 5.0, 1, [CartesianIndex(2, 1) CartesianIndex(2, 2)] NoTangent()) # MethodError: no method matching isapprox(::Matrix{Float64}, ::Float64; rtol=1.0e-9, atol=1.0e-9)
y, bk = rrule(ChainRules._zerolike_writeat, rand(2, 2), 5.0, 1, [CartesianIndex(2, 1) CartesianIndex(2, 2)])
@test y == [0 0; 5 5]
@test bk([1 2; 3 4]) == (NoTangent(), NoTangent(), [3 4], NoTangent(), NoTangent())
end

@testset "$imum" for imum in [maximum, minimum]
# Forward
test_frule(imum, rand(10))
test_frule(imum, rand(3,4))
test_frule(imum, rand(3,4), fkwargs=(dims=1,))
test_frule(imum, [rand(2) for _ in 1:3])
test_frule(imum, [rand(2) for _ in 1:3, _ in 1:4]; fkwargs=(dims=1,))

# Reverse
test_rrule(imum, rand(10))
test_rrule(imum, rand(3,4))
test_rrule(imum, rand(3,4), fkwargs=(dims=1,))
test_rrule(imum, rand(3,4,5), fkwargs=(dims=(1,3),))

# Arrays of arrays
test_rrule(imum, [rand(2) for _ in 1:3]; check_inferred=false)
test_rrule(imum, [rand(2) for _ in 1:3, _ in 1:4]; fkwargs=(dims=1,), check_inferred=false)

# Case which attains max twice -- can't use FiniteDifferences for this
res = imum == maximum ? [0,1,0,0,0,0] : [1,0,0,0,0,0]
@test res == @inferred unthunk(rrule(imum, [1,2,1,2,1,2])[2](1.0)[2])

# Structured matrix -- NB the minimum is a structral zero here
@test unthunk(rrule(imum, Diagonal(rand(3) .+ 1))[2](5.5)[2]) isa Diagonal
@test unthunk(rrule(imum, UpperTriangular(rand(3,3) .+ 1))[2](5.5)[2]) isa UpperTriangular{Float64}
@test_skip test_rrule(imum, Diagonal(rand(3) .+ 1)) # MethodError: no method matching zero(::Type{Any}), from fill!(A::SparseArrays.SparseMatrixCSC{Any, Int64}, x::Bool)
end

@testset "extrema" begin
test_rrule(extrema, rand(10), output_tangent = (rand(), rand()))
test_rrule(extrema, rand(3,4), fkwargs=(dims=1,), output_tangent = collect(zip(rand(1,4), rand(1,4))))
# Case where both extrema are the same index, to check accumulation:
test_rrule(extrema, rand(1), output_tangent = (rand(), rand()))
test_rrule(extrema, rand(1,1), fkwargs=(dims=2,), output_tangent = hcat((rand(), rand())))
test_rrule(extrema, rand(3,1), fkwargs=(dims=2,), output_tangent = collect(zip(rand(3,1), rand(3,1))))
# Double-check the forward pass
A = randn(3,4,5)
@test extrema(A, dims=(1,3)) == rrule(extrema, A, dims=(1,3))[1]
B = hcat(A[:,:,1], A[:,:,1])
@test extrema(B, dims=2) == rrule(extrema, B, dims=2)[1]
end

4 comments on commit 605354c

@mzgubic
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/49340

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v1.14.0 -m "<description of version>" 605354c19ae5805ff3a11b650c728915f03bff88
git push origin v1.14.0

Also, note the warning: Version 1.14.0 skips over 1.13.0
This can be safely ignored. However, if you want to fix this you can do so. Call register() again after making the fix. This will update the Pull request.

@mzgubic
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request updated: JuliaRegistries/General/49340

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v1.14.0 -m "<description of version>" 605354c19ae5805ff3a11b650c728915f03bff88
git push origin v1.14.0

Please sign in to comment.