Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use KernelAbstractions.jl for gather/scatter kernels #487

Merged
merged 10 commits into from
Apr 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@ version = "0.8.19"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Atomix = "a9b6321e-bd34-4604-b9c9-b65b8de01458"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Expand All @@ -21,8 +23,10 @@ NNlibAMDGPUExt = "AMDGPU"
[compat]
AMDGPU = "0.4.8"
Adapt = "2, 3.2"
Atomix = "0.1"
ChainRulesCore = "1.13"
KernelAbstractions = "0.9"
GPUArraysCore = "0.1"
KernelAbstractions = "0.9.2"
Requires = "0.5, 1.0"
julia = "1.6"

Expand Down
1 change: 0 additions & 1 deletion ext/NNlibCUDA/src/NNlibCUDA.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ include("batchedmul.jl")
include("ctc.jl")
include("fold.jl")
include("scatter.jl")
include("gather.jl")
include("utils.jl")
include("cudnn/cudnn.jl")
include("cudnn/conv.jl")
Expand Down
65 changes: 0 additions & 65 deletions ext/NNlibCUDA/src/gather.jl

This file was deleted.

12 changes: 6 additions & 6 deletions ext/NNlibCUDA/src/scatter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ function scatter_kernel!(op::OP, dst, src, idx, max_idx, max_dims_idx, dims_size
return nothing
end

function scatter_kernel!(op::OP, dst, src, idx::CUDA.CuDeviceArray{<:CartesianIndex},
function scatter_kernel!(op::OP, dst, src, idx::CUDA.CuDeviceArray{<:CartesianIndex},
max_idx, max_dims_idx, dims_size) where OP
index = threadIdx().x + (blockIdx().x - 1) * blockDim().x

Expand Down Expand Up @@ -73,7 +73,7 @@ end

## Gradients

function ∇scatter_src_kernel!(op::OP, Δsrc, src, idx,
function ∇scatter_src_kernel!(op::OP, Δsrc, src, idx,
rev_idx, max_idx, T::Type{TT}) where {OP,TT}
index = threadIdx().x + (blockIdx().x - 1) * blockDim().x

Expand All @@ -93,7 +93,7 @@ function ∇scatter_src_kernel!(op::OP, Δsrc, src, idx,
return nothing
end

function ∇scatter_src_kernel!(op::OP, Δsrc, src, idx::CUDA.CuDeviceArray{<:CartesianIndex},
function ∇scatter_src_kernel!(op::OP, Δsrc, src, idx::CUDA.CuDeviceArray{<:CartesianIndex},
rev_idx, max_idx, T::Type{TT}) where {OP,TT}
index = threadIdx().x + (blockIdx().x - 1) * blockDim().x

Expand All @@ -113,7 +113,7 @@ function ∇scatter_src_kernel!(op::OP, Δsrc, src, idx::CUDA.CuDeviceArray{<:Ca
return nothing
end

function ∇scatter_src_kernel!(op::OP, Δsrc, src, idx,
function ∇scatter_src_kernel!(op::OP, Δsrc, src, idx,
rev_idx, pre_cart_idx, max_dims_idx, max_idx, T::Type{TT}) where {OP,TT}
index = threadIdx().x + (blockIdx().x - 1) * blockDim().x

Expand Down Expand Up @@ -160,13 +160,13 @@ function ∇scatter_src_kernel!(op::OP, Δsrc, src, idx::CUDA.CuDeviceArray{<:Ca
end

function NNlib.∇scatter_src(op::Union{typeof(*),typeof(/)}, Δ, dst,
src::AnyCuArray{Tsrc,Nsrc},
src::AnyCuArray{Tsrc,Nsrc},
idx::AnyCuArray{Tidx,Nidx}) where {Tsrc,Tidx,Nsrc,Nidx}
dims = Nsrc - Nidx
Δsrc = NNlib.modify_src(op, NNlib.gather(Δ, idx), src)
rev_idx = NNlib.reverse_indices(idx)
rev_idx = CuArray(map(CUDA.cudaconvert, rev_idx))

if dims == 0
max_idx = length(idx)
args = op, Δsrc, src, idx, rev_idx, max_idx, Tsrc
Expand Down
22 changes: 11 additions & 11 deletions ext/NNlibCUDA/test/gather.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
@testset "gather" begin
@testset "gather" begin
T = Float32
CT = CuArray{Float32}

## 1d src, 2d index of ints -> 2d output
src = CT([3, 4, 5, 6, 7])
index = cu([1 2 3 4;
Expand All @@ -10,14 +10,14 @@
output = CT([3 4 5 6;
6 4 3 5;
5 7 7 5])

y = NNlib.gather(src, index)
@test y isa CuArray{Float32,2}
@test size(y) == size(index)
gputest(src -> NNlib.gather(src, index), src, checkgrad=true)
@test NNlib.gather!(CUDA.zeros(T, size(index)...), src, index) == output
@test_throws ArgumentError NNlib.gather!(zeros(T, 3, 5), src, index)

## 1d src, 2d index of tuples -> 2d output
src = CT([3, 4, 5, 6, 7])
index = cu([(1,) (2,) (3,) (4,);
Expand All @@ -26,14 +26,14 @@
output = CT([3 4 5 6;
6 4 3 5;
5 7 7 5])

y = NNlib.gather(src, index)
@test y isa CuArray{Float32,2}
@test size(y) == size(index)
gputest(src -> NNlib.gather(src, index), src, checkgrad=true)
@test NNlib.gather!(CUDA.zeros(T, size(index)...), src, index) == output
@test_throws ArgumentError NNlib.gather!(zeros(T, 3, 5), src, index)

## 1d src, 2d index of CartesianIndex -> 2d output
src = CT([3, 4, 5, 6, 7])
index = cu(CartesianIndex.([(1,) (2,) (3,) (4,);
Expand All @@ -42,7 +42,7 @@
output = CT([3 4 5 6;
6 4 3 5;
5 7 7 5])

y = NNlib.gather(src, index)
@test y isa CuArray{Float32,2}
@test size(y) == size(index)
Expand All @@ -66,7 +66,7 @@


## 2d src, 2d index of ints -> 3d output
src = CT([3 5 7
src = CT([3 5 7
4 6 8])
index = cu([1 2 3;
2 2 1;
Expand All @@ -79,14 +79,14 @@

output[:,:,2] = [5 5 3
6 6 4]

output[:,:,3] = [7 3 7
8 4 8]

y = NNlib.gather(src, index)
M = NNlib.typelength(eltype(index))
Nsrc = ndims(src)
@test y isa CuArray{Float32,3}
@test size(y) == (size(src)[1:Nsrc-M]..., size(index)...)
@test size(y) == (size(src)[1:Nsrc-M]..., size(index)...)
gputest(src -> NNlib.gather(src, index), src, checkgrad=true)
end
2 changes: 2 additions & 0 deletions src/NNlib.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
module NNlib

import Atomix
import ChainRulesCore: rrule

using Base.Broadcast: broadcasted
using Base.Threads
using ChainRulesCore
using GPUArraysCore
using KernelAbstractions
using KernelAbstractions: @atomic
using LinearAlgebra
Expand Down
Loading