Skip to content

Commit

Permalink
Use KernelAbstractions.jl for gather/scatter kernels (#487)
Browse files Browse the repository at this point in the history
* Use KA for gather

* Finish scatter

* Update testsuite for gather/scatter

* Cleanup

* Update tests

* Retain NNlibCUDA scatter kernels

* Fixup

* Add at-inbounds

* Add compat

* Use KA unsafe free
  • Loading branch information
pxl-th authored Apr 16, 2023
1 parent ee909e6 commit 1c6a87c
Show file tree
Hide file tree
Showing 12 changed files with 536 additions and 469 deletions.
6 changes: 5 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@ version = "0.8.19"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Atomix = "a9b6321e-bd34-4604-b9c9-b65b8de01458"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Expand All @@ -21,8 +23,10 @@ NNlibAMDGPUExt = "AMDGPU"
[compat]
AMDGPU = "0.4.8"
Adapt = "2, 3.2"
Atomix = "0.1"
ChainRulesCore = "1.13"
KernelAbstractions = "0.9"
GPUArraysCore = "0.1"
KernelAbstractions = "0.9.2"
Requires = "0.5, 1.0"
julia = "1.6"

Expand Down
1 change: 0 additions & 1 deletion ext/NNlibCUDA/src/NNlibCUDA.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ include("batchedmul.jl")
include("ctc.jl")
include("fold.jl")
include("scatter.jl")
include("gather.jl")
include("utils.jl")
include("cudnn/cudnn.jl")
include("cudnn/conv.jl")
Expand Down
65 changes: 0 additions & 65 deletions ext/NNlibCUDA/src/gather.jl

This file was deleted.

12 changes: 6 additions & 6 deletions ext/NNlibCUDA/src/scatter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ function scatter_kernel!(op::OP, dst, src, idx, max_idx, max_dims_idx, dims_size
return nothing
end

function scatter_kernel!(op::OP, dst, src, idx::CUDA.CuDeviceArray{<:CartesianIndex},
function scatter_kernel!(op::OP, dst, src, idx::CUDA.CuDeviceArray{<:CartesianIndex},
max_idx, max_dims_idx, dims_size) where OP
index = threadIdx().x + (blockIdx().x - 1) * blockDim().x

Expand Down Expand Up @@ -73,7 +73,7 @@ end

## Gradients

function ∇scatter_src_kernel!(op::OP, Δsrc, src, idx,
function ∇scatter_src_kernel!(op::OP, Δsrc, src, idx,
rev_idx, max_idx, T::Type{TT}) where {OP,TT}
index = threadIdx().x + (blockIdx().x - 1) * blockDim().x

Expand All @@ -93,7 +93,7 @@ function ∇scatter_src_kernel!(op::OP, Δsrc, src, idx,
return nothing
end

function ∇scatter_src_kernel!(op::OP, Δsrc, src, idx::CUDA.CuDeviceArray{<:CartesianIndex},
function ∇scatter_src_kernel!(op::OP, Δsrc, src, idx::CUDA.CuDeviceArray{<:CartesianIndex},
rev_idx, max_idx, T::Type{TT}) where {OP,TT}
index = threadIdx().x + (blockIdx().x - 1) * blockDim().x

Expand All @@ -113,7 +113,7 @@ function ∇scatter_src_kernel!(op::OP, Δsrc, src, idx::CUDA.CuDeviceArray{<:Ca
return nothing
end

function ∇scatter_src_kernel!(op::OP, Δsrc, src, idx,
function ∇scatter_src_kernel!(op::OP, Δsrc, src, idx,
rev_idx, pre_cart_idx, max_dims_idx, max_idx, T::Type{TT}) where {OP,TT}
index = threadIdx().x + (blockIdx().x - 1) * blockDim().x

Expand Down Expand Up @@ -160,13 +160,13 @@ function ∇scatter_src_kernel!(op::OP, Δsrc, src, idx::CUDA.CuDeviceArray{<:Ca
end

function NNlib.∇scatter_src(op::Union{typeof(*),typeof(/)}, Δ, dst,
src::AnyCuArray{Tsrc,Nsrc},
src::AnyCuArray{Tsrc,Nsrc},
idx::AnyCuArray{Tidx,Nidx}) where {Tsrc,Tidx,Nsrc,Nidx}
dims = Nsrc - Nidx
Δsrc = NNlib.modify_src(op, NNlib.gather(Δ, idx), src)
rev_idx = NNlib.reverse_indices(idx)
rev_idx = CuArray(map(CUDA.cudaconvert, rev_idx))

if dims == 0
max_idx = length(idx)
args = op, Δsrc, src, idx, rev_idx, max_idx, Tsrc
Expand Down
22 changes: 11 additions & 11 deletions ext/NNlibCUDA/test/gather.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
@testset "gather" begin
@testset "gather" begin
T = Float32
CT = CuArray{Float32}

## 1d src, 2d index of ints -> 2d output
src = CT([3, 4, 5, 6, 7])
index = cu([1 2 3 4;
Expand All @@ -10,14 +10,14 @@
output = CT([3 4 5 6;
6 4 3 5;
5 7 7 5])

y = NNlib.gather(src, index)
@test y isa CuArray{Float32,2}
@test size(y) == size(index)
gputest(src -> NNlib.gather(src, index), src, checkgrad=true)
@test NNlib.gather!(CUDA.zeros(T, size(index)...), src, index) == output
@test_throws ArgumentError NNlib.gather!(zeros(T, 3, 5), src, index)

## 1d src, 2d index of tuples -> 2d output
src = CT([3, 4, 5, 6, 7])
index = cu([(1,) (2,) (3,) (4,);
Expand All @@ -26,14 +26,14 @@
output = CT([3 4 5 6;
6 4 3 5;
5 7 7 5])

y = NNlib.gather(src, index)
@test y isa CuArray{Float32,2}
@test size(y) == size(index)
gputest(src -> NNlib.gather(src, index), src, checkgrad=true)
@test NNlib.gather!(CUDA.zeros(T, size(index)...), src, index) == output
@test_throws ArgumentError NNlib.gather!(zeros(T, 3, 5), src, index)

## 1d src, 2d index of CartesianIndex -> 2d output
src = CT([3, 4, 5, 6, 7])
index = cu(CartesianIndex.([(1,) (2,) (3,) (4,);
Expand All @@ -42,7 +42,7 @@
output = CT([3 4 5 6;
6 4 3 5;
5 7 7 5])

y = NNlib.gather(src, index)
@test y isa CuArray{Float32,2}
@test size(y) == size(index)
Expand All @@ -66,7 +66,7 @@


## 2d src, 2d index of ints -> 3d output
src = CT([3 5 7
src = CT([3 5 7
4 6 8])
index = cu([1 2 3;
2 2 1;
Expand All @@ -79,14 +79,14 @@

output[:,:,2] = [5 5 3
6 6 4]

output[:,:,3] = [7 3 7
8 4 8]

y = NNlib.gather(src, index)
M = NNlib.typelength(eltype(index))
Nsrc = ndims(src)
@test y isa CuArray{Float32,3}
@test size(y) == (size(src)[1:Nsrc-M]..., size(index)...)
@test size(y) == (size(src)[1:Nsrc-M]..., size(index)...)
gputest(src -> NNlib.gather(src, index), src, checkgrad=true)
end
2 changes: 2 additions & 0 deletions src/NNlib.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
module NNlib

import Atomix
import ChainRulesCore: rrule

using Base.Broadcast: broadcasted
using Base.Threads
using ChainRulesCore
using GPUArraysCore
using KernelAbstractions
using KernelAbstractions: @atomic
using LinearAlgebra
Expand Down
Loading

0 comments on commit 1c6a87c

Please sign in to comment.