Skip to content

Commit

Permalink
handle indexing of GPU arrays
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott committed Aug 10, 2022
1 parent 9e1aa8c commit f49b118
Show file tree
Hide file tree
Showing 6 changed files with 67 additions and 5 deletions.
5 changes: 3 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
version = "1.44.0"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
Expand All @@ -16,6 +17,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a"

[compat]
Adapt = "3.4.0"
ChainRulesCore = "1.15.3"
ChainRulesTestUtils = "1.5"
Compat = "3.42.0, 4"
Expand All @@ -30,7 +32,6 @@ StructArrays = "0.6.11"
julia = "1.6"

[extras]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
Expand All @@ -40,4 +41,4 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Adapt", "ChainRulesTestUtils", "FiniteDifferences", "JLArrays", "JuliaInterpreter", "Random", "StaticArrays", "Test"]
test = ["ChainRulesTestUtils", "FiniteDifferences", "JLArrays", "JuliaInterpreter", "Random", "StaticArrays", "Test"]
3 changes: 2 additions & 1 deletion src/ChainRules.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
module ChainRules

using Adapt: adapt
using Base.Broadcast: materialize, materialize!, broadcasted, Broadcasted, broadcastable
using ChainRulesCore
using Compat
using Distributed
using GPUArraysCore: AbstractGPUArrayStyle
using GPUArraysCore: AbstractGPUArray, AbstractGPUArrayStyle
using IrrationalConstants: logtwo, logten
using LinearAlgebra
using LinearAlgebra.BLAS
Expand Down
20 changes: 20 additions & 0 deletions src/rulesets/Base/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -243,3 +243,23 @@ function rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(map), f::F, xs::Tu
map_back(dy::AbstractZero) = (NoTangent(), NoTangent(), ntuple(Returns(NoTangent()), num_xs)...)
return y, map_pullback
end

#####
##### `task_local_storage`
#####

# Called by `@allowscalar` from GPUArrays

ChainRules.@non_differentiable task_local_storage(key::Any)
ChainRules.@non_differentiable task_local_storage(key::Any, value::Any)

function rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(task_local_storage), body::Function, key, value)
y, back = task_local_storage(key, value) do
rrule_via_ad(config, body)
end
function task_local_storage_pullback(dy)
dbody = only(back(dy))
return (NoTangent(), dbody, NoTangent(), NoTangent())
end
return y, task_local_storage_pullback
end
21 changes: 19 additions & 2 deletions src/rulesets/Base/indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,6 @@ function ∇getindex!(dx::AbstractArray, x::AbstractArray, dy, inds::Integer...)
end
function ∇getindex!(dx::AbstractArray, x::AbstractArray, dy, inds...)
view(dx, inds...) .+= dy
# For GPU arrays, `inds::Union{Integer, Base.Slice}...` is fine, but any other AbstractArray risks overwriting.
# Those should call `NNlib.scatter!`, alla https://github.com/FluxML/Zygote.jl/pull/1131
return dx
end

Expand All @@ -134,6 +132,25 @@ function rrule(::typeof(∇getindex), x, dy, inds...)
return z, ∇getindex_pullback
end

# Indexing with repeated indices on a GPU will lead ∇getindex to have race conditions & wrong answers.
# To avoid this, copy everything back to the CPU.
# But don't do that for indices which are known to be unique, e.g. `A[1, 2:3, :]` the colon gives Base.Slice:

function ∇getindex!(dx::AbstractGPUArray, x::AbstractArray, dy, inds::Integer...)
view(dx, inds...) .+= Ref(dy)
return dx
end
function ∇getindex!(dx::AbstractGPUArray, x::AbstractArray, dy, inds::Union{Integer, AbstractUnitRange, Base.Slice}...)
view(dx, inds...) .+= dy
return dx
end
function ∇getindex!(dx::AbstractGPUArray, x::AbstractArray, dy, inds...)
dx_cpu = adapt(Array, dx)
view(dx_cpu, adapt(Array, inds)...) .+= adapt(Array, dy)
copyto!(dx, dx_cpu)
return dx
end

#####
##### first, tail
#####
Expand Down
20 changes: 20 additions & 0 deletions test/rulesets/Base/indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,25 @@
test_rrule(∇getindex, [rand(2) for _ in 1:3], rand(2), 3; check_inferred=false)
test_rrule(∇getindex, [rand(2) for _ in 1:3], [rand(2), rand(2)], 1:2; check_inferred=false)
end

@testset "GPU" begin
x_23_gpu = jl(rand(2, 3))

# Scalar indexing, copied from: @macroexpand @allowscalar A[i]
# Gives an error in Pkg.test, no idea why
# y1, bk1 = rrule(CFG, Base.task_local_storage, () -> x_23_gpu[1], :ScalarIndexing, ScalarAllowed)
# @test y1 == @allowscalar x_gpu[1]
# bk1(1.0) # This is zero, because finite-differencing ignores the function
# ... but this works, and calls the rule:
# Zygote.gradient(x -> @allowscalar(x[1]), jl(rand(3)))[1]

y2, bk2 = rrule(getindex, x_23_gpu, :, 2:3) # fast path, just broadcast .+=
@test unthunk(bk2(jl(ones(2,2)))[2]) == jl([0 1 1; 0 1 1])

y3, bk3 = rrule(getindex, x_23_gpu, 1, [1,1,2]) # slow path, copy to CPU
@test_skip Array(y3) == Array(x_gpu)[1, [1,1,2]] # error in Pkg.test, no idea why
@test unthunk(bk3(jl(ones(3)))[2]) == jl([2 1 0; 0 0 0])
end
end

@testset "first & tail" begin
Expand Down Expand Up @@ -178,6 +197,7 @@ end
end

@testset "unsafe_getindex" begin
# In real life this is called only on some AbstractRanges, but easier to test on Array:
test_frule(Base.unsafe_getindex, collect(1:0.1:2), 3)
test_rrule(Base.unsafe_getindex, collect(1:0.1:2), 3)
end
Expand Down
3 changes: 3 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,15 @@ using Test, ChainRulesCore, ChainRulesTestUtils

@nospecialize

using Adapt
using Base.Broadcast: broadcastable
using ChainRules
using ChainRulesCore
using ChainRulesTestUtils
using ChainRulesTestUtils: rand_tangent, _fdm
using FiniteDifferences
using GPUArraysCore
using JLArrays
using LinearAlgebra
using LinearAlgebra.BLAS
using LinearAlgebra: dot
Expand Down

0 comments on commit f49b118

Please sign in to comment.