Skip to content

Commit

Permalink
Add extension
Browse files Browse the repository at this point in the history
  • Loading branch information
pxl-th committed Jun 21, 2024
1 parent bb07736 commit e7bc827
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 69 deletions.
11 changes: 11 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,28 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a"
SuiteSparse = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9"

[weakdeps]
Atomix = "a9b6321e-bd34-4604-b9c9-b65b8de01458"
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"

[extensions]
ChainRulesKernelAbstractionsExt = ["Atomix", "GPUArrays", "KernelAbstractions"]

[compat]
Adapt = "3.4.0, 4"
Atomix = "0.1"
ChainRulesCore = "1.20"
ChainRulesTestUtils = "1.5"
Compat = "3.46, 4.2"
Distributed = "1"
FiniteDifferences = "0.12.20"
GPUArraysCore = "0.1.0"
GPUArrays = "10"
IrrationalConstants = "0.1.1, 0.2"
JLArrays = "0.1"
JuliaInterpreter = "0.8,0.9"
KernelAbstractions = "0.9"
LinearAlgebra = "1"
Random = "1"
RealDot = "0.1"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
module ChainRulesKernelAbstractionsExt

import Adapt
import Atomix
import ChainRules
import GPUArrays
import KernelAbstractions as KA

using GPUArraysCore: AbstractGPUArray
using KernelAbstractions

function ChainRules.∇getindex!(dx::AbstractGPUArray, dy, inds...)
# kab = get_backend(dx)

# if KA.supports_atomics(kab)
# gids = GPUArrays.to_indices(dx, inds)
# idims = map(length, gids)
# Is = map(Adapt.adapt(GPUArrays.ToGPU(dy)), gids)
# scatter!(kab)(+, dx, dy, idims, Is...; ndrange=length(dy))
# else
dx_cpu = Adapt.adapt(Array, dx)
view(dx_cpu, Adapt.adapt(Array, inds)...) .+= Adapt.adapt(Array, dy)
copyto!(dx, dx_cpu)
# end
return dx
end

@kernel function scatter!(op, dest, src, idims, Is::Vararg{Any, N}) where N
_scatter!(@index(Global), op, dest, src, idims, Is...)
end

@generated function _scatter!(i, op, dest, src, idims, Is::Vararg{Any, N}) where N
quote
is = @inbounds CartesianIndices(idims)[i]
Base.Cartesian.@nexprs $N j -> I_j = @inbounds((Is[j])[is[j]])
dv = src[i]
Base.Cartesian.@ncall $N _accum! op dest dv j -> I_j
end
end

function _accum!(op, dest, val, ids...)
Atomix.modify!(Atomix.IndexableRef(dest, (ids...,)), op, val)
end

end
17 changes: 3 additions & 14 deletions src/rulesets/Base/indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -168,9 +168,9 @@ function rrule(::typeof(∇getindex), x, dy, inds...)
return z, ∇getindex_pullback
end

# Indexing with repeated indices on a GPU will lead ∇getindex to have race conditions & wrong answers.
# To avoid this, copy everything back to the CPU.
# But don't do that for indices which are known to be unique, e.g. `A[1, 2:3, :]` the colon gives Base.Slice:
# NOTE:
# Generic `∇getindex!(dx::AbstractGPUArray, dy, inds...)`
# is implemented in `ext/` with a custom kernel.

function ∇getindex!(dx::AbstractGPUArray, dy, inds::Integer...)
view(dx, inds...) .+= Ref(dy)
Expand All @@ -181,17 +181,6 @@ function ∇getindex!(dx::AbstractGPUArray, dy, inds::Union{Integer, AbstractUni
return dx
end

function ∇getindex!(dx::AbstractGPUArray, dy, inds...)
# TODO we want this
# @atomic dx[inds...] .+= dy
# return dx

dx_cpu = adapt(Array, dx)
view(dx_cpu, adapt(Array, inds)...) .+= adapt(Array, dy)
copyto!(dx, dx_cpu)
return dx
end

#####
##### view
#####
Expand Down
55 changes: 0 additions & 55 deletions t.jl

This file was deleted.

0 comments on commit e7bc827

Please sign in to comment.