From 09f98cad7b97a008f89caf391d4195aab9414416 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Wed, 25 May 2022 13:35:27 -0400 Subject: [PATCH 01/11] move and rename _zerolike_writeat, NFC --- src/rulesets/Base/array.jl | 41 +------------------------------- src/rulesets/Base/indexing.jl | 43 ++++++++++++++++++++++++++++++++++ src/rulesets/Base/sort.jl | 9 ++++--- test/rulesets/Base/array.jl | 8 ------- test/rulesets/Base/indexing.jl | 9 +++++++ 5 files changed, 57 insertions(+), 53 deletions(-) diff --git a/src/rulesets/Base/array.jl b/src/rulesets/Base/array.jl index 36ba52d55..5599ca552 100644 --- a/src/rulesets/Base/array.jl +++ b/src/rulesets/Base/array.jl @@ -519,7 +519,7 @@ for findm in (:findmin, :findmax) # This pullback is a lot like the one for getindex. Ideally they would probably be combined? function $findm_pullback((dy, _)) # this accepts e.g. Tangent{Tuple{Float64, Int64}}(4.0, nothing) dy isa AbstractZero && return (NoTangent(), NoTangent()) - x_thunk = @thunk project(_zerolike_writeat(x, unthunk(dy), dims, ind)) + x_thunk = @thunk project(∇getindex(x, unthunk(dy), dims, ind)) x_ithunk = InplaceableThunk(x_thunk) do dx if dims isa Colon view(dx, ind) .= view(dx, ind) .+ Ref(unthunk(dy)) @@ -534,45 +534,6 @@ for findm in (:findmin, :findmax) end end -# This function is roughly `setindex!(zero(x), dy, inds...)`: - -function _zerolike_writeat(x::AbstractArray{<:Number}, dy, dims, inds...) - _zero_fill = eltype(dy) == Any ? 0 : zero(eltype(dy)) - - # It's unfortunate to close over `x`, but `similar(typeof(x), axes(x))` doesn't - # allow `eltype(dy)`, nor does it work for many structured matrices. - dx = fill!(similar(x, eltype(dy), axes(x)), _zero_fill) - view(dx, inds...) .= dy # possibly 0-dim view, allows dy::Number and dy::Array, and dx::CuArray - dx -end -function _zerolike_writeat(x::AbstractArray, dy, dims, inds...) - # Since we have `x`, we can also handle arrays of arrays. - dx = map(zero, x) - if dims isa Colon - view(dx, inds...) .= Ref(dy) - else - view(dx, inds...) .= dy - end - dx -end - -# Allow for second derivatives, by writing rules for `_zerolike_writeat`; -# these rules are the reason it takes a `dims` argument. - -function frule((_, _, dẏ), ::typeof(_zerolike_writeat), x, dy, dims, inds...) - return _zerolike_writeat(x, dy, dims, inds...), _zerolike_writeat(x, dẏ, dims, inds...) -end - -function rrule(::typeof(_zerolike_writeat), x, dy, dims, inds...) - z = _zerolike_writeat(x, dy, dims, inds...) - function _zerolike_writeat_pullback(dz) - dx = sum(view(unthunk(dz), inds...); dims=dims) - nots = map(_ -> NoTangent(), inds) - return (NoTangent(), NoTangent(), dx, NoTangent(), nots...) - end - return z, _zerolike_writeat_pullback -end - # These rules for `maximum` pick the same subgradient as `findmax`: function frule((_, ẋ), ::typeof(maximum), x; dims=:) diff --git a/src/rulesets/Base/indexing.jl b/src/rulesets/Base/indexing.jl index fbce009e0..9c58aa95e 100644 --- a/src/rulesets/Base/indexing.jl +++ b/src/rulesets/Base/indexing.jl @@ -86,6 +86,49 @@ function rrule(::typeof(getindex), x::Array{<:Number}, inds...) return y, getindex_pullback end + +""" + ∇getindex(x, dy, dims, inds...) + +This function is roughly `setindex!(zero(x), dy, inds...)`. + +""" +function ∇getindex(x::AbstractArray{<:Number}, dy, dims, inds...) + # It's unfortunate to close over `x`, but `similar(typeof(x), axes(x))` doesn't + # allow `eltype(dy)`, nor does it work for many structured matrices. + dx = fill!(similar(x, eltype(dy), axes(x)), 0) + view(dx, inds...) .= dy # possibly 0-dim view, allows dy::Number and dy::Array, and dx::CuArray + dx +end +function ∇getindex(x::AbstractArray, dy, dims, inds...) + # Since we have `x`, we can also handle arrays of arrays. + dx = map(zero, x) + if dims isa Colon + view(dx, inds...) .= Ref(dy) + else + view(dx, inds...) .= dy + end + dx +end + +# Allow for second derivatives, by writing rules for `∇getindex`; +# these rules are the reason it takes a `dims` argument. + +function frule((_, _, dẏ), ::typeof(∇getindex), x, dy, dims, inds...) + return ∇getindex(x, dy, dims, inds...), ∇getindex(x, dẏ, dims, inds...) +end + +function rrule(::typeof(∇getindex), x, dy, dims, inds...) + z = ∇getindex(x, dy, dims, inds...) + function ∇getindex_pullback(dz) + dx = sum(view(unthunk(dz), inds...); dims=dims) + nots = map(_ -> NoTangent(), inds) + return (NoTangent(), NoTangent(), dx, NoTangent(), nots...) + end + return z, ∇getindex_pullback +end + + ##### ##### first, tail ##### diff --git a/src/rulesets/Base/sort.jl b/src/rulesets/Base/sort.jl index 42f674a24..ba3f3ddaf 100644 --- a/src/rulesets/Base/sort.jl +++ b/src/rulesets/Base/sort.jl @@ -64,7 +64,7 @@ function rrule(::typeof(sortslices), x::AbstractArray; dims::Integer, kw...) function sortslices_pullback(dy) # No actual need to zero this, and if you didn't, then you could widen eltype # Also, you could use similar(dy) here not x, same size? - dx = _zerolike_writeat(x, unthunk(dy), (), inds...) + dx = ∇getindex(x, unthunk(dy), (), inds...) return (NoTangent(), ProjectTo(x)(dx)) end return x[inds...], sortslices_pullback @@ -94,12 +94,11 @@ function rrule(::typeof(unique), x::AbstractArray{<:Number}; dims=:) mask .= (mask .== cumsum(mask, dims=1) .== true) # this implements findfirst(mask; dims=1) keep = map(I -> I[1], findall(mask)) if dims isa Colon - # The function `_zerolike_writeat` allows second derivatives. - # Should perhaps eventually be shared with `getindex`. - dx = reshape(_zerolike_writeat(vec(x), vec(dy), (), keep), axes_x) + # The function `∇getindex` allows second derivatives. + dx = reshape(∇getindex(vec(x), vec(dy), (), keep), axes_x) else inds = ntuple(d -> d==dims ? keep : (:), length(axes_x)) - dx = _zerolike_writeat(x, dy, (), inds...) + dx = ∇getindex(x, dy, (), inds...) end return (NoTangent(), ProjectTo(x)(dx)) end diff --git a/test/rulesets/Base/array.jl b/test/rulesets/Base/array.jl index dced0eb5e..42e3674e9 100644 --- a/test/rulesets/Base/array.jl +++ b/test/rulesets/Base/array.jl @@ -366,14 +366,6 @@ end @test [5 0; 6 0] == @inferred unthunk(rrule(findmin, [1 2; 3 4], dims=2)[2]((hcat([5,6]), nothing))[2]) test_rrule(findmin, rand(3,4), fkwargs=(dims=1,), output_tangent = (rand(1,4), NoTangent())) test_rrule(findmin, rand(3,4), fkwargs=(dims=2,)) - - # Second derivatives - test_frule(ChainRules._zerolike_writeat, rand(2, 2), 5.0, :, CartesianIndex(2, 2)) - test_rrule(ChainRules._zerolike_writeat, rand(2, 2), 5.0, :, CartesianIndex(2, 2)) - @test_skip test_rrule(ChainRules._zerolike_writeat, rand(2, 2), 5.0, 1, [CartesianIndex(2, 1) CartesianIndex(2, 2)] ⊢ NoTangent()) # MethodError: no method matching isapprox(::Matrix{Float64}, ::Float64; rtol=1.0e-9, atol=1.0e-9) - y, bk = rrule(ChainRules._zerolike_writeat, rand(2, 2), 5.0, 1, [CartesianIndex(2, 1) CartesianIndex(2, 2)]) - @test y == [0 0; 5 5] - @test bk([1 2; 3 4]) == (NoTangent(), NoTangent(), [3 4], NoTangent(), NoTangent()) end @testset "$imum" for imum in [maximum, minimum] diff --git a/test/rulesets/Base/indexing.jl b/test/rulesets/Base/indexing.jl index 29a7e2337..7f5efa682 100644 --- a/test/rulesets/Base/indexing.jl +++ b/test/rulesets/Base/indexing.jl @@ -77,6 +77,15 @@ test_rrule(getindex, x, [2,2], [3,3]) end end + + @testset "second derivatives" begin + test_frule(ChainRules.∇getindex, rand(2, 2), 5.0, :, CartesianIndex(2, 2)) + test_rrule(ChainRules.∇getindex, rand(2, 2), 5.0, :, CartesianIndex(2, 2)) + @test_skip test_rrule(ChainRules.∇getindex, rand(2, 2), 5.0, 1, [CartesianIndex(2, 1) CartesianIndex(2, 2)] ⊢ NoTangent()) # MethodError: no method matching isapprox(::Matrix{Float64}, ::Float64; rtol=1.0e-9, atol=1.0e-9) + y, bk = rrule(ChainRules.∇getindex, rand(2, 2), 5.0, 1, [CartesianIndex(2, 1) CartesianIndex(2, 2)]) + @test y == [0 0; 5 5] + @test bk([1 2; 3 4]) == (NoTangent(), NoTangent(), [3 4], NoTangent(), NoTangent()) + end end @testset "first & tail" begin From e006af3a7309505bb4fbbc3011cd5daebc4c79dd Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Mon, 1 Aug 2022 22:36:35 -0400 Subject: [PATCH 02/11] simplify, use it for getindex, tests --- src/rulesets/Base/array.jl | 17 +++--- src/rulesets/Base/indexing.jl | 94 +++++++++++++++++++--------------- src/rulesets/Base/sort.jl | 9 ++-- test/rulesets/Base/indexing.jl | 58 ++++++++++++++++++--- 4 files changed, 111 insertions(+), 67 deletions(-) diff --git a/src/rulesets/Base/array.jl b/src/rulesets/Base/array.jl index 5599ca552..174d4c180 100644 --- a/src/rulesets/Base/array.jl +++ b/src/rulesets/Base/array.jl @@ -515,20 +515,15 @@ for findm in (:findmin, :findmax) @eval function rrule(::typeof($findm), x::AbstractArray; dims=:) y, ind = $findm(x; dims=dims) + plain_inds = Base.to_indices(x, (ind,)) project = ProjectTo(x) - # This pullback is a lot like the one for getindex. Ideally they would probably be combined? function $findm_pullback((dy, _)) # this accepts e.g. Tangent{Tuple{Float64, Int64}}(4.0, nothing) dy isa AbstractZero && return (NoTangent(), NoTangent()) - x_thunk = @thunk project(∇getindex(x, unthunk(dy), dims, ind)) - x_ithunk = InplaceableThunk(x_thunk) do dx - if dims isa Colon - view(dx, ind) .= view(dx, ind) .+ Ref(unthunk(dy)) - else - view(dx, ind) .= view(dx, ind) .+ unthunk(dy) # this could be .+=, but not on Julia 1.0 - end - dx - end - return (NoTangent(), x_ithunk) + xthunk = InplaceableThunk( + dx -> ∇getindex!(dx, x, unthunk(dy), plain_inds...), + @thunk(∇getindex(x, unthunk(dy), plain_inds...)), + ) + return (NoTangent(), xthunk) end return (y, ind), $findm_pullback end diff --git a/src/rulesets/Base/indexing.jl b/src/rulesets/Base/indexing.jl index 9c58aa95e..4e431b345 100644 --- a/src/rulesets/Base/indexing.jl +++ b/src/rulesets/Base/indexing.jl @@ -54,81 +54,77 @@ end ##### -##### getindex +##### getindex(::AbstractArray) ##### function frule((_, ẋ), ::typeof(getindex), x::AbstractArray, inds...) return x[inds...], ẋ[inds...] end -function rrule(::typeof(getindex), x::Array{<:Number}, inds...) +function rrule(::typeof(getindex), x::AbstractArray, inds...) # removes any logical indexing, CartesianIndex etc # leaving us just with a tuple of Int, Arrays of Int and Ranges of Int plain_inds = Base.to_indices(x, inds) y = getindex(x, plain_inds...) function getindex_pullback(ȳ) - function getindex_add!(Δ) - # this a optimizes away for simple cases - for (ȳ_ii, ii) in zip(ȳ, Iterators.product(plain_inds...)) - Δ[ii...] += ȳ_ii - end - return Δ - end - - x̄ = InplaceableThunk( - getindex_add!, - @thunk(getindex_add!(zero(x))), + xthunk = InplaceableThunk( + x̄ -> ∇getindex!(x̄, x, unthunk(ȳ), plain_inds...), + @thunk(∇getindex(x, unthunk(ȳ), plain_inds...)), ) - īnds = broadcast(Returns(NoTangent()), inds) - return (NoTangent(), x̄, īnds...) + nots = map(Returns(NoTangent()), inds) + return (NoTangent(), xthunk, nots...) end - return y, getindex_pullback end - """ - ∇getindex(x, dy, dims, inds...) - -This function is roughly `setindex!(zero(x), dy, inds...)`. + ∇getindex(x, dy, inds...) +For the `rrule` of `y = x[inds...]`, this function is roughly +`setindex(zero(x), dy, inds...)`, returning the array `dx`. +Differentiable. Includes `ProjectTo(x)(dx)`. """ -function ∇getindex(x::AbstractArray{<:Number}, dy, dims, inds...) +function ∇getindex(x::AbstractArray{<:Number}, dy, inds...) # It's unfortunate to close over `x`, but `similar(typeof(x), axes(x))` doesn't # allow `eltype(dy)`, nor does it work for many structured matrices. dx = fill!(similar(x, eltype(dy), axes(x)), 0) - view(dx, inds...) .= dy # possibly 0-dim view, allows dy::Number and dy::Array, and dx::CuArray - dx + ∇getindex!(dx, x, dy, inds...) + return ProjectTo(x)(dx) # since we have x, may as well do this inside, not in rules end -function ∇getindex(x::AbstractArray, dy, dims, inds...) +function ∇getindex(x::AbstractArray, dy, inds...) # Since we have `x`, we can also handle arrays of arrays. - dx = map(zero, x) - if dims isa Colon - view(dx, inds...) .= Ref(dy) - else - view(dx, inds...) .= dy - end - dx + dx = map(zero, x) # this ignores type of dy, TODO? + ∇getindex!(dx, x, dy, inds...) + return ProjectTo(x)(dx) end -# Allow for second derivatives, by writing rules for `∇getindex`; -# these rules are the reason it takes a `dims` argument. +function ∇getindex!(dx::AbstractArray, x::AbstractArray, dy, inds::Integer...) + view(dx, inds...) .+= Ref(dy) + return dx +end +function ∇getindex!(dx::AbstractArray, x::AbstractArray, dy, inds...) + view(dx, inds...) .+= dy + # For GPU arrays, `inds::Union{Integer, Base.Slice}...` is fine, but any other AbstractArray risks overwriting. + # Those should call `NNlib.scatter!`, alla https://github.com/FluxML/Zygote.jl/pull/1131 + return dx +end -function frule((_, _, dẏ), ::typeof(∇getindex), x, dy, dims, inds...) - return ∇getindex(x, dy, dims, inds...), ∇getindex(x, dẏ, dims, inds...) +# Allow for second derivatives, by writing rules for `∇getindex`: + +function frule((_, _, dẏ), ::typeof(∇getindex), x, dy, inds...) + return ∇getindex(x, dy, inds...), ∇getindex(x, dẏ, inds...) end -function rrule(::typeof(∇getindex), x, dy, dims, inds...) - z = ∇getindex(x, dy, dims, inds...) +function rrule(::typeof(∇getindex), x, dy, inds...) + z = ∇getindex(x, dy, inds...) function ∇getindex_pullback(dz) - dx = sum(view(unthunk(dz), inds...); dims=dims) - nots = map(_ -> NoTangent(), inds) - return (NoTangent(), NoTangent(), dx, NoTangent(), nots...) + d2y = getindex(unthunk(dz), inds...) + nots = map(Returns(NoTangent()), inds) + return (NoTangent(), NoTangent(), ProjectTo(dy)(d2y), nots...) end return z, ∇getindex_pullback end - ##### ##### first, tail ##### @@ -160,6 +156,21 @@ function frule((_, ẋ), ::typeof(view), x::AbstractArray, inds...) return view(x, inds...), view(ẋ, inds...) end +# Identical to `getindex` above: +function rrule(::typeof(view), x::AbstractArray, inds...) + plain_inds = Base.to_indices(x, inds) + y = view(x, plain_inds...) + function view_pullback(ȳ) + xthunk = InplaceableThunk( + x̄ -> ∇getindex!(x̄, x, unthunk(ȳ), plain_inds...), + @thunk(∇getindex(x, unthunk(ȳ), plain_inds...)), + ) + nots = map(Returns(NoTangent()), inds) + return (NoTangent(), xthunk, nots...) + end + return y, view_pullback +end + ##### ##### setindex! ##### @@ -168,7 +179,6 @@ function frule((_, ẋ, v̇), ::typeof(setindex!), x::AbstractArray, v, inds...) return setindex!(x, v, inds...), setindex!(ẋ, v̇, inds...) end - ##### ##### `eachslice` and friends ##### diff --git a/src/rulesets/Base/sort.jl b/src/rulesets/Base/sort.jl index ba3f3ddaf..0805da91f 100644 --- a/src/rulesets/Base/sort.jl +++ b/src/rulesets/Base/sort.jl @@ -62,10 +62,7 @@ function rrule(::typeof(sortslices), x::AbstractArray; dims::Integer, kw...) p = sortperm(collect(eachslice(x; dims=dims)); kw...) inds = ntuple(d -> d == dims ? p : (:), ndims(x)) function sortslices_pullback(dy) - # No actual need to zero this, and if you didn't, then you could widen eltype - # Also, you could use similar(dy) here not x, same size? - dx = ∇getindex(x, unthunk(dy), (), inds...) - return (NoTangent(), ProjectTo(x)(dx)) + return (NoTangent(), ∇getindex(x, unthunk(dy), inds...)) end return x[inds...], sortslices_pullback end @@ -95,10 +92,10 @@ function rrule(::typeof(unique), x::AbstractArray{<:Number}; dims=:) keep = map(I -> I[1], findall(mask)) if dims isa Colon # The function `∇getindex` allows second derivatives. - dx = reshape(∇getindex(vec(x), vec(dy), (), keep), axes_x) + dx = reshape(∇getindex(vec(x), vec(dy), keep), axes_x) ## TODO understand again why vec! else inds = ntuple(d -> d==dims ? keep : (:), length(axes_x)) - dx = ∇getindex(x, dy, (), inds...) + dx = ∇getindex(x, dy, inds...) end return (NoTangent(), ProjectTo(x)(dx)) end diff --git a/test/rulesets/Base/indexing.jl b/test/rulesets/Base/indexing.jl index 7f5efa682..20c1d43ac 100644 --- a/test/rulesets/Base/indexing.jl +++ b/test/rulesets/Base/indexing.jl @@ -78,14 +78,56 @@ end end - @testset "second derivatives" begin - test_frule(ChainRules.∇getindex, rand(2, 2), 5.0, :, CartesianIndex(2, 2)) - test_rrule(ChainRules.∇getindex, rand(2, 2), 5.0, :, CartesianIndex(2, 2)) - @test_skip test_rrule(ChainRules.∇getindex, rand(2, 2), 5.0, 1, [CartesianIndex(2, 1) CartesianIndex(2, 2)] ⊢ NoTangent()) # MethodError: no method matching isapprox(::Matrix{Float64}, ::Float64; rtol=1.0e-9, atol=1.0e-9) - y, bk = rrule(ChainRules.∇getindex, rand(2, 2), 5.0, 1, [CartesianIndex(2, 1) CartesianIndex(2, 2)]) - @test y == [0 0; 5 5] - @test bk([1 2; 3 4]) == (NoTangent(), NoTangent(), [3 4], NoTangent(), NoTangent()) - end + @testset "getindex for structured arrays" begin + test_frule(getindex, Diagonal(rand(3)), 1) + test_frule(getindex, Symmetric(rand(3, 3)), 2, 3) + + test_rrule(getindex, Diagonal(rand(3)), 1) + @test_skip test_rrule(getindex, Diagonal(rand(3)), 2, :) # in-place update of off-diagonal entries + dgrad = rrule(getindex, Diagonal(rand(3)), 2, :)[2]([1,2,3])[2] + @test unthunk(dgrad) ≈ Diagonal([0, 2, 0]) + + test_rrule(getindex, Symmetric(rand(3, 3)), 2, 2) + sgrad = rrule(getindex, Symmetric(rand(3, 3)), 2, 3)[2](1.0)[2] + @test unthunk(sgrad) ≈ [0 0 0; 0 0 1/2; 0 1/2 0] + end + + @testset "getindex(::Array{<:Array})" begin + test_frule(getindex, [rand(2) for _ in 1:3], 1) + test_frule(getindex, [rand(2) for _ in 1:3], 2:3) + test_frule(getindex, [rand(2) for _ in 1:3], [true, false, true]) + + test_rrule(getindex, [rand(2) for _ in 1:3], 1; check_inferred=false) + test_rrule(getindex, [rand(2) for _ in 1:3], 2:3; check_inferred=false) + test_frule(getindex, [rand(2) for _ in 1:3], [true, false, true]; check_inferred=false) + end + + @testset "second derivatives: ∇getindex" begin + @eval using ChainRules: ∇getindex + # Forward, scalar result + test_frule(∇getindex, rand(2, 3), rand(), 3) + test_frule(∇getindex, rand(2, 3), rand()+im, 2, 1) + # array result + test_frule(∇getindex, rand(2, 3), rand(2), 4:5) + test_frule(∇getindex, rand(2, 3), rand(3), 1, :) + test_frule(∇getindex, rand(2, 3), rand(1, 2), [CartesianIndex(2, 1) CartesianIndex(2, 2)] ⊢ NoTangent()) + test_frule(∇getindex, rand(2, 3), rand(3), Bool[1 0 1; 0 1 0]) + # arrays of arrays + test_frule(∇getindex, [rand(2) for _ in 1:3], rand(2), 3) + test_frule(∇getindex, [rand(2) for _ in 1:3], [rand(2), rand(2)], 1:2) + + # Reverse, scalar result + test_rrule(∇getindex, rand(2, 3), rand(), 3; check_inferred=false) + test_rrule(∇getindex, rand(2, 3), rand()+im, 2, 1; check_inferred=false) + # array result + test_rrule(∇getindex, rand(2, 3), rand(2), 4:5; check_inferred=false) + test_rrule(∇getindex, rand(2, 3), rand(3), 1, :; check_inferred=false) + test_rrule(∇getindex, rand(2, 3), rand(1, 2), [CartesianIndex(2, 1) CartesianIndex(2, 2)] ⊢ NoTangent(); check_inferred=false) + test_rrule(∇getindex, rand(2, 3), rand(3), Bool[1 0 1; 0 1 0]; check_inferred=false) + # arrays of arrays + test_rrule(∇getindex, [rand(2) for _ in 1:3], rand(2), 3; check_inferred=false) + test_rrule(∇getindex, [rand(2) for _ in 1:3], [rand(2), rand(2)], 1:2; check_inferred=false) + end end @testset "first & tail" begin From 94e90fb918be7f2974dacb99ad5d4daaa99865dd Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Tue, 2 Aug 2022 09:31:35 -0400 Subject: [PATCH 03/11] add unsafe_getindex too --- src/rulesets/Base/indexing.jl | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/src/rulesets/Base/indexing.jl b/src/rulesets/Base/indexing.jl index 4e431b345..408de313f 100644 --- a/src/rulesets/Base/indexing.jl +++ b/src/rulesets/Base/indexing.jl @@ -179,6 +179,17 @@ function frule((_, ẋ, v̇), ::typeof(setindex!), x::AbstractArray, v, inds...) return setindex!(x, v, inds...), setindex!(ẋ, v̇, inds...) end +##### +##### unsafe_getindex +##### + +# This is called by e.g. `iterate(1:0.1:2)`, +# and fixes https://github.com/FluxML/Zygote.jl/issues/1247 + +function rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(Base.unsafe_getindex), x::AbstractRange, i::Integer) + return rrule_via_ad(cfg, getindex, x, i) +end + ##### ##### `eachslice` and friends ##### From 0e1c2c5f6e131bf6776d0d5da26c93c3825baecc Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Tue, 2 Aug 2022 19:22:00 -0400 Subject: [PATCH 04/11] tidy, make weird types work via _setindex_zero --- src/rulesets/Base/array.jl | 8 +---- src/rulesets/Base/indexing.jl | 65 ++++++++++++++++++---------------- test/rulesets/Base/indexing.jl | 19 ++++++++++ 3 files changed, 54 insertions(+), 38 deletions(-) diff --git a/src/rulesets/Base/array.jl b/src/rulesets/Base/array.jl index 174d4c180..a14d0368e 100644 --- a/src/rulesets/Base/array.jl +++ b/src/rulesets/Base/array.jl @@ -515,15 +515,9 @@ for findm in (:findmin, :findmax) @eval function rrule(::typeof($findm), x::AbstractArray; dims=:) y, ind = $findm(x; dims=dims) - plain_inds = Base.to_indices(x, (ind,)) - project = ProjectTo(x) function $findm_pullback((dy, _)) # this accepts e.g. Tangent{Tuple{Float64, Int64}}(4.0, nothing) dy isa AbstractZero && return (NoTangent(), NoTangent()) - xthunk = InplaceableThunk( - dx -> ∇getindex!(dx, x, unthunk(dy), plain_inds...), - @thunk(∇getindex(x, unthunk(dy), plain_inds...)), - ) - return (NoTangent(), xthunk) + return (NoTangent(), thunked∇getindex(x, dy, ind),) end return (y, ind), $findm_pullback end diff --git a/src/rulesets/Base/indexing.jl b/src/rulesets/Base/indexing.jl index 408de313f..9b8a6f578 100644 --- a/src/rulesets/Base/indexing.jl +++ b/src/rulesets/Base/indexing.jl @@ -62,19 +62,18 @@ function frule((_, ẋ), ::typeof(getindex), x::AbstractArray, inds...) end function rrule(::typeof(getindex), x::AbstractArray, inds...) - # removes any logical indexing, CartesianIndex etc - # leaving us just with a tuple of Int, Arrays of Int and Ranges of Int - plain_inds = Base.to_indices(x, inds) - y = getindex(x, plain_inds...) - function getindex_pullback(ȳ) - xthunk = InplaceableThunk( - x̄ -> ∇getindex!(x̄, x, unthunk(ȳ), plain_inds...), - @thunk(∇getindex(x, unthunk(ȳ), plain_inds...)), - ) + function getindex_pullback(dy) nots = map(Returns(NoTangent()), inds) - return (NoTangent(), xthunk, nots...) + return (NoTangent(), thunked∇getindex(x, dy, inds...), nots...) end - return y, getindex_pullback + return x[inds...], getindex_pullback +end + +function thunked∇getindex(x, dy, inds...) + return InplaceableThunk( + dx -> ∇getindex!(dx, x, unthunk(dy), Base.to_indices(x, inds)...), + @thunk(∇getindex(x, unthunk(dy), inds...)), + ) end """ @@ -84,19 +83,30 @@ For the `rrule` of `y = x[inds...]`, this function is roughly `setindex(zero(x), dy, inds...)`, returning the array `dx`. Differentiable. Includes `ProjectTo(x)(dx)`. """ -function ∇getindex(x::AbstractArray{<:Number}, dy, inds...) - # It's unfortunate to close over `x`, but `similar(typeof(x), axes(x))` doesn't - # allow `eltype(dy)`, nor does it work for many structured matrices. - dx = fill!(similar(x, eltype(dy), axes(x)), 0) - ∇getindex!(dx, x, dy, inds...) +function ∇getindex(x::AbstractArray, dy, inds...) + # `to_indices` removes any logical indexing, colons, CartesianIndex etc, + # leaving just Int / AbstractVector of Int + plain_inds = Base.to_indices(x, inds) + dx = _setindex_zero(x, dy, plain_inds...) + ∇getindex!(dx, x, dy, plain_inds...) return ProjectTo(x)(dx) # since we have x, may as well do this inside, not in rules end -function ∇getindex(x::AbstractArray, dy, inds...) - # Since we have `x`, we can also handle arrays of arrays. - dx = map(zero, x) # this ignores type of dy, TODO? - ∇getindex!(dx, x, dy, inds...) - return ProjectTo(x)(dx) + +# It's unfortunate to close over `x`, but `similar(typeof(x), axes(x))` doesn't +# allow `eltype(dy)`, nor does it work for many structured matrices. +_setindex_zero(x::AbstractArray{<:Number}, dy, inds::Integer...) = fill!(similar(x, typeof(dy), axes(x)), 0) +_setindex_zero(x::AbstractArray{<:Number}, dy, inds...) = fill!(similar(x, eltype(dy), axes(x)), 0) +function _setindex_zero(x::AbstractArray, dy, inds::Integer...) + # This allows for types which don't define zero (like Vector) and types whose zero special (like Tangent), + # but always makes an abstract type. TODO: make it infer concrete type for e.g. vectors of SVectors + T = Union{typeof(dy), ZeroTangent} + return fill!(similar(x, T, axes(x)), ZeroTangent()) end +function _setindex_zero(x::AbstractArray, dy, inds...) + T = Union{eltype(dy), ZeroTangent} + return fill!(similar(x, T, axes(x)), ZeroTangent()) +end +ChainRules.@non_differentiable _setindex_zero(x::AbstractArray, dy::Any, inds::Any...) function ∇getindex!(dx::AbstractArray, x::AbstractArray, dy, inds::Integer...) view(dx, inds...) .+= Ref(dy) @@ -156,19 +166,12 @@ function frule((_, ẋ), ::typeof(view), x::AbstractArray, inds...) return view(x, inds...), view(ẋ, inds...) end -# Identical to `getindex` above: function rrule(::typeof(view), x::AbstractArray, inds...) - plain_inds = Base.to_indices(x, inds) - y = view(x, plain_inds...) - function view_pullback(ȳ) - xthunk = InplaceableThunk( - x̄ -> ∇getindex!(x̄, x, unthunk(ȳ), plain_inds...), - @thunk(∇getindex(x, unthunk(ȳ), plain_inds...)), - ) + function view_pullback(dy) nots = map(Returns(NoTangent()), inds) - return (NoTangent(), xthunk, nots...) + return (NoTangent(), thunked∇getindex(x, dy, inds...), nots...) end - return y, view_pullback + return view(x, inds...), view_pullback end ##### diff --git a/test/rulesets/Base/indexing.jl b/test/rulesets/Base/indexing.jl index 20c1d43ac..a02026fe7 100644 --- a/test/rulesets/Base/indexing.jl +++ b/test/rulesets/Base/indexing.jl @@ -101,6 +101,21 @@ test_rrule(getindex, [rand(2) for _ in 1:3], 2:3; check_inferred=false) test_frule(getindex, [rand(2) for _ in 1:3], [true, false, true]; check_inferred=false) end + + @testset "getindex(::Array{<:Weird})" begin + xfix = [Base.Fix1(*, pi), Base.Fix1(^, ℯ), Base.Fix1(/, -1)] + dxfix = [Tangent{Base.Fix1}(; x = i/10) for i in 1:3] + # test_frule(getindex, xfix ⊢ dxfix, 1) + # test_rrule(getindex, xfix ⊢ dxfix, 1) + + dx1 = unthunk(rrule(getindex, xfix, 1)[2](dxfix[1])[2]) + @test dx1[1] == dxfix[1] + @test iszero(dx1[2]) + + dx23 = unthunk(rrule(getindex, xfix, 2:3)[2](dxfix[2:3])[2]) + @test iszero(dx23[1]) + @test dx23[3] == dxfix[3] + end @testset "second derivatives: ∇getindex" begin @eval using ChainRules: ∇getindex @@ -151,6 +166,10 @@ end test_frule(view, rand(3, 4), :, 1) test_frule(view, rand(3, 4), 2, [1, 1, 2]) test_frule(view, rand(3, 4), 3, 4) + + test_rrule(view, rand(3, 4), :, 1) + test_rrule(view, rand(3, 4), 2, [1, 1, 2]) + @test_broken test_rrule(view, rand(3, 4), 3, 4) # This is why ∇getindex needs one more argument, dammit end @testset "setindex!" begin From c1ebf3fbb1a65df40d38ba0c49e0acb55f2202d4 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Tue, 2 Aug 2022 22:09:42 -0400 Subject: [PATCH 05/11] fix view & its zero-arrays --- src/rulesets/Base/indexing.jl | 10 +++++++++- test/rulesets/Base/array.jl | 1 + test/rulesets/Base/indexing.jl | 2 +- 3 files changed, 11 insertions(+), 2 deletions(-) diff --git a/src/rulesets/Base/indexing.jl b/src/rulesets/Base/indexing.jl index 9b8a6f578..7b6568e09 100644 --- a/src/rulesets/Base/indexing.jl +++ b/src/rulesets/Base/indexing.jl @@ -52,7 +52,6 @@ function rrule(::typeof(getindex), x::Tuple, ::Colon) return x, getindex_back_4 end - ##### ##### getindex(::AbstractArray) ##### @@ -174,6 +173,15 @@ function rrule(::typeof(view), x::AbstractArray, inds...) return view(x, inds...), view_pullback end +function rrule(::typeof(view), x::AbstractArray, i::Integer, jkl::Integer...) + # This case returns a zero-dim array, unlike getindex. So we fool ∇getindex: + function view_pullback_0(dy) + nots = map(Returns(NoTangent()), (i, jkl...)) + return (NoTangent(), thunked∇getindex(x, dy, i:i, jkl...), nots...) + end + return view(x, i, jkl...), view_pullback_0 +end + ##### ##### setindex! ##### diff --git a/test/rulesets/Base/array.jl b/test/rulesets/Base/array.jl index 42e3674e9..075c5e050 100644 --- a/test/rulesets/Base/array.jl +++ b/test/rulesets/Base/array.jl @@ -366,6 +366,7 @@ end @test [5 0; 6 0] == @inferred unthunk(rrule(findmin, [1 2; 3 4], dims=2)[2]((hcat([5,6]), nothing))[2]) test_rrule(findmin, rand(3,4), fkwargs=(dims=1,), output_tangent = (rand(1,4), NoTangent())) test_rrule(findmin, rand(3,4), fkwargs=(dims=2,)) + test_rrule(findmin, rand(3,4), fkwargs=(dims=(1,2),)) end @testset "$imum" for imum in [maximum, minimum] diff --git a/test/rulesets/Base/indexing.jl b/test/rulesets/Base/indexing.jl index a02026fe7..695f06010 100644 --- a/test/rulesets/Base/indexing.jl +++ b/test/rulesets/Base/indexing.jl @@ -169,7 +169,7 @@ end test_rrule(view, rand(3, 4), :, 1) test_rrule(view, rand(3, 4), 2, [1, 1, 2]) - @test_broken test_rrule(view, rand(3, 4), 3, 4) # This is why ∇getindex needs one more argument, dammit + test_rrule(view, rand(3, 4), 3, 4) end @testset "setindex!" begin From 9e1aa8c9a0df72278987dd96b3dade9599ba5581 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Tue, 2 Aug 2022 23:09:46 -0400 Subject: [PATCH 06/11] test unsafe_getindex --- Project.toml | 2 +- src/rulesets/Base/indexing.jl | 9 +++++++-- test/rulesets/Base/indexing.jl | 5 +++++ 3 files changed, 13 insertions(+), 3 deletions(-) diff --git a/Project.toml b/Project.toml index e44739e88..64b86cc45 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRules" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "1.43.2" +version = "1.44.0" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/src/rulesets/Base/indexing.jl b/src/rulesets/Base/indexing.jl index 7b6568e09..db25537b2 100644 --- a/src/rulesets/Base/indexing.jl +++ b/src/rulesets/Base/indexing.jl @@ -196,9 +196,14 @@ end # This is called by e.g. `iterate(1:0.1:2)`, # and fixes https://github.com/FluxML/Zygote.jl/issues/1247 +# Only needs to accept AbstractRange, but AbstractVector makes testing easier. -function rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(Base.unsafe_getindex), x::AbstractRange, i::Integer) - return rrule_via_ad(cfg, getindex, x, i) +function frule((_, ẋ), ::typeof(Base.unsafe_getindex), x::AbstractVector, i::Integer) + return Base.unsafe_getindex(x, i), getindex(ẋ, i) +end + +function rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(Base.unsafe_getindex), x::AbstractVector, i::Integer) + return rrule_via_ad(cfg, getindex, x, i) end ##### diff --git a/test/rulesets/Base/indexing.jl b/test/rulesets/Base/indexing.jl index 695f06010..0d5b09398 100644 --- a/test/rulesets/Base/indexing.jl +++ b/test/rulesets/Base/indexing.jl @@ -177,6 +177,11 @@ end test_frule(setindex!, rand(3, 4), [1,10,100.0], :, 3) end +@testset "unsafe_getindex" begin + test_frule(Base.unsafe_getindex, collect(1:0.1:2), 3) + test_rrule(Base.unsafe_getindex, collect(1:0.1:2), 3) +end + @testset "eachslice" begin # Testing eachrow not collect∘eachrow leads to errors, e.g. # test_rrule: eachrow on Vector{Float64}: Error During Test at /Users/me/.julia/packages/ChainRulesTestUtils/8dFTY/src/testers.jl:195 From f49b118c00cbf8653616d5f2d4803a0a851a9377 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Wed, 3 Aug 2022 00:59:53 -0400 Subject: [PATCH 07/11] handle indexing of GPU arrays --- Project.toml | 5 +++-- src/ChainRules.jl | 3 ++- src/rulesets/Base/base.jl | 20 ++++++++++++++++++++ src/rulesets/Base/indexing.jl | 21 +++++++++++++++++++-- test/rulesets/Base/indexing.jl | 20 ++++++++++++++++++++ test/runtests.jl | 3 +++ 6 files changed, 67 insertions(+), 5 deletions(-) diff --git a/Project.toml b/Project.toml index 64b86cc45..b66176754 100644 --- a/Project.toml +++ b/Project.toml @@ -3,6 +3,7 @@ uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" version = "1.44.0" [deps] +Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" @@ -16,6 +17,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a" [compat] +Adapt = "3.4.0" ChainRulesCore = "1.15.3" ChainRulesTestUtils = "1.5" Compat = "3.42.0, 4" @@ -30,7 +32,6 @@ StructArrays = "0.6.11" julia = "1.6" [extras] -Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb" @@ -40,4 +41,4 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Adapt", "ChainRulesTestUtils", "FiniteDifferences", "JLArrays", "JuliaInterpreter", "Random", "StaticArrays", "Test"] +test = ["ChainRulesTestUtils", "FiniteDifferences", "JLArrays", "JuliaInterpreter", "Random", "StaticArrays", "Test"] diff --git a/src/ChainRules.jl b/src/ChainRules.jl index b314d7be7..30e492d2e 100644 --- a/src/ChainRules.jl +++ b/src/ChainRules.jl @@ -1,10 +1,11 @@ module ChainRules +using Adapt: adapt using Base.Broadcast: materialize, materialize!, broadcasted, Broadcasted, broadcastable using ChainRulesCore using Compat using Distributed -using GPUArraysCore: AbstractGPUArrayStyle +using GPUArraysCore: AbstractGPUArray, AbstractGPUArrayStyle using IrrationalConstants: logtwo, logten using LinearAlgebra using LinearAlgebra.BLAS diff --git a/src/rulesets/Base/base.jl b/src/rulesets/Base/base.jl index c10ba6e71..d486fd67e 100644 --- a/src/rulesets/Base/base.jl +++ b/src/rulesets/Base/base.jl @@ -243,3 +243,23 @@ function rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(map), f::F, xs::Tu map_back(dy::AbstractZero) = (NoTangent(), NoTangent(), ntuple(Returns(NoTangent()), num_xs)...) return y, map_pullback end + +##### +##### `task_local_storage` +##### + +# Called by `@allowscalar` from GPUArrays + +ChainRules.@non_differentiable task_local_storage(key::Any) +ChainRules.@non_differentiable task_local_storage(key::Any, value::Any) + +function rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(task_local_storage), body::Function, key, value) + y, back = task_local_storage(key, value) do + rrule_via_ad(config, body) + end + function task_local_storage_pullback(dy) + dbody = only(back(dy)) + return (NoTangent(), dbody, NoTangent(), NoTangent()) + end + return y, task_local_storage_pullback +end diff --git a/src/rulesets/Base/indexing.jl b/src/rulesets/Base/indexing.jl index db25537b2..14bdf747d 100644 --- a/src/rulesets/Base/indexing.jl +++ b/src/rulesets/Base/indexing.jl @@ -113,8 +113,6 @@ function ∇getindex!(dx::AbstractArray, x::AbstractArray, dy, inds::Integer...) end function ∇getindex!(dx::AbstractArray, x::AbstractArray, dy, inds...) view(dx, inds...) .+= dy - # For GPU arrays, `inds::Union{Integer, Base.Slice}...` is fine, but any other AbstractArray risks overwriting. - # Those should call `NNlib.scatter!`, alla https://github.com/FluxML/Zygote.jl/pull/1131 return dx end @@ -134,6 +132,25 @@ function rrule(::typeof(∇getindex), x, dy, inds...) return z, ∇getindex_pullback end +# Indexing with repeated indices on a GPU will lead ∇getindex to have race conditions & wrong answers. +# To avoid this, copy everything back to the CPU. +# But don't do that for indices which are known to be unique, e.g. `A[1, 2:3, :]` the colon gives Base.Slice: + +function ∇getindex!(dx::AbstractGPUArray, x::AbstractArray, dy, inds::Integer...) + view(dx, inds...) .+= Ref(dy) + return dx +end +function ∇getindex!(dx::AbstractGPUArray, x::AbstractArray, dy, inds::Union{Integer, AbstractUnitRange, Base.Slice}...) + view(dx, inds...) .+= dy + return dx +end +function ∇getindex!(dx::AbstractGPUArray, x::AbstractArray, dy, inds...) + dx_cpu = adapt(Array, dx) + view(dx_cpu, adapt(Array, inds)...) .+= adapt(Array, dy) + copyto!(dx, dx_cpu) + return dx +end + ##### ##### first, tail ##### diff --git a/test/rulesets/Base/indexing.jl b/test/rulesets/Base/indexing.jl index 0d5b09398..c2b9203a4 100644 --- a/test/rulesets/Base/indexing.jl +++ b/test/rulesets/Base/indexing.jl @@ -143,6 +143,25 @@ test_rrule(∇getindex, [rand(2) for _ in 1:3], rand(2), 3; check_inferred=false) test_rrule(∇getindex, [rand(2) for _ in 1:3], [rand(2), rand(2)], 1:2; check_inferred=false) end + + @testset "GPU" begin + x_23_gpu = jl(rand(2, 3)) + + # Scalar indexing, copied from: @macroexpand @allowscalar A[i] + # Gives an error in Pkg.test, no idea why + # y1, bk1 = rrule(CFG, Base.task_local_storage, () -> x_23_gpu[1], :ScalarIndexing, ScalarAllowed) + # @test y1 == @allowscalar x_gpu[1] + # bk1(1.0) # This is zero, because finite-differencing ignores the function + # ... but this works, and calls the rule: + # Zygote.gradient(x -> @allowscalar(x[1]), jl(rand(3)))[1] + + y2, bk2 = rrule(getindex, x_23_gpu, :, 2:3) # fast path, just broadcast .+= + @test unthunk(bk2(jl(ones(2,2)))[2]) == jl([0 1 1; 0 1 1]) + + y3, bk3 = rrule(getindex, x_23_gpu, 1, [1,1,2]) # slow path, copy to CPU + @test_skip Array(y3) == Array(x_gpu)[1, [1,1,2]] # error in Pkg.test, no idea why + @test unthunk(bk3(jl(ones(3)))[2]) == jl([2 1 0; 0 0 0]) + end end @testset "first & tail" begin @@ -178,6 +197,7 @@ end end @testset "unsafe_getindex" begin + # In real life this is called only on some AbstractRanges, but easier to test on Array: test_frule(Base.unsafe_getindex, collect(1:0.1:2), 3) test_rrule(Base.unsafe_getindex, collect(1:0.1:2), 3) end diff --git a/test/runtests.jl b/test/runtests.jl index 9ac5c5981..71444f388 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2,12 +2,15 @@ using Test, ChainRulesCore, ChainRulesTestUtils @nospecialize +using Adapt using Base.Broadcast: broadcastable using ChainRules using ChainRulesCore using ChainRulesTestUtils using ChainRulesTestUtils: rand_tangent, _fdm using FiniteDifferences +using GPUArraysCore +using JLArrays using LinearAlgebra using LinearAlgebra.BLAS using LinearAlgebra: dot From 79fcbd9e112ba1fe40fcbcbba17b4a4ee5b18ea0 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Thu, 11 Aug 2022 15:26:18 -0700 Subject: [PATCH 08/11] suggested changes --- src/rulesets/Base/array.jl | 2 +- src/rulesets/Base/indexing.jl | 38 ++++++++++++++++++++-------------- test/rulesets/Base/indexing.jl | 2 +- 3 files changed, 25 insertions(+), 17 deletions(-) diff --git a/src/rulesets/Base/array.jl b/src/rulesets/Base/array.jl index a14d0368e..49a4e1ac2 100644 --- a/src/rulesets/Base/array.jl +++ b/src/rulesets/Base/array.jl @@ -517,7 +517,7 @@ for findm in (:findmin, :findmax) y, ind = $findm(x; dims=dims) function $findm_pullback((dy, _)) # this accepts e.g. Tangent{Tuple{Float64, Int64}}(4.0, nothing) dy isa AbstractZero && return (NoTangent(), NoTangent()) - return (NoTangent(), thunked∇getindex(x, dy, ind),) + return (NoTangent(), thunked_∇getindex(x, dy, ind),) end return (y, ind), $findm_pullback end diff --git a/src/rulesets/Base/indexing.jl b/src/rulesets/Base/indexing.jl index 14bdf747d..76ab59232 100644 --- a/src/rulesets/Base/indexing.jl +++ b/src/rulesets/Base/indexing.jl @@ -63,14 +63,14 @@ end function rrule(::typeof(getindex), x::AbstractArray, inds...) function getindex_pullback(dy) nots = map(Returns(NoTangent()), inds) - return (NoTangent(), thunked∇getindex(x, dy, inds...), nots...) + return (NoTangent(), thunked_∇getindex(x, dy, inds...), nots...) end return x[inds...], getindex_pullback end -function thunked∇getindex(x, dy, inds...) +function thunked_∇getindex(x, dy, inds...) return InplaceableThunk( - dx -> ∇getindex!(dx, x, unthunk(dy), Base.to_indices(x, inds)...), + dx -> ∇getindex!(dx, unthunk(dy), Base.to_indices(x, inds)...), @thunk(∇getindex(x, unthunk(dy), inds...)), ) end @@ -87,14 +87,22 @@ function ∇getindex(x::AbstractArray, dy, inds...) # leaving just Int / AbstractVector of Int plain_inds = Base.to_indices(x, inds) dx = _setindex_zero(x, dy, plain_inds...) - ∇getindex!(dx, x, dy, plain_inds...) + ∇getindex!(dx, dy, plain_inds...) return ProjectTo(x)(dx) # since we have x, may as well do this inside, not in rules end -# It's unfortunate to close over `x`, but `similar(typeof(x), axes(x))` doesn't -# allow `eltype(dy)`, nor does it work for many structured matrices. -_setindex_zero(x::AbstractArray{<:Number}, dy, inds::Integer...) = fill!(similar(x, typeof(dy), axes(x)), 0) -_setindex_zero(x::AbstractArray{<:Number}, dy, inds...) = fill!(similar(x, eltype(dy), axes(x)), 0) +""" + _setindex_zero(x, dy, inds...) + +This returns roughly `dx = zero(x)`, except that this is guaranteed to be mutable via `similar`, +and its element type is wide enough to allow `setindex!(dx, dy, inds...)`, which is exactly what +`∇getindex` does next. + +It's unfortunate to close over `x`, but `similar(typeof(x), axes(x))` doesn't +allow `eltype(dy)`, nor does it work for many structured matrices. +""" +_setindex_zero(x::AbstractArray{<:Number}, dy, inds::Integer...) = fill!(similar(x, typeof(dy), axes(x)), ZeroTangent()) +_setindex_zero(x::AbstractArray{<:Number}, dy, inds...) = fill!(similar(x, eltype(dy), axes(x)), ZeroTangent()) function _setindex_zero(x::AbstractArray, dy, inds::Integer...) # This allows for types which don't define zero (like Vector) and types whose zero special (like Tangent), # but always makes an abstract type. TODO: make it infer concrete type for e.g. vectors of SVectors @@ -107,11 +115,11 @@ function _setindex_zero(x::AbstractArray, dy, inds...) end ChainRules.@non_differentiable _setindex_zero(x::AbstractArray, dy::Any, inds::Any...) -function ∇getindex!(dx::AbstractArray, x::AbstractArray, dy, inds::Integer...) +function ∇getindex!(dx::AbstractArray, dy, inds::Integer...) view(dx, inds...) .+= Ref(dy) return dx end -function ∇getindex!(dx::AbstractArray, x::AbstractArray, dy, inds...) +function ∇getindex!(dx::AbstractArray, dy, inds...) view(dx, inds...) .+= dy return dx end @@ -136,15 +144,15 @@ end # To avoid this, copy everything back to the CPU. # But don't do that for indices which are known to be unique, e.g. `A[1, 2:3, :]` the colon gives Base.Slice: -function ∇getindex!(dx::AbstractGPUArray, x::AbstractArray, dy, inds::Integer...) +function ∇getindex!(dx::AbstractGPUArray, dy, inds::Integer...) view(dx, inds...) .+= Ref(dy) return dx end -function ∇getindex!(dx::AbstractGPUArray, x::AbstractArray, dy, inds::Union{Integer, AbstractUnitRange, Base.Slice}...) +function ∇getindex!(dx::AbstractGPUArray, dy, inds::Union{Integer, AbstractUnitRange, Base.Slice}...) view(dx, inds...) .+= dy return dx end -function ∇getindex!(dx::AbstractGPUArray, x::AbstractArray, dy, inds...) +function ∇getindex!(dx::AbstractGPUArray, dy, inds...) dx_cpu = adapt(Array, dx) view(dx_cpu, adapt(Array, inds)...) .+= adapt(Array, dy) copyto!(dx, dx_cpu) @@ -185,7 +193,7 @@ end function rrule(::typeof(view), x::AbstractArray, inds...) function view_pullback(dy) nots = map(Returns(NoTangent()), inds) - return (NoTangent(), thunked∇getindex(x, dy, inds...), nots...) + return (NoTangent(), thunked_∇getindex(x, dy, inds...), nots...) end return view(x, inds...), view_pullback end @@ -194,7 +202,7 @@ function rrule(::typeof(view), x::AbstractArray, i::Integer, jkl::Integer...) # This case returns a zero-dim array, unlike getindex. So we fool ∇getindex: function view_pullback_0(dy) nots = map(Returns(NoTangent()), (i, jkl...)) - return (NoTangent(), thunked∇getindex(x, dy, i:i, jkl...), nots...) + return (NoTangent(), thunked_∇getindex(x, dy, i:i, jkl...), nots...) end return view(x, i, jkl...), view_pullback_0 end diff --git a/test/rulesets/Base/indexing.jl b/test/rulesets/Base/indexing.jl index c2b9203a4..7e37a86b4 100644 --- a/test/rulesets/Base/indexing.jl +++ b/test/rulesets/Base/indexing.jl @@ -83,7 +83,7 @@ test_frule(getindex, Symmetric(rand(3, 3)), 2, 3) test_rrule(getindex, Diagonal(rand(3)), 1) - @test_skip test_rrule(getindex, Diagonal(rand(3)), 2, :) # in-place update of off-diagonal entries + @test_skip test_rrule(getindex, Diagonal(rand(3)), 2, :) # https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/260 dgrad = rrule(getindex, Diagonal(rand(3)), 2, :)[2]([1,2,3])[2] @test unthunk(dgrad) ≈ Diagonal([0, 2, 0]) From f11ce50ec82d118edcb9eb30f19bfa076dac57a3 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Thu, 11 Aug 2022 20:46:50 -0700 Subject: [PATCH 09/11] restore some gpu tests --- test/rulesets/Base/indexing.jl | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/test/rulesets/Base/indexing.jl b/test/rulesets/Base/indexing.jl index 7e37a86b4..f5d9e0ed0 100644 --- a/test/rulesets/Base/indexing.jl +++ b/test/rulesets/Base/indexing.jl @@ -144,14 +144,13 @@ test_rrule(∇getindex, [rand(2) for _ in 1:3], [rand(2), rand(2)], 1:2; check_inferred=false) end - @testset "GPU" begin - x_23_gpu = jl(rand(2, 3)) + @testset "getindex(::AbstractGPUArray)" begin + x_23_gpu = jl(rand(2, 3)) # using JLArrays, loaded for @gpu in test_helpers.jl # Scalar indexing, copied from: @macroexpand @allowscalar A[i] - # Gives an error in Pkg.test, no idea why - # y1, bk1 = rrule(CFG, Base.task_local_storage, () -> x_23_gpu[1], :ScalarIndexing, ScalarAllowed) - # @test y1 == @allowscalar x_gpu[1] - # bk1(1.0) # This is zero, because finite-differencing ignores the function + y1, bk1 = rrule(CFG, Base.task_local_storage, () -> x_23_gpu[1], :ScalarIndexing, ScalarAllowed) + @test y1 == @allowscalar x_23_gpu[1] + bk1(1.0) # This is zero, because finite-differencing ignores the function # ... but this works, and calls the rule: # Zygote.gradient(x -> @allowscalar(x[1]), jl(rand(3)))[1] @@ -159,7 +158,7 @@ @test unthunk(bk2(jl(ones(2,2)))[2]) == jl([0 1 1; 0 1 1]) y3, bk3 = rrule(getindex, x_23_gpu, 1, [1,1,2]) # slow path, copy to CPU - @test_skip Array(y3) == Array(x_gpu)[1, [1,1,2]] # error in Pkg.test, no idea why + @test Array(y3) == Array(x_gpu)[1, [1,1,2]] @test unthunk(bk3(jl(ones(3)))[2]) == jl([2 1 0; 0 0 0]) end end From 7d183f1750f4d689e793fafe5409c878c98636ee Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Thu, 11 Aug 2022 21:01:37 -0700 Subject: [PATCH 10/11] avoid the error --- test/rulesets/Base/indexing.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/rulesets/Base/indexing.jl b/test/rulesets/Base/indexing.jl index f5d9e0ed0..52731e26d 100644 --- a/test/rulesets/Base/indexing.jl +++ b/test/rulesets/Base/indexing.jl @@ -150,15 +150,15 @@ # Scalar indexing, copied from: @macroexpand @allowscalar A[i] y1, bk1 = rrule(CFG, Base.task_local_storage, () -> x_23_gpu[1], :ScalarIndexing, ScalarAllowed) @test y1 == @allowscalar x_23_gpu[1] - bk1(1.0) # This is zero, because finite-differencing ignores the function - # ... but this works, and calls the rule: + # bk1(1.0) # This gives a StackOverflowError! + # Also gives zero in global scope, error when in a let block? But this works, and calls the rule: # Zygote.gradient(x -> @allowscalar(x[1]), jl(rand(3)))[1] - + y2, bk2 = rrule(getindex, x_23_gpu, :, 2:3) # fast path, just broadcast .+= @test unthunk(bk2(jl(ones(2,2)))[2]) == jl([0 1 1; 0 1 1]) y3, bk3 = rrule(getindex, x_23_gpu, 1, [1,1,2]) # slow path, copy to CPU - @test Array(y3) == Array(x_gpu)[1, [1,1,2]] + @test Array(y3) == Array(x_23_gpu)[1, [1,1,2]] @test unthunk(bk3(jl(ones(3)))[2]) == jl([2 1 0; 0 0 0]) end end From adab9b0aa00cc1a9ba77b86c916b73f42c2fd96b Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Thu, 11 Aug 2022 22:52:37 -0700 Subject: [PATCH 11/11] in fact, mystery errors persist --- test/rulesets/Base/indexing.jl | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/test/rulesets/Base/indexing.jl b/test/rulesets/Base/indexing.jl index 52731e26d..a0748276c 100644 --- a/test/rulesets/Base/indexing.jl +++ b/test/rulesets/Base/indexing.jl @@ -148,10 +148,15 @@ x_23_gpu = jl(rand(2, 3)) # using JLArrays, loaded for @gpu in test_helpers.jl # Scalar indexing, copied from: @macroexpand @allowscalar A[i] - y1, bk1 = rrule(CFG, Base.task_local_storage, () -> x_23_gpu[1], :ScalarIndexing, ScalarAllowed) - @test y1 == @allowscalar x_23_gpu[1] - # bk1(1.0) # This gives a StackOverflowError! - # Also gives zero in global scope, error when in a let block? But this works, and calls the rule: + @test_skip begin # This gives + y1, bk1 = rrule(CFG, Base.task_local_storage, () -> x_23_gpu[1], :ScalarIndexing, ScalarAllowed) + @test y1 == @allowscalar x_23_gpu[1] + end + @test_skip begin + bk1(1.0) # This gives a StackOverflowError! Or gives zero in global scope. + true + end + # But this works, and calls the rule: # Zygote.gradient(x -> @allowscalar(x[1]), jl(rand(3)))[1] y2, bk2 = rrule(getindex, x_23_gpu, :, 2:3) # fast path, just broadcast .+=