From f49b118c00cbf8653616d5f2d4803a0a851a9377 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Wed, 3 Aug 2022 00:59:53 -0400 Subject: [PATCH] handle indexing of GPU arrays --- Project.toml | 5 +++-- src/ChainRules.jl | 3 ++- src/rulesets/Base/base.jl | 20 ++++++++++++++++++++ src/rulesets/Base/indexing.jl | 21 +++++++++++++++++++-- test/rulesets/Base/indexing.jl | 20 ++++++++++++++++++++ test/runtests.jl | 3 +++ 6 files changed, 67 insertions(+), 5 deletions(-) diff --git a/Project.toml b/Project.toml index 64b86cc45..b66176754 100644 --- a/Project.toml +++ b/Project.toml @@ -3,6 +3,7 @@ uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" version = "1.44.0" [deps] +Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" @@ -16,6 +17,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a" [compat] +Adapt = "3.4.0" ChainRulesCore = "1.15.3" ChainRulesTestUtils = "1.5" Compat = "3.42.0, 4" @@ -30,7 +32,6 @@ StructArrays = "0.6.11" julia = "1.6" [extras] -Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb" @@ -40,4 +41,4 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Adapt", "ChainRulesTestUtils", "FiniteDifferences", "JLArrays", "JuliaInterpreter", "Random", "StaticArrays", "Test"] +test = ["ChainRulesTestUtils", "FiniteDifferences", "JLArrays", "JuliaInterpreter", "Random", "StaticArrays", "Test"] diff --git a/src/ChainRules.jl b/src/ChainRules.jl index b314d7be7..30e492d2e 100644 --- a/src/ChainRules.jl +++ b/src/ChainRules.jl @@ -1,10 +1,11 @@ module ChainRules +using Adapt: adapt using Base.Broadcast: materialize, materialize!, broadcasted, Broadcasted, broadcastable using ChainRulesCore using Compat using Distributed -using GPUArraysCore: AbstractGPUArrayStyle +using GPUArraysCore: AbstractGPUArray, AbstractGPUArrayStyle using IrrationalConstants: logtwo, logten using LinearAlgebra using LinearAlgebra.BLAS diff --git a/src/rulesets/Base/base.jl b/src/rulesets/Base/base.jl index c10ba6e71..d486fd67e 100644 --- a/src/rulesets/Base/base.jl +++ b/src/rulesets/Base/base.jl @@ -243,3 +243,23 @@ function rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(map), f::F, xs::Tu map_back(dy::AbstractZero) = (NoTangent(), NoTangent(), ntuple(Returns(NoTangent()), num_xs)...) return y, map_pullback end + +##### +##### `task_local_storage` +##### + +# Called by `@allowscalar` from GPUArrays + +ChainRules.@non_differentiable task_local_storage(key::Any) +ChainRules.@non_differentiable task_local_storage(key::Any, value::Any) + +function rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(task_local_storage), body::Function, key, value) + y, back = task_local_storage(key, value) do + rrule_via_ad(config, body) + end + function task_local_storage_pullback(dy) + dbody = only(back(dy)) + return (NoTangent(), dbody, NoTangent(), NoTangent()) + end + return y, task_local_storage_pullback +end diff --git a/src/rulesets/Base/indexing.jl b/src/rulesets/Base/indexing.jl index db25537b2..14bdf747d 100644 --- a/src/rulesets/Base/indexing.jl +++ b/src/rulesets/Base/indexing.jl @@ -113,8 +113,6 @@ function ∇getindex!(dx::AbstractArray, x::AbstractArray, dy, inds::Integer...) end function ∇getindex!(dx::AbstractArray, x::AbstractArray, dy, inds...) view(dx, inds...) .+= dy - # For GPU arrays, `inds::Union{Integer, Base.Slice}...` is fine, but any other AbstractArray risks overwriting. - # Those should call `NNlib.scatter!`, alla https://github.com/FluxML/Zygote.jl/pull/1131 return dx end @@ -134,6 +132,25 @@ function rrule(::typeof(∇getindex), x, dy, inds...) return z, ∇getindex_pullback end +# Indexing with repeated indices on a GPU will lead ∇getindex to have race conditions & wrong answers. +# To avoid this, copy everything back to the CPU. +# But don't do that for indices which are known to be unique, e.g. `A[1, 2:3, :]` the colon gives Base.Slice: + +function ∇getindex!(dx::AbstractGPUArray, x::AbstractArray, dy, inds::Integer...) + view(dx, inds...) .+= Ref(dy) + return dx +end +function ∇getindex!(dx::AbstractGPUArray, x::AbstractArray, dy, inds::Union{Integer, AbstractUnitRange, Base.Slice}...) + view(dx, inds...) .+= dy + return dx +end +function ∇getindex!(dx::AbstractGPUArray, x::AbstractArray, dy, inds...) + dx_cpu = adapt(Array, dx) + view(dx_cpu, adapt(Array, inds)...) .+= adapt(Array, dy) + copyto!(dx, dx_cpu) + return dx +end + ##### ##### first, tail ##### diff --git a/test/rulesets/Base/indexing.jl b/test/rulesets/Base/indexing.jl index 0d5b09398..c2b9203a4 100644 --- a/test/rulesets/Base/indexing.jl +++ b/test/rulesets/Base/indexing.jl @@ -143,6 +143,25 @@ test_rrule(∇getindex, [rand(2) for _ in 1:3], rand(2), 3; check_inferred=false) test_rrule(∇getindex, [rand(2) for _ in 1:3], [rand(2), rand(2)], 1:2; check_inferred=false) end + + @testset "GPU" begin + x_23_gpu = jl(rand(2, 3)) + + # Scalar indexing, copied from: @macroexpand @allowscalar A[i] + # Gives an error in Pkg.test, no idea why + # y1, bk1 = rrule(CFG, Base.task_local_storage, () -> x_23_gpu[1], :ScalarIndexing, ScalarAllowed) + # @test y1 == @allowscalar x_gpu[1] + # bk1(1.0) # This is zero, because finite-differencing ignores the function + # ... but this works, and calls the rule: + # Zygote.gradient(x -> @allowscalar(x[1]), jl(rand(3)))[1] + + y2, bk2 = rrule(getindex, x_23_gpu, :, 2:3) # fast path, just broadcast .+= + @test unthunk(bk2(jl(ones(2,2)))[2]) == jl([0 1 1; 0 1 1]) + + y3, bk3 = rrule(getindex, x_23_gpu, 1, [1,1,2]) # slow path, copy to CPU + @test_skip Array(y3) == Array(x_gpu)[1, [1,1,2]] # error in Pkg.test, no idea why + @test unthunk(bk3(jl(ones(3)))[2]) == jl([2 1 0; 0 0 0]) + end end @testset "first & tail" begin @@ -178,6 +197,7 @@ end end @testset "unsafe_getindex" begin + # In real life this is called only on some AbstractRanges, but easier to test on Array: test_frule(Base.unsafe_getindex, collect(1:0.1:2), 3) test_rrule(Base.unsafe_getindex, collect(1:0.1:2), 3) end diff --git a/test/runtests.jl b/test/runtests.jl index 9ac5c5981..71444f388 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2,12 +2,15 @@ using Test, ChainRulesCore, ChainRulesTestUtils @nospecialize +using Adapt using Base.Broadcast: broadcastable using ChainRules using ChainRulesCore using ChainRulesTestUtils using ChainRulesTestUtils: rand_tangent, _fdm using FiniteDifferences +using GPUArraysCore +using JLArrays using LinearAlgebra using LinearAlgebra.BLAS using LinearAlgebra: dot