Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Indexing #655

Merged
merged 11 commits into from
Aug 12, 2022
7 changes: 4 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
name = "ChainRules"
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
version = "1.43.2"
version = "1.44.0"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
Expand All @@ -16,6 +17,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a"

[compat]
Adapt = "3.4.0"
ChainRulesCore = "1.15.3"
ChainRulesTestUtils = "1.5"
Compat = "3.42.0, 4"
Expand All @@ -30,7 +32,6 @@ StructArrays = "0.6.11"
julia = "1.6"

[extras]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
Expand All @@ -40,4 +41,4 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Adapt", "ChainRulesTestUtils", "FiniteDifferences", "JLArrays", "JuliaInterpreter", "Random", "StaticArrays", "Test"]
test = ["ChainRulesTestUtils", "FiniteDifferences", "JLArrays", "JuliaInterpreter", "Random", "StaticArrays", "Test"]
3 changes: 2 additions & 1 deletion src/ChainRules.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
module ChainRules

using Adapt: adapt
using Base.Broadcast: materialize, materialize!, broadcasted, Broadcasted, broadcastable
using ChainRulesCore
using Compat
using Distributed
using GPUArraysCore: AbstractGPUArrayStyle
using GPUArraysCore: AbstractGPUArray, AbstractGPUArrayStyle
using IrrationalConstants: logtwo, logten
using LinearAlgebra
using LinearAlgebra.BLAS
Expand Down
52 changes: 1 addition & 51 deletions src/rulesets/Base/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -515,64 +515,14 @@ for findm in (:findmin, :findmax)

@eval function rrule(::typeof($findm), x::AbstractArray; dims=:)
y, ind = $findm(x; dims=dims)
project = ProjectTo(x)
# This pullback is a lot like the one for getindex. Ideally they would probably be combined?
function $findm_pullback((dy, _)) # this accepts e.g. Tangent{Tuple{Float64, Int64}}(4.0, nothing)
dy isa AbstractZero && return (NoTangent(), NoTangent())
x_thunk = @thunk project(_zerolike_writeat(x, unthunk(dy), dims, ind))
x_ithunk = InplaceableThunk(x_thunk) do dx
if dims isa Colon
view(dx, ind) .= view(dx, ind) .+ Ref(unthunk(dy))
else
view(dx, ind) .= view(dx, ind) .+ unthunk(dy) # this could be .+=, but not on Julia 1.0
end
dx
end
return (NoTangent(), x_ithunk)
return (NoTangent(), thunked∇getindex(x, dy, ind),)
mcabbott marked this conversation as resolved.
Show resolved Hide resolved
end
return (y, ind), $findm_pullback
end
end

# This function is roughly `setindex!(zero(x), dy, inds...)`:

function _zerolike_writeat(x::AbstractArray{<:Number}, dy, dims, inds...)
_zero_fill = eltype(dy) == Any ? 0 : zero(eltype(dy))

# It's unfortunate to close over `x`, but `similar(typeof(x), axes(x))` doesn't
# allow `eltype(dy)`, nor does it work for many structured matrices.
dx = fill!(similar(x, eltype(dy), axes(x)), _zero_fill)
view(dx, inds...) .= dy # possibly 0-dim view, allows dy::Number and dy::Array, and dx::CuArray
dx
end
function _zerolike_writeat(x::AbstractArray, dy, dims, inds...)
# Since we have `x`, we can also handle arrays of arrays.
dx = map(zero, x)
if dims isa Colon
view(dx, inds...) .= Ref(dy)
else
view(dx, inds...) .= dy
end
dx
end

# Allow for second derivatives, by writing rules for `_zerolike_writeat`;
# these rules are the reason it takes a `dims` argument.

function frule((_, _, dẏ), ::typeof(_zerolike_writeat), x, dy, dims, inds...)
return _zerolike_writeat(x, dy, dims, inds...), _zerolike_writeat(x, dẏ, dims, inds...)
end

function rrule(::typeof(_zerolike_writeat), x, dy, dims, inds...)
z = _zerolike_writeat(x, dy, dims, inds...)
function _zerolike_writeat_pullback(dz)
dx = sum(view(unthunk(dz), inds...); dims=dims)
nots = map(_ -> NoTangent(), inds)
return (NoTangent(), NoTangent(), dx, NoTangent(), nots...)
end
return z, _zerolike_writeat_pullback
end

# These rules for `maximum` pick the same subgradient as `findmax`:

function frule((_, ẋ), ::typeof(maximum), x; dims=:)
Expand Down
20 changes: 20 additions & 0 deletions src/rulesets/Base/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -243,3 +243,23 @@ function rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(map), f::F, xs::Tu
map_back(dy::AbstractZero) = (NoTangent(), NoTangent(), ntuple(Returns(NoTangent()), num_xs)...)
return y, map_pullback
end

#####
##### `task_local_storage`
#####

# Called by `@allowscalar` from GPUArrays

ChainRules.@non_differentiable task_local_storage(key::Any)
ChainRules.@non_differentiable task_local_storage(key::Any, value::Any)

function rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(task_local_storage), body::Function, key, value)
y, back = task_local_storage(key, value) do
rrule_via_ad(config, body)
end
function task_local_storage_pullback(dy)
dbody = only(back(dy))
return (NoTangent(), dbody, NoTangent(), NoTangent())
end
return y, task_local_storage_pullback
end
139 changes: 118 additions & 21 deletions src/rulesets/Base/indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,38 +52,103 @@ function rrule(::typeof(getindex), x::Tuple, ::Colon)
return x, getindex_back_4
end


#####
##### getindex
##### getindex(::AbstractArray)
#####

function frule((_, ẋ), ::typeof(getindex), x::AbstractArray, inds...)
return x[inds...], ẋ[inds...]
end

function rrule(::typeof(getindex), x::Array{<:Number}, inds...)
# removes any logical indexing, CartesianIndex etc
# leaving us just with a tuple of Int, Arrays of Int and Ranges of Int
function rrule(::typeof(getindex), x::AbstractArray, inds...)
function getindex_pullback(dy)
nots = map(Returns(NoTangent()), inds)
return (NoTangent(), thunked∇getindex(x, dy, inds...), nots...)
mcabbott marked this conversation as resolved.
Show resolved Hide resolved
end
return x[inds...], getindex_pullback
end

function thunked∇getindex(x, dy, inds...)
mcabbott marked this conversation as resolved.
Show resolved Hide resolved
return InplaceableThunk(
dx -> ∇getindex!(dx, x, unthunk(dy), Base.to_indices(x, inds)...),
mcabbott marked this conversation as resolved.
Show resolved Hide resolved
@thunk(∇getindex(x, unthunk(dy), inds...)),
)
end

"""
∇getindex(x, dy, inds...)

For the `rrule` of `y = x[inds...]`, this function is roughly
`setindex(zero(x), dy, inds...)`, returning the array `dx`.
Differentiable. Includes `ProjectTo(x)(dx)`.
"""
function ∇getindex(x::AbstractArray, dy, inds...)
# `to_indices` removes any logical indexing, colons, CartesianIndex etc,
# leaving just Int / AbstractVector of Int
plain_inds = Base.to_indices(x, inds)
y = getindex(x, plain_inds...)
function getindex_pullback(ȳ)
function getindex_add!(Δ)
# this a optimizes away for simple cases
for (ȳ_ii, ii) in zip(ȳ, Iterators.product(plain_inds...))
Δ[ii...] += ȳ_ii
end
return Δ
end
dx = _setindex_zero(x, dy, plain_inds...)
∇getindex!(dx, x, dy, plain_inds...)
mcabbott marked this conversation as resolved.
Show resolved Hide resolved
return ProjectTo(x)(dx) # since we have x, may as well do this inside, not in rules
end

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"""
_setindex_zero(x::AbstractArray, dy, inds...)
Basically this function is like `zero(x)` except that it ensure that
it is possible to set the value at index `inds` to `dy`.
It does this while preserving at least the outermost the structure of `x`.
Like `zero(x)`, it promises that `x == x + _setindex_zero(x, dy, inds...)` for all inputs; i.e. it always returns an additive identity.
"""
function _setindex_zero(x::AbstractArray, dy, inds...) end

I wonder if we shouldn't call it _setable_zero or something?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe, I don't like any names. I guess it takes indices in a way that's like set/getindex functions.

# It's unfortunate to close over `x`, but `similar(typeof(x), axes(x))` doesn't
# allow `eltype(dy)`, nor does it work for many structured matrices.
_setindex_zero(x::AbstractArray{<:Number}, dy, inds::Integer...) = fill!(similar(x, typeof(dy), axes(x)), 0)
_setindex_zero(x::AbstractArray{<:Number}, dy, inds...) = fill!(similar(x, eltype(dy), axes(x)), 0)
function _setindex_zero(x::AbstractArray, dy, inds::Integer...)
# This allows for types which don't define zero (like Vector) and types whose zero special (like Tangent),
# but always makes an abstract type. TODO: make it infer concrete type for e.g. vectors of SVectors
T = Union{typeof(dy), ZeroTangent}
return fill!(similar(x, T, axes(x)), ZeroTangent())
end
function _setindex_zero(x::AbstractArray, dy, inds...)
T = Union{eltype(dy), ZeroTangent}
return fill!(similar(x, T, axes(x)), ZeroTangent())
end
ChainRules.@non_differentiable _setindex_zero(x::AbstractArray, dy::Any, inds::Any...)

function ∇getindex!(dx::AbstractArray, x::AbstractArray, dy, inds::Integer...)
view(dx, inds...) .+= Ref(dy)
return dx
end
function ∇getindex!(dx::AbstractArray, x::AbstractArray, dy, inds...)
view(dx, inds...) .+= dy
return dx
end
mcabbott marked this conversation as resolved.
Show resolved Hide resolved

# Allow for second derivatives, by writing rules for `∇getindex`:

function frule((_, _, dẏ), ::typeof(∇getindex), x, dy, inds...)
return ∇getindex(x, dy, inds...), ∇getindex(x, dẏ, inds...)
end

x̄ = InplaceableThunk(
getindex_add!,
@thunk(getindex_add!(zero(x))),
)
īnds = broadcast(Returns(NoTangent()), inds)
return (NoTangent(), x̄, īnds...)
function rrule(::typeof(∇getindex), x, dy, inds...)
z = ∇getindex(x, dy, inds...)
function ∇getindex_pullback(dz)
d2y = getindex(unthunk(dz), inds...)
Comment on lines +133 to +136
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This second derivative function doesn't seem to infer well, can it be improved?

nots = map(Returns(NoTangent()), inds)
return (NoTangent(), NoTangent(), ProjectTo(dy)(d2y), nots...)
end
return z, ∇getindex_pullback
end

# Indexing with repeated indices on a GPU will lead ∇getindex to have race conditions & wrong answers.
# To avoid this, copy everything back to the CPU.
# But don't do that for indices which are known to be unique, e.g. `A[1, 2:3, :]` the colon gives Base.Slice:

return y, getindex_pullback
function ∇getindex!(dx::AbstractGPUArray, x::AbstractArray, dy, inds::Integer...)
view(dx, inds...) .+= Ref(dy)
return dx
end
function ∇getindex!(dx::AbstractGPUArray, x::AbstractArray, dy, inds::Union{Integer, AbstractUnitRange, Base.Slice}...)
view(dx, inds...) .+= dy
return dx
end
function ∇getindex!(dx::AbstractGPUArray, x::AbstractArray, dy, inds...)
dx_cpu = adapt(Array, dx)
view(dx_cpu, adapt(Array, inds)...) .+= adapt(Array, dy)
copyto!(dx, dx_cpu)
return dx
end
mcabbott marked this conversation as resolved.
Show resolved Hide resolved

#####
Expand Down Expand Up @@ -117,6 +182,23 @@ function frule((_, ẋ), ::typeof(view), x::AbstractArray, inds...)
return view(x, inds...), view(ẋ, inds...)
end

function rrule(::typeof(view), x::AbstractArray, inds...)
function view_pullback(dy)
nots = map(Returns(NoTangent()), inds)
return (NoTangent(), thunked∇getindex(x, dy, inds...), nots...)
mcabbott marked this conversation as resolved.
Show resolved Hide resolved
end
return view(x, inds...), view_pullback
end

function rrule(::typeof(view), x::AbstractArray, i::Integer, jkl::Integer...)
# This case returns a zero-dim array, unlike getindex. So we fool ∇getindex:
function view_pullback_0(dy)
nots = map(Returns(NoTangent()), (i, jkl...))
return (NoTangent(), thunked∇getindex(x, dy, i:i, jkl...), nots...)
mcabbott marked this conversation as resolved.
Show resolved Hide resolved
end
return view(x, i, jkl...), view_pullback_0
end

#####
##### setindex!
#####
Expand All @@ -125,6 +207,21 @@ function frule((_, ẋ, v̇), ::typeof(setindex!), x::AbstractArray, v, inds...)
return setindex!(x, v, inds...), setindex!(ẋ, v̇, inds...)
end

#####
##### unsafe_getindex
#####

# This is called by e.g. `iterate(1:0.1:2)`,
# and fixes https://github.com/FluxML/Zygote.jl/issues/1247
# Only needs to accept AbstractRange, but AbstractVector makes testing easier.

function frule((_, ẋ), ::typeof(Base.unsafe_getindex), x::AbstractVector, i::Integer)
return Base.unsafe_getindex(x, i), getindex(ẋ, i)
end

function rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(Base.unsafe_getindex), x::AbstractVector, i::Integer)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't need any mode does it?

Suggested change
function rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(Base.unsafe_getindex), x::AbstractVector, i::Integer)
function rrule(cfg::RuleConfig, ::typeof(Base.unsafe_getindex), x::AbstractVector, i::Integer)

Not sure why we aren't just calling the rrule for get infact?
Or putting a Union{typeof(getindex), typeof(Base.getindex)) in the function arg on them.
(might even be able to stick view in that union too?)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess I thought the first Zygote use of this might leave its own rule intact for scalar indexing, to make OneElement for that, and use CR for all others. In which case rrule_via_ad here will call that.

More generally it may want to do other more efficient things for indexing ranges. Or we may want to do that here, and remove this entirely.

return rrule_via_ad(cfg, getindex, x, i)
end

#####
##### `eachslice` and friends
Expand Down
12 changes: 4 additions & 8 deletions src/rulesets/Base/sort.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,7 @@ function rrule(::typeof(sortslices), x::AbstractArray; dims::Integer, kw...)
p = sortperm(collect(eachslice(x; dims=dims)); kw...)
inds = ntuple(d -> d == dims ? p : (:), ndims(x))
function sortslices_pullback(dy)
# No actual need to zero this, and if you didn't, then you could widen eltype
# Also, you could use similar(dy) here not x, same size?
dx = _zerolike_writeat(x, unthunk(dy), (), inds...)
return (NoTangent(), ProjectTo(x)(dx))
return (NoTangent(), ∇getindex(x, unthunk(dy), inds...))
end
return x[inds...], sortslices_pullback
end
Expand Down Expand Up @@ -94,12 +91,11 @@ function rrule(::typeof(unique), x::AbstractArray{<:Number}; dims=:)
mask .= (mask .== cumsum(mask, dims=1) .== true) # this implements findfirst(mask; dims=1)
keep = map(I -> I[1], findall(mask))
if dims isa Colon
# The function `_zerolike_writeat` allows second derivatives.
# Should perhaps eventually be shared with `getindex`.
dx = reshape(_zerolike_writeat(vec(x), vec(dy), (), keep), axes_x)
# The function `∇getindex` allows second derivatives.
dx = reshape(∇getindex(vec(x), vec(dy), keep), axes_x) ## TODO understand again why vec!
else
inds = ntuple(d -> d==dims ? keep : (:), length(axes_x))
dx = _zerolike_writeat(x, dy, (), inds...)
dx = ∇getindex(x, dy, inds...)
end
return (NoTangent(), ProjectTo(x)(dx))
end
Expand Down
9 changes: 1 addition & 8 deletions test/rulesets/Base/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -366,14 +366,7 @@ end
@test [5 0; 6 0] == @inferred unthunk(rrule(findmin, [1 2; 3 4], dims=2)[2]((hcat([5,6]), nothing))[2])
test_rrule(findmin, rand(3,4), fkwargs=(dims=1,), output_tangent = (rand(1,4), NoTangent()))
test_rrule(findmin, rand(3,4), fkwargs=(dims=2,))

# Second derivatives
test_frule(ChainRules._zerolike_writeat, rand(2, 2), 5.0, :, CartesianIndex(2, 2))
test_rrule(ChainRules._zerolike_writeat, rand(2, 2), 5.0, :, CartesianIndex(2, 2))
@test_skip test_rrule(ChainRules._zerolike_writeat, rand(2, 2), 5.0, 1, [CartesianIndex(2, 1) CartesianIndex(2, 2)] ⊢ NoTangent()) # MethodError: no method matching isapprox(::Matrix{Float64}, ::Float64; rtol=1.0e-9, atol=1.0e-9)
y, bk = rrule(ChainRules._zerolike_writeat, rand(2, 2), 5.0, 1, [CartesianIndex(2, 1) CartesianIndex(2, 2)])
@test y == [0 0; 5 5]
@test bk([1 2; 3 4]) == (NoTangent(), NoTangent(), [3 4], NoTangent(), NoTangent())
test_rrule(findmin, rand(3,4), fkwargs=(dims=(1,2),))
end

@testset "$imum" for imum in [maximum, minimum]
Expand Down
Loading