-
Notifications
You must be signed in to change notification settings - Fork 89
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Indexing #655
Indexing #655
Conversation
function rrule(::typeof(∇getindex), x, dy, inds...) | ||
z = ∇getindex(x, dy, inds...) | ||
function ∇getindex_pullback(dz) | ||
d2y = getindex(unthunk(dz), inds...) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This second derivative function doesn't seem to infer well, can it be improved?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool cool once these comments are addressed as you feel best
merge when happy
∇getindex!(dx, x, dy, plain_inds...) | ||
return ProjectTo(x)(dx) # since we have x, may as well do this inside, not in rules | ||
end | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
""" | |
_setindex_zero(x::AbstractArray, dy, inds...) | |
Basically this function is like `zero(x)` except that it ensure that | |
it is possible to set the value at index `inds` to `dy`. | |
It does this while preserving at least the outermost the structure of `x`. | |
Like `zero(x)`, it promises that `x == x + _setindex_zero(x, dy, inds...)` for all inputs; i.e. it always returns an additive identity. | |
""" | |
function _setindex_zero(x::AbstractArray, dy, inds...) end |
I wonder if we shouldn't call it _setable_zero
or something?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe, I don't like any names. I guess it takes indices in a way that's like set/getindex functions.
return Base.unsafe_getindex(x, i), getindex(ẋ, i) | ||
end | ||
|
||
function rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(Base.unsafe_getindex), x::AbstractVector, i::Integer) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This doesn't need any mode does it?
function rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(Base.unsafe_getindex), x::AbstractVector, i::Integer) | |
function rrule(cfg::RuleConfig, ::typeof(Base.unsafe_getindex), x::AbstractVector, i::Integer) |
Not sure why we aren't just calling the rrule
for get infact?
Or putting a Union{typeof(getindex), typeof(Base.getindex))
in the function arg on them.
(might even be able to stick view
in that union too?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess I thought the first Zygote use of this might leave its own rule intact for scalar indexing, to make OneElement for that, and use CR for all others. In which case rrule_via_ad
here will call that.
More generally it may want to do other more efficient things for indexing ranges. Or we may want to do that here, and remove this entirely.
test/rulesets/Base/indexing.jl
Outdated
@test unthunk(bk2(jl(ones(2,2)))[2]) == jl([0 1 1; 0 1 1]) | ||
|
||
y3, bk3 = rrule(getindex, x_23_gpu, 1, [1,1,2]) # slow path, copy to CPU | ||
@test_skip Array(y3) == Array(x_gpu)[1, [1,1,2]] # error in Pkg.test, no idea why |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we reproduce and open an issue on Julia itself and link back here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After fiddling a bit, here's a better version. These steps run in global scope, but fail inside the let block. FiniteDifferences is involved:
julia> let
x_23_gpu = jl(rand(2, 3)) # using JLArrays, loaded for @gpu in test_helpers.jl
# Scalar indexing, copied from: @macroexpand @allowscalar A[i]
y1, bk1 = rrule(CFG, Base.task_local_storage, () -> x_23_gpu[1], :ScalarIndexing, ScalarAllowed)
@test y1 == @allowscalar x_23_gpu[1]
bk1(1.0)
end
ERROR: StackOverflowError:
Stacktrace:
[1] to_vec(x::JLArray{Float64, 2})
@ FiniteDifferences ~/.julia/packages/FiniteDifferences/VpgIT/src/to_vec.jl:73
[2] to_vec(x::Base.ReshapedArray{Float64, 1, JLArray{Float64, 2}, Tuple{}})
@ FiniteDifferences ~/.julia/packages/FiniteDifferences/VpgIT/src/to_vec.jl:84
--- the last 2 lines are repeated 39990 more times ---
[79983] to_vec(x::JLArray{Float64, 2})
@ FiniteDifferences ~/.julia/packages/FiniteDifferences/VpgIT/src/to_vec.jl:73
At best FiniteDifferences can give zero here. But the parameter which needs tracking is embedded in the function () -> x_23_gpu[1]
and I don't think it can unpack that.
This wants to add a rule for
A[i,j,k]
any AbstractArray.The earlier rule was only for
Array
. I think the argument for that is that, ultimately, indexing of any linear algebra wrapper resolves to indexing the underlying array. But it resolves to scalar indexing, which I think will be quite inefficient for something likegradient(x -> sum(x[:,1]), transpose(rand(3,3)))[1]
. And in practice that fails (with just theArray
rule) as it creates & mutates an array to hold the parts.The internal function
_zerolike_writeat
which was previously used for some other rules is re-named∇getindex
and simplified:I am not sure why it needed. It has rules to allow higher derivatives.dims
It also always makes a full dense array; we could consider adding something like
Zygote.OneElement
to be more efficient at scalar indexing. But once you add two of those you get an Array; perhaps InplaceableThunk is eventually going to be better?