-
Notifications
You must be signed in to change notification settings - Fork 89
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Indexing #655
Indexing #655
Changes from all commits
09f98ca
e006af3
94e90fb
0e1c2c5
c1ebf3f
9e1aa8c
f49b118
79fcbd9
f11ce50
7d183f1
adab9b0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -52,38 +52,111 @@ function rrule(::typeof(getindex), x::Tuple, ::Colon) | |||||
return x, getindex_back_4 | ||||||
end | ||||||
|
||||||
|
||||||
##### | ||||||
##### getindex | ||||||
##### getindex(::AbstractArray) | ||||||
##### | ||||||
|
||||||
function frule((_, ẋ), ::typeof(getindex), x::AbstractArray, inds...) | ||||||
return x[inds...], ẋ[inds...] | ||||||
end | ||||||
|
||||||
function rrule(::typeof(getindex), x::Array{<:Number}, inds...) | ||||||
# removes any logical indexing, CartesianIndex etc | ||||||
# leaving us just with a tuple of Int, Arrays of Int and Ranges of Int | ||||||
function rrule(::typeof(getindex), x::AbstractArray, inds...) | ||||||
function getindex_pullback(dy) | ||||||
nots = map(Returns(NoTangent()), inds) | ||||||
return (NoTangent(), thunked_∇getindex(x, dy, inds...), nots...) | ||||||
end | ||||||
return x[inds...], getindex_pullback | ||||||
end | ||||||
|
||||||
function thunked_∇getindex(x, dy, inds...) | ||||||
return InplaceableThunk( | ||||||
dx -> ∇getindex!(dx, unthunk(dy), Base.to_indices(x, inds)...), | ||||||
@thunk(∇getindex(x, unthunk(dy), inds...)), | ||||||
) | ||||||
end | ||||||
|
||||||
""" | ||||||
∇getindex(x, dy, inds...) | ||||||
|
||||||
For the `rrule` of `y = x[inds...]`, this function is roughly | ||||||
`setindex(zero(x), dy, inds...)`, returning the array `dx`. | ||||||
Differentiable. Includes `ProjectTo(x)(dx)`. | ||||||
""" | ||||||
function ∇getindex(x::AbstractArray, dy, inds...) | ||||||
# `to_indices` removes any logical indexing, colons, CartesianIndex etc, | ||||||
# leaving just Int / AbstractVector of Int | ||||||
plain_inds = Base.to_indices(x, inds) | ||||||
y = getindex(x, plain_inds...) | ||||||
function getindex_pullback(ȳ) | ||||||
function getindex_add!(Δ) | ||||||
# this a optimizes away for simple cases | ||||||
for (ȳ_ii, ii) in zip(ȳ, Iterators.product(plain_inds...)) | ||||||
Δ[ii...] += ȳ_ii | ||||||
end | ||||||
return Δ | ||||||
end | ||||||
dx = _setindex_zero(x, dy, plain_inds...) | ||||||
∇getindex!(dx, dy, plain_inds...) | ||||||
return ProjectTo(x)(dx) # since we have x, may as well do this inside, not in rules | ||||||
end | ||||||
|
||||||
""" | ||||||
_setindex_zero(x, dy, inds...) | ||||||
|
||||||
x̄ = InplaceableThunk( | ||||||
getindex_add!, | ||||||
@thunk(getindex_add!(zero(x))), | ||||||
) | ||||||
īnds = broadcast(Returns(NoTangent()), inds) | ||||||
return (NoTangent(), x̄, īnds...) | ||||||
This returns roughly `dx = zero(x)`, except that this is guaranteed to be mutable via `similar`, | ||||||
and its element type is wide enough to allow `setindex!(dx, dy, inds...)`, which is exactly what | ||||||
`∇getindex` does next. | ||||||
|
||||||
It's unfortunate to close over `x`, but `similar(typeof(x), axes(x))` doesn't | ||||||
allow `eltype(dy)`, nor does it work for many structured matrices. | ||||||
""" | ||||||
_setindex_zero(x::AbstractArray{<:Number}, dy, inds::Integer...) = fill!(similar(x, typeof(dy), axes(x)), ZeroTangent()) | ||||||
_setindex_zero(x::AbstractArray{<:Number}, dy, inds...) = fill!(similar(x, eltype(dy), axes(x)), ZeroTangent()) | ||||||
function _setindex_zero(x::AbstractArray, dy, inds::Integer...) | ||||||
# This allows for types which don't define zero (like Vector) and types whose zero special (like Tangent), | ||||||
# but always makes an abstract type. TODO: make it infer concrete type for e.g. vectors of SVectors | ||||||
T = Union{typeof(dy), ZeroTangent} | ||||||
return fill!(similar(x, T, axes(x)), ZeroTangent()) | ||||||
end | ||||||
function _setindex_zero(x::AbstractArray, dy, inds...) | ||||||
T = Union{eltype(dy), ZeroTangent} | ||||||
return fill!(similar(x, T, axes(x)), ZeroTangent()) | ||||||
end | ||||||
ChainRules.@non_differentiable _setindex_zero(x::AbstractArray, dy::Any, inds::Any...) | ||||||
|
||||||
function ∇getindex!(dx::AbstractArray, dy, inds::Integer...) | ||||||
view(dx, inds...) .+= Ref(dy) | ||||||
return dx | ||||||
end | ||||||
function ∇getindex!(dx::AbstractArray, dy, inds...) | ||||||
view(dx, inds...) .+= dy | ||||||
return dx | ||||||
end | ||||||
|
||||||
# Allow for second derivatives, by writing rules for `∇getindex`: | ||||||
|
||||||
function frule((_, _, dẏ), ::typeof(∇getindex), x, dy, inds...) | ||||||
return ∇getindex(x, dy, inds...), ∇getindex(x, dẏ, inds...) | ||||||
end | ||||||
|
||||||
function rrule(::typeof(∇getindex), x, dy, inds...) | ||||||
z = ∇getindex(x, dy, inds...) | ||||||
function ∇getindex_pullback(dz) | ||||||
d2y = getindex(unthunk(dz), inds...) | ||||||
Comment on lines
+133
to
+136
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This second derivative function doesn't seem to infer well, can it be improved? |
||||||
nots = map(Returns(NoTangent()), inds) | ||||||
return (NoTangent(), NoTangent(), ProjectTo(dy)(d2y), nots...) | ||||||
end | ||||||
return z, ∇getindex_pullback | ||||||
end | ||||||
|
||||||
# Indexing with repeated indices on a GPU will lead ∇getindex to have race conditions & wrong answers. | ||||||
# To avoid this, copy everything back to the CPU. | ||||||
# But don't do that for indices which are known to be unique, e.g. `A[1, 2:3, :]` the colon gives Base.Slice: | ||||||
|
||||||
return y, getindex_pullback | ||||||
function ∇getindex!(dx::AbstractGPUArray, dy, inds::Integer...) | ||||||
view(dx, inds...) .+= Ref(dy) | ||||||
return dx | ||||||
end | ||||||
function ∇getindex!(dx::AbstractGPUArray, dy, inds::Union{Integer, AbstractUnitRange, Base.Slice}...) | ||||||
view(dx, inds...) .+= dy | ||||||
return dx | ||||||
end | ||||||
function ∇getindex!(dx::AbstractGPUArray, dy, inds...) | ||||||
dx_cpu = adapt(Array, dx) | ||||||
view(dx_cpu, adapt(Array, inds)...) .+= adapt(Array, dy) | ||||||
copyto!(dx, dx_cpu) | ||||||
return dx | ||||||
end | ||||||
|
||||||
##### | ||||||
|
@@ -117,6 +190,23 @@ function frule((_, ẋ), ::typeof(view), x::AbstractArray, inds...) | |||||
return view(x, inds...), view(ẋ, inds...) | ||||||
end | ||||||
|
||||||
function rrule(::typeof(view), x::AbstractArray, inds...) | ||||||
function view_pullback(dy) | ||||||
nots = map(Returns(NoTangent()), inds) | ||||||
return (NoTangent(), thunked_∇getindex(x, dy, inds...), nots...) | ||||||
end | ||||||
return view(x, inds...), view_pullback | ||||||
end | ||||||
|
||||||
function rrule(::typeof(view), x::AbstractArray, i::Integer, jkl::Integer...) | ||||||
# This case returns a zero-dim array, unlike getindex. So we fool ∇getindex: | ||||||
function view_pullback_0(dy) | ||||||
nots = map(Returns(NoTangent()), (i, jkl...)) | ||||||
return (NoTangent(), thunked_∇getindex(x, dy, i:i, jkl...), nots...) | ||||||
end | ||||||
return view(x, i, jkl...), view_pullback_0 | ||||||
end | ||||||
|
||||||
##### | ||||||
##### setindex! | ||||||
##### | ||||||
|
@@ -125,6 +215,21 @@ function frule((_, ẋ, v̇), ::typeof(setindex!), x::AbstractArray, v, inds...) | |||||
return setindex!(x, v, inds...), setindex!(ẋ, v̇, inds...) | ||||||
end | ||||||
|
||||||
##### | ||||||
##### unsafe_getindex | ||||||
##### | ||||||
|
||||||
# This is called by e.g. `iterate(1:0.1:2)`, | ||||||
# and fixes https://github.com/FluxML/Zygote.jl/issues/1247 | ||||||
# Only needs to accept AbstractRange, but AbstractVector makes testing easier. | ||||||
|
||||||
function frule((_, ẋ), ::typeof(Base.unsafe_getindex), x::AbstractVector, i::Integer) | ||||||
return Base.unsafe_getindex(x, i), getindex(ẋ, i) | ||||||
end | ||||||
|
||||||
function rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(Base.unsafe_getindex), x::AbstractVector, i::Integer) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This doesn't need any mode does it?
Suggested change
Not sure why we aren't just calling the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess I thought the first Zygote use of this might leave its own rule intact for scalar indexing, to make OneElement for that, and use CR for all others. In which case More generally it may want to do other more efficient things for indexing ranges. Or we may want to do that here, and remove this entirely. |
||||||
return rrule_via_ad(cfg, getindex, x, i) | ||||||
end | ||||||
|
||||||
##### | ||||||
##### `eachslice` and friends | ||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if we shouldn't call it
_setable_zero
or something?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe, I don't like any names. I guess it takes indices in a way that's like set/getindex functions.