Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Indexing #655

Merged
merged 11 commits into from
Aug 12, 2022
Merged

Indexing #655

merged 11 commits into from
Aug 12, 2022

Conversation

mcabbott
Copy link
Member

@mcabbott mcabbott commented Aug 2, 2022

This wants to add a rule for A[i,j,k] any AbstractArray.

The earlier rule was only for Array. I think the argument for that is that, ultimately, indexing of any linear algebra wrapper resolves to indexing the underlying array. But it resolves to scalar indexing, which I think will be quite inefficient for something like gradient(x -> sum(x[:,1]), transpose(rand(3,3)))[1]. And in practice that fails (with just the Array rule) as it creates & mutates an array to hold the parts.

The internal function _zerolike_writeat which was previously used for some other rules is re-named ∇getindex and simplified: I am not sure why it needed dims . It has rules to allow higher derivatives.

It also always makes a full dense array; we could consider adding something like Zygote.OneElement to be more efficient at scalar indexing. But once you add two of those you get an Array; perhaps InplaceableThunk is eventually going to be better?

@mcabbott mcabbott marked this pull request as draft August 2, 2022 14:44
src/rulesets/Base/indexing.jl Outdated Show resolved Hide resolved
Comment on lines +118 to +128
function rrule(::typeof(∇getindex), x, dy, inds...)
z = ∇getindex(x, dy, inds...)
function ∇getindex_pullback(dz)
d2y = getindex(unthunk(dz), inds...)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This second derivative function doesn't seem to infer well, can it be improved?

src/rulesets/Base/indexing.jl Outdated Show resolved Hide resolved
src/rulesets/Base/indexing.jl Outdated Show resolved Hide resolved
Copy link
Member

@oxinabox oxinabox left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool cool once these comments are addressed as you feel best
merge when happy

src/rulesets/Base/array.jl Outdated Show resolved Hide resolved
src/rulesets/Base/indexing.jl Outdated Show resolved Hide resolved
src/rulesets/Base/indexing.jl Outdated Show resolved Hide resolved
∇getindex!(dx, x, dy, plain_inds...)
return ProjectTo(x)(dx) # since we have x, may as well do this inside, not in rules
end

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"""
_setindex_zero(x::AbstractArray, dy, inds...)
Basically this function is like `zero(x)` except that it ensure that
it is possible to set the value at index `inds` to `dy`.
It does this while preserving at least the outermost the structure of `x`.
Like `zero(x)`, it promises that `x == x + _setindex_zero(x, dy, inds...)` for all inputs; i.e. it always returns an additive identity.
"""
function _setindex_zero(x::AbstractArray, dy, inds...) end

I wonder if we shouldn't call it _setable_zero or something?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe, I don't like any names. I guess it takes indices in a way that's like set/getindex functions.

src/rulesets/Base/indexing.jl Outdated Show resolved Hide resolved
src/rulesets/Base/indexing.jl Outdated Show resolved Hide resolved
src/rulesets/Base/indexing.jl Outdated Show resolved Hide resolved
return Base.unsafe_getindex(x, i), getindex(ẋ, i)
end

function rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(Base.unsafe_getindex), x::AbstractVector, i::Integer)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't need any mode does it?

Suggested change
function rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(Base.unsafe_getindex), x::AbstractVector, i::Integer)
function rrule(cfg::RuleConfig, ::typeof(Base.unsafe_getindex), x::AbstractVector, i::Integer)

Not sure why we aren't just calling the rrule for get infact?
Or putting a Union{typeof(getindex), typeof(Base.getindex)) in the function arg on them.
(might even be able to stick view in that union too?)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess I thought the first Zygote use of this might leave its own rule intact for scalar indexing, to make OneElement for that, and use CR for all others. In which case rrule_via_ad here will call that.

More generally it may want to do other more efficient things for indexing ranges. Or we may want to do that here, and remove this entirely.

test/rulesets/Base/indexing.jl Outdated Show resolved Hide resolved
@test unthunk(bk2(jl(ones(2,2)))[2]) == jl([0 1 1; 0 1 1])

y3, bk3 = rrule(getindex, x_23_gpu, 1, [1,1,2]) # slow path, copy to CPU
@test_skip Array(y3) == Array(x_gpu)[1, [1,1,2]] # error in Pkg.test, no idea why
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we reproduce and open an issue on Julia itself and link back here?

Copy link
Member Author

@mcabbott mcabbott Aug 12, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After fiddling a bit, here's a better version. These steps run in global scope, but fail inside the let block. FiniteDifferences is involved:

julia> let
               x_23_gpu = jl(rand(2, 3))  # using JLArrays, loaded for @gpu in test_helpers.jl
           
               # Scalar indexing, copied from:  @macroexpand @allowscalar A[i]
               y1, bk1 = rrule(CFG, Base.task_local_storage, () -> x_23_gpu[1], :ScalarIndexing, ScalarAllowed)
                @test y1 == @allowscalar x_23_gpu[1]
                bk1(1.0)
       end
ERROR: StackOverflowError:
Stacktrace:
     [1] to_vec(x::JLArray{Float64, 2})
       @ FiniteDifferences ~/.julia/packages/FiniteDifferences/VpgIT/src/to_vec.jl:73
     [2] to_vec(x::Base.ReshapedArray{Float64, 1, JLArray{Float64, 2}, Tuple{}})
       @ FiniteDifferences ~/.julia/packages/FiniteDifferences/VpgIT/src/to_vec.jl:84
--- the last 2 lines are repeated 39990 more times ---
 [79983] to_vec(x::JLArray{Float64, 2})
       @ FiniteDifferences ~/.julia/packages/FiniteDifferences/VpgIT/src/to_vec.jl:73

At best FiniteDifferences can give zero here. But the parameter which needs tracking is embedded in the function () -> x_23_gpu[1] and I don't think it can unpack that.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants