From 6888f795c4d4f31d15fff41fa29f1e604a30df47 Mon Sep 17 00:00:00 2001 From: Anton Smirnov Date: Sat, 8 Apr 2023 22:53:18 +0300 Subject: [PATCH 01/10] Use KA for gather --- Project.toml | 1 + ext/NNlibCUDA/src/NNlibCUDA.jl | 1 - ext/NNlibCUDA/src/gather.jl | 65 ------- ext/NNlibCUDA/src/scatter.jl | 3 +- src/NNlib.jl | 1 + src/gather.jl | 129 +++++++------ src/scatter.jl | 107 +++++++---- test/gather.jl | 310 ++++++++++++++++--------------- test/runtests.jl | 226 +++++++++++------------ test/scatter.jl | 321 ++++++++++++++++++--------------- 10 files changed, 596 insertions(+), 568 deletions(-) delete mode 100644 ext/NNlibCUDA/src/gather.jl diff --git a/Project.toml b/Project.toml index e8c55345e..5bfef021f 100644 --- a/Project.toml +++ b/Project.toml @@ -5,6 +5,7 @@ version = "0.8.19" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" diff --git a/ext/NNlibCUDA/src/NNlibCUDA.jl b/ext/NNlibCUDA/src/NNlibCUDA.jl index b0d5ea9e7..e1c3c225c 100644 --- a/ext/NNlibCUDA/src/NNlibCUDA.jl +++ b/ext/NNlibCUDA/src/NNlibCUDA.jl @@ -13,7 +13,6 @@ include("batchedmul.jl") include("ctc.jl") include("fold.jl") include("scatter.jl") -include("gather.jl") include("utils.jl") include("cudnn/cudnn.jl") include("cudnn/conv.jl") diff --git a/ext/NNlibCUDA/src/gather.jl b/ext/NNlibCUDA/src/gather.jl deleted file mode 100644 index 5690dfc59..000000000 --- a/ext/NNlibCUDA/src/gather.jl +++ /dev/null @@ -1,65 +0,0 @@ -function gather_check_dims(X::AbstractArray{Tx,Nx}, - Y::AbstractArray{Ty,Ny}, - idx::AbstractArray{Tidx,Nidx}) where - {Tx,Ty,Tidx<:IntOrIntTuple,Nx,Ny,Nidx} - M = NNlib.typelength(Tidx) - dims = gather_check_dims(Nx, Ny, M, Nidx) - size(X)[1:dims] == size(Y)[1:dims] || throw(ArgumentError("Incompatible input shapes.")) - size(Y)[dims+1:end] == size(idx) || throw(ArgumentError("Incompatible input shapes.")) - return dims -end - -function gather_check_dims(X::AbstractArray{Tx,Nx}, - Y::AbstractArray{Ty,Ny}, - idx::AbstractArray{CartesianIndex{M},Nidx}) where - {Tx,Ty,Nx,Ny,M,Nidx} - dims = gather_check_dims(Nx, Ny, M, Nidx) - size(X)[1:dims] == size(Y)[1:dims] || throw(ArgumentError("Incompatible input shapes.")) - size(Y)[dims+1:end] == size(idx) || throw(ArgumentError("Incompatible input shapes.")) - return dims -end - -function gather_check_dims(Nx, Ny, M, Nidx) - @assert Nx - M == Ny - Nidx "Incompatible input shapes of (dst, src, idx) = ($Nx, $Ny, $Nidx)." - dims = Nx - M - dims < 0 && throw(ArgumentError("dims must be non-negative but got dims=$dims.")) - return dims -end - -function gather_kernel!(dst, src, idx, max_idx, max_dims_idx, dims_size) - index = threadIdx().x + (blockIdx().x - 1) * blockDim().x - - @inbounds if index <= max_idx - j, k = divrem(index-1, max_dims_idx) - dims_i = CartesianIndices(dims_size)[k+1] - dst[index] = src[dims_i, idx[j+1]...] - end - return nothing -end - -function gather_kernel!(dst, src, idx::CUDA.CuDeviceArray{<:CartesianIndex}, max_idx, max_dims_idx, dims_size) - index = threadIdx().x + (blockIdx().x - 1) * blockDim().x - - @inbounds if index <= max_idx - j, k = divrem(index-1, max_dims_idx) - dims_i = CartesianIndices(dims_size)[k+1] - li = Base._to_linear_index(src, Tuple(dims_i)..., Tuple(idx[j+1])...) - dst[index] = src[li] - end - return nothing -end - -function NNlib.gather!(dst::AnyCuArray, src::AnyCuArray, idx::AnyCuArray) - dims = gather_check_dims(src, dst, idx) - dims_size = size(src)[1:dims] - max_dims_idx = prod(dims_size) - max_idx = max_dims_idx * length(idx) - args = dst, src, idx, max_idx, max_dims_idx, dims_size - - kernel = @cuda launch=false gather_kernel!(args...) - config = launch_configuration(kernel.fun; max_threads=256) - threads = min(max_idx, config.threads) - blocks = cld(max_idx, threads) - kernel(args...; threads=threads, blocks=blocks) - return dst -end diff --git a/ext/NNlibCUDA/src/scatter.jl b/ext/NNlibCUDA/src/scatter.jl index 6a9c5ba87..f839554f1 100644 --- a/ext/NNlibCUDA/src/scatter.jl +++ b/ext/NNlibCUDA/src/scatter.jl @@ -30,8 +30,7 @@ function scatter_kernel!(op::OP, dst, src, idx, max_idx, max_dims_idx, dims_size return nothing end -function scatter_kernel!(op::OP, dst, src, idx::CUDA.CuDeviceArray{<:CartesianIndex}, - max_idx, max_dims_idx, dims_size) where OP +function scatter_kernel!(op::OP, dst, src, idx::CUDA.CuDeviceArray{<:CartesianIndex}, max_idx, max_dims_idx, dims_size) where OP index = threadIdx().x + (blockIdx().x - 1) * blockDim().x @inbounds if index <= max_idx diff --git a/src/NNlib.jl b/src/NNlib.jl index 184fbfc74..7dfbd8805 100644 --- a/src/NNlib.jl +++ b/src/NNlib.jl @@ -5,6 +5,7 @@ import ChainRulesCore: rrule using Base.Broadcast: broadcasted using Base.Threads using ChainRulesCore +using GPUArraysCore using KernelAbstractions using KernelAbstractions: @atomic using LinearAlgebra diff --git a/src/gather.jl b/src/gather.jl index 014d085be..a25b4fc31 100644 --- a/src/gather.jl +++ b/src/gather.jl @@ -1,41 +1,10 @@ -""" - NNlib.gather!(dst, src, idx) - -Reverse operation of [`scatter!`](@ref). Gathers data from source `src` -and writes it in destination `dst` according to the index array `idx`. -For each `k` in `CartesianIndices(idx)`, assign values to `dst` according to - - dst[:, ... , k] .= src[:, ... , idx[k]...] - -Notice that if `idx` is a vector containing integers, -and both `dst` and `src` are matrices, previous expression simplifies to - - dst[:, k] .= src[:, idx[k]] - -and `k` will run over `1:length(idx)`. - -The elements of `idx` can be integers or integer tuples and may be repeated. -A single `src` column can end up being copied into zero, one, -or multiple `dst` columns. - -See [`gather`](@ref) for an allocating version. -""" -function gather!(dst::AbstractArray, src::AbstractArray, idx::AbstractArray) - dims = scatter_dims(src, dst, idx) - colons = ntuple(i -> Colon(), dims) - for k in CartesianIndices(idx) - _view(dst, colons, k) .= _view(src, colons, idx[k]) - end - return dst -end - """ NNlib.gather(src, idx) -> dst -Reverse operation of [`scatter`](@ref). Gathers data from source `src` +Reverse operation of [`scatter`](@ref). Gathers data from source `src` and writes it in a destination `dst` according to the index array `idx`. -For each `k` in `CartesianIndices(idx)`, assign values to `dst` +For each `k` in `CartesianIndices(idx)`, assign values to `dst` according to dst[:, ... , k] .= src[:, ... , idx[k]...] @@ -45,10 +14,10 @@ and `src` is a matrix, previous expression simplifies to dst[:, k] .= src[:, idx[k]] -and `k` will run over `1:length(idx)`. +and `k` will run over `1:length(idx)`. -The elements of `idx` can be integers or integer tuples and may be repeated. -A single `src` column can end up being copied into zero, one, +The elements of `idx` can be integers or integer tuples and may be repeated. +A single `src` column can end up being copied into zero, one, or multiple `dst` columns. See [`gather!`](@ref) for an in-place version. @@ -68,29 +37,19 @@ julia> NNlib.gather([1 2 3; 4 5 6], [1,3,1,3,1]) 4 6 4 6 4 ``` """ -function gather(src::AbstractArray{Tsrc, Nsrc}, - idx::AbstractArray{Tidx, Nidx}) where - {Tsrc, Nsrc, Nidx, Tidx} - - M = typelength(Tidx) +function gather( + src::AbstractArray{Tsrc, Nsrc}, idx::AbstractArray{Tidx, Nidx}, +) where {Tsrc, Nsrc, Nidx, Tidx} + M = typelength(Tidx) dstsize = (size(src)[1:Nsrc-M]..., size(idx)...) dst = similar(src, Tsrc, dstsize) return gather!(dst, src, idx) end -∇gather_src(Δ, src_size, idx) = scatter!(+, fill!(similar(Δ, eltype(Δ), src_size), 0), Δ, idx) - -function rrule(::typeof(gather!), dst::AbstractArray, src::AbstractArray, idx::AbstractArray) - y = gather!(dst, src, idx) - src_size = size(src) - gather!_pullback(Δ) = (NoTangent(), NoTangent(), ∇gather_src(unthunk(Δ), src_size, idx), NoTangent()) - return y, gather!_pullback -end - """ gather(src, IJK...) -Convert the tuple of integer vectors `IJK` to a tuple of `CartesianIndex` and +Convert the tuple of integer vectors `IJK` to a tuple of `CartesianIndex` and call `gather` on it: `gather(src, CartesianIndex.(IJK...))`. # Examples @@ -108,14 +67,72 @@ julia> NNlib.gather(src, [1, 2], [2, 4]) 11 ``` """ -function gather(src::AbstractArray{Tsrc, Nsrc}, - I::AbstractVector{<:Integer}, - J::AbstractVector{<:Integer}, - Ks::AbstractVector{<:Integer}...) where {Nsrc, Tsrc} - +function gather( + src::AbstractArray{Tsrc, Nsrc}, + I::AbstractVector{<:Integer}, J::AbstractVector{<:Integer}, + Ks::AbstractVector{<:Integer}..., +) where {Nsrc, Tsrc} return gather(src, to_cartesian_index(I, J, Ks...)) end to_cartesian_index(IJK...) = CartesianIndex.(IJK...) @non_differentiable to_cartesian_index(::Any...) +""" + NNlib.gather!(dst, src, idx) + +Reverse operation of [`scatter!`](@ref). Gathers data from source `src` +and writes it in destination `dst` according to the index array `idx`. +For each `k` in `CartesianIndices(idx)`, assign values to `dst` according to + + dst[:, ... , k] .= src[:, ... , idx[k]...] + +Notice that if `idx` is a vector containing integers, +and both `dst` and `src` are matrices, previous expression simplifies to + + dst[:, k] .= src[:, idx[k]] + +and `k` will run over `1:length(idx)`. + +The elements of `idx` can be integers or integer tuples and may be repeated. +A single `src` column can end up being copied into zero, one, +or multiple `dst` columns. + +See [`gather`](@ref) for an allocating version. +""" +function gather!(dst::AbstractArray, src::AbstractArray, idx::AbstractArray) + dims = scatter_dims(src, dst, idx) + colons = ntuple(i -> Colon(), dims) + for k in CartesianIndices(idx) + _view(dst, colons, k) .= _view(src, colons, idx[k]) + end + return dst +end + +function gather!(dst::AbstractGPUArray, src::AbstractGPUArray, idx::AbstractGPUArray) + n_dims = scatter_dims(src, dst, idx) + dims = size(src)[1:n_dims] + max_dims_idx = prod(dims) + ndrange = max_dims_idx * length(idx) + _gather!(KernelAbstractions.get_backend(src))( + dst, src, idx, CartesianIndices(dims), max_dims_idx; ndrange) + return dst +end + +@kernel function _gather!( + dst, @Const(src), @Const(idx), + dim_ids::CartesianIndices, max_dims_idx::Int, +) + i = @index(Global) + j, k = divrem(i - 1, max_dims_idx) + dst[i] = src[dim_ids[k + 1], Tuple(idx[j + 1])...] +end + +∇gather_src(Δ, src_size, idx) = scatter!(+, fill!(similar(Δ, eltype(Δ), src_size), 0), Δ, idx) + +function rrule(::typeof(gather!), dst::AbstractArray, src::AbstractArray, idx::AbstractArray) + y = gather!(dst, src, idx) + src_size = size(src) + gather!_pullback(Δ) = (NoTangent(), NoTangent(), ∇gather_src(unthunk(Δ), src_size, idx), NoTangent()) + return y, gather!_pullback +end diff --git a/src/scatter.jl b/src/scatter.jl index 88bb42c65..f0d17039e 100644 --- a/src/scatter.jl +++ b/src/scatter.jl @@ -14,14 +14,14 @@ typelength(::Type{<:NTuple{M}}) where M = M typelength(::Type{CartesianIndex{M}}) where M = M """ -Performs dimensional consistency checks and return the +Performs dimensional consistency checks and return the dimensionality of the scattered objects. """ -function scatter_dims(X::AbstractArray{Tx,Nx}, - Y::AbstractArray{Ty,Ny}, - idx::AbstractArray{Tidx,Nidx}) where {Tx,Ty,Tidx,Nx,Ny,Nidx} - M = typelength(Tidx) - dims = scatter_dims(Nx, Ny, M, Nidx) +function scatter_dims( + X::AbstractArray{Tx,Nx}, Y::AbstractArray{Ty,Ny}, + idx::AbstractArray{Tidx,Nidx}, +) where {Tx,Ty,Tidx,Nx,Ny,Nidx} + dims = scatter_dims(Nx, Ny, typelength(Tidx), Nidx) size(X)[1:dims] == size(Y)[1:dims] || throw(ArgumentError("Incompatible input shapes.")) size(Y)[dims+1:end] == size(idx) || throw(ArgumentError("Incompatible input shapes.")) return dims @@ -41,7 +41,7 @@ _view(X, colons, k::Union{Integer, CartesianIndex}) = view(X, colons..., k) NNlib.scatter!(op, dst, src, idx) Scatter operation, which writes data in `src` into `dst` at locations `idx`. -A binary reduction operator `op` is applied during the scatter. +A binary reduction operator `op` is applied during the scatter. For each index `k` in `idx`, accumulates values in `dst` according to dst[:, ..., idx[k]...] = (op).(dst[:, ..., idx[k]...], src[:, ..., k...]) @@ -53,7 +53,7 @@ See also [`scatter`](@ref), [`gather`](@ref). - `op`: Operations to be applied on `dst` and `src`, e.g. `+`, `-`, `*`, `/`, `max`, `min` and `mean`. - `dst`: The destination for `src` to aggregate to. This argument will be mutated. - `src`: The source data for aggregating. -- `idx`: The mapping for aggregation from source (index) to destination (value). +- `idx`: The mapping for aggregation from source (index) to destination (value). The `idx` array can contain either integers or tuples. # Examples @@ -70,17 +70,49 @@ julia> NNlib.scatter!(*, fill(0.5, 2, 4), [1 10; 100 1000], [3,2]) 0.5 500.0 50.0 0.5 ``` """ -function scatter!(op::OP, dst::AbstractArray, src::AbstractArray, idx::AbstractArray) where OP - dims = scatter_dims(dst, src, idx) - colons = Base.ntuple(_->Colon(), dims) - for k in CartesianIndices(idx) - dst_v = _view(dst, colons, idx[k]) - src_v = _view(src, colons, k) - dst_v .= (op).(dst_v, src_v) +# function scatter!(op::OP, dst::AbstractArray, src::AbstractArray, idx::AbstractArray) where OP +# dims = scatter_dims(dst, src, idx) +# colons = Base.ntuple(_->Colon(), dims) +# for k in CartesianIndices(idx) +# dst_v = _view(dst, colons, idx[k]) +# src_v = _view(src, colons, k) +# dst_v .= (op).(dst_v, src_v) +# end +# dst +# end + +function scatter!(op::OP, dst::AbstractGPUArray, src::AbstractGPUArray, idx::AbstractGPUArray) where OP + n_dims = scatter_dims(dst, src, idx) + args = if n_dims == 0 + ndrange = length(idx) + () + else + dims = size(dst)[1:n_dims] + max_dims_idx = prod(dims) + ndrange = max_dims_idx * length(idx) + (CartesianIndices(dims), max_dims_idx) end + _scatter!(KernelAbstractions.get_backend(dst))( + op, dst, src, idx, args...; ndrange) dst end +@kernel function _scatter!(op::OP, dst, src, idxs) where OP + i = @index(Global) + idx = Tuple(idxs[i]) + @atomic dst[idx...] = op(dst[idx...], src[i]) +end + +# @kernel function _scatter!( +# op::OP, dst, src, idxs, dim_ids::CartesianIndices, max_dims_idx::Int, +# ) where OP +# i = @index(Global) +# j, k = divrem(i - 1, max_dims_idx) +# dim_i = Tuple(dim_ids[k + 1]) +# idx = idxs[j + 1] +# @atomic dst[dim_i..., idx...] = op(dst[dim_i..., idx...], src[i]) +# end + function scatter!(op::typeof(mean), dst::AbstractArray, src::AbstractArray, idx::AbstractArray) Ns = scatter!(+, zero(dst), one.(src), idx) dst_ = scatter!(+, zero(dst), src, idx) @@ -88,16 +120,15 @@ function scatter!(op::typeof(mean), dst::AbstractArray, src::AbstractArray, idx: return dst end - """ NNlib.scatter(op, src, idx; [init, dstsize]) -Scatter operation allocating a destination array `dst` and +Scatter operation allocating a destination array `dst` and calling `scatter!(op, dst, src, idx)` on it. * If keyword `init` is provided, it is used to initialize the content of `dst`. Otherwise, the init values is inferred from the reduction operator `op` - for some common operators (e.g. `init = 0` for `op = +`). + for some common operators (e.g. `init = 0` for `op = +`). * If `dstsize` is provided, it will be used to define the size of destination array, otherwise it will be inferred by `src` and `idx`. @@ -127,15 +158,14 @@ julia> NNlib.scatter(*, [10,200,3000], [1,4,2]; init = 10, dstsize = 6) 10 ``` """ -function scatter(op::OP, - src::AbstractArray{Tsrc,Nsrc}, - idx::AbstractArray{Tidx,Nidx}; - init = nothing, dstsize = nothing) where {Tsrc,Tidx,Nsrc,Nidx,OP} - +function scatter( + op::OP, src::AbstractArray{Tsrc,Nsrc}, idx::AbstractArray{Tidx,Nidx}; + init = nothing, dstsize = nothing, +) where {Tsrc,Tidx,Nsrc,Nidx,OP} dims = Nsrc - Nidx - dstsz = isnothing(dstsize) ? (size(src)[1:dims]..., maximum_dims(idx)...) : dstsize + dstsz = isnothing(dstsize) ? (size(src)[1:dims]..., maximum_dims(idx)...) : dstsize dst = similar(src, Tsrc, dstsz) - xinit = isnothing(init) ? scatter_empty(op, Tsrc) : init + xinit = isnothing(init) ? scatter_empty(op, Tsrc) : init fill!(dst, xinit) scatter!(op, dst, src, idx) end @@ -147,13 +177,11 @@ scatter_empty(op::typeof(min), T) = typemax(T) scatter_empty(op::typeof(max), T) = typemin(T) scatter_empty(op::typeof(mean), T) = zero(T) - ## Gradients -∇scatter!_src(op, Δ, dst, src, idx) = ∇scatter_src(op, Δ, dst, src, idx) +∇scatter!_src(op, Δ, dst, src, idx) = ∇scatter_src(op, Δ, dst, src, idx) ∇scatter!_dst(op, Δ, dst, y) = Δ - -∇scatter!_dst(op::Union{typeof(max),typeof(min)}, Δ, dst_old, dst) = +∇scatter!_dst(op::Union{typeof(max),typeof(min)}, Δ, dst_old, dst) = (dst_old .== op.(dst_old, dst)) .* Δ modify_src(::typeof(+), X) = X @@ -162,13 +190,13 @@ modify_src(::typeof(*), X, Y) = X modify_src(::typeof(/), X, Y) = .-X ./ Y.^2 ∇scatter_src(op::Union{typeof(+),typeof(-)}, Δ, dst, src, idx) = modify_src(op, gather(Δ, idx)) - -∇scatter!_src(op::Union{typeof(*),typeof(/)}, Δ, dst, src, idx) = +∇scatter!_src(op::Union{typeof(*),typeof(/)}, Δ, dst, src, idx) = gather(dst, idx) .* ∇scatter_src(op, Δ, dst, src, idx) -function ∇scatter_src(op::Union{typeof(*),typeof(/)}, Δ, dst, - src::AbstractArray{Tsrc,Nsrc}, - idx::AbstractArray{Tidx,Nidx}) where {Tsrc,Tidx,Nsrc,Nidx} +function ∇scatter_src( + op::Union{typeof(*),typeof(/)}, Δ, dst, + src::AbstractArray{Tsrc,Nsrc}, idx::AbstractArray{Tidx,Nidx}, +) where {Tsrc,Tidx,Nsrc,Nidx} dims = Nsrc - Nidx Δsrc = modify_src(op, gather(Δ, idx), src) rev_idx = reverse_indices(idx) @@ -182,12 +210,13 @@ function ∇scatter_src(op::Union{typeof(*),typeof(/)}, Δ, dst, Δsrc end -∇scatter_src(::Union{typeof(max),typeof(min)}, Δ, dst, src, idx) = (src .== gather(dst, idx)) .* gather(Δ, idx) +∇scatter_src(::Union{typeof(max),typeof(min)}, Δ, dst, src, idx) = + (src .== gather(dst, idx)) .* gather(Δ, idx) -function ∇scatter_src(::typeof(mean), Δ, dst, - src::AbstractArray{Tsrc,Nsrc}, - idx::AbstractArray{Tidx,Nidx}) where {Tsrc,Tidx,Nsrc,Nidx} - +function ∇scatter_src( + ::typeof(mean), Δ, dst, + src::AbstractArray{Tsrc,Nsrc}, idx::AbstractArray{Tidx,Nidx}, +) where {Tsrc,Tidx,Nsrc,Nidx} M = typelength(Tidx) num = gather(Δ, idx) counts = fill!(similar(Δ, Int, size(Δ)[end-M+1:end]), 0) diff --git a/test/gather.jl b/test/gather.jl index 21ce1630d..818e9ad5e 100644 --- a/test/gather.jl +++ b/test/gather.jl @@ -1,162 +1,172 @@ using NNlib: gather, gather! -@testset "gather scalar index" begin +function gather_testsuite(Backend) + cpu, backend = CPU(), Backend() T = Float32 - - ## 1d src, 2d index of ints -> 2d output - src = T[3, 4, 5, 6, 7] - index = [1 2 3 4; - 4 2 1 3; - 3 5 5 3] - output = T[3 4 5 6; - 6 4 3 5; - 5 7 7 5] - - y = gather(src, index) - @test y isa Array{T,2} - @test size(y) == size(index) - @test y == output - @test gather!(T.(zero(index)), src, index) == output - @test_throws ArgumentError gather!(zeros(T, 3, 5), src, index) - - index2 = [1 2 3 4; - 4 2 1 3; - 3 6 5 3] - @test_throws BoundsError gather!(T.(zero(index)), src, index2) - - ## 1d src, 3d index of ints -> 3d output - src = T[3, 4, 5, 6, 7] - index = [1 2 3 4; - 4 2 1 3; - 3 5 5 3][:,:,1:1] - output = T[3 4 5 6; - 6 4 3 5; - 5 7 7 5][:,:,1:1] - - y = gather(src, index) - @test y isa Array{T,3} - @test size(y) == size(index) - @test y == output - - - ## 2d src, 2d index of ints -> 3d output - src = T[3 5 7 - 4 6 8] - index = [1 2 3; - 2 2 1; - 3 1 3] - - output = zeros(T, 2, 3, 3) - - output[:,:,1] = [3 5 7 - 4 6 8] - - output[:,:,2] = [5 5 3 - 6 6 4] - - output[:,:,3] = [7 3 7 - 8 4 8] - - y = gather(src, index) - M = NNlib.typelength(eltype(index)) - Nsrc = ndims(src) - @test y isa Array{T,3} - @test size(y) == (size(src)[1:Nsrc-M]..., size(index)...) - @test y == output -end - -@testset "gather tuple index" begin - T = Float32 - - ## 2d src, 1d index of 2-tuples -> 1d output - src = T[3 5 7 - 4 6 8] - - index = [(1,1), (1,2), (1,3), (2,1), (2,2), (2,3)] - - output = T[3, 5, 7, 4, 6, 8] - - y = gather(src, index) - M = NNlib.typelength(eltype(index)) - Nsrc = ndims(src) - @test y isa Array{T,1} - @test size(y) == (size(src)[1:Nsrc-M]..., size(index)...) - @test y == output - - ## 3d src, 2d index of 2-tuples -> 3d output - n1, nsrc, nidx = 2, 3, 6 - src = rand(Float32, n1, nsrc, nsrc) - index = [(rand(1:nsrc), rand(1:nsrc)) for i=1:nidx, j=1:nidx] - - y = gather(src, index) - M = NNlib.typelength(eltype(index)) - Nsrc = ndims(src) - @test y isa Array{T,3} - @test size(y) == (size(src)[1:Nsrc-M]..., size(index)...) -end + gradtest_fn = backend == CPU() ? gradtest : gputest -@testset "gather cartesian index" begin - T = Float32 - - ## 2d src, 1d index of 2-tuples -> 1d output - src = T[3 5 7 - 4 6 8] - - index = CartesianIndex.([(1,1), (1,2), (1,3), (2,1), (2,2), (2,3)]) - - output = T[3, 5, 7, 4, 6, 8] - - y = gather(src, index) - M = NNlib.typelength(eltype(index)) - Nsrc = ndims(src) - @test y isa Array{T,1} - @test size(y) == (size(src)[1:Nsrc-M]..., size(index)...) - @test y == output - - ## 3d src, 2d index of 2-tuples -> 3d output - n1, nsrc, nidx = 2, 3, 6 - src = rand(Float32, n1, nsrc, nsrc) - index = [CartesianIndex((rand(1:nsrc), rand(1:nsrc))) for i=1:nidx, j=1:nidx] - - y = gather(src, index) - M = NNlib.typelength(eltype(index)) - Nsrc = ndims(src) - @test y isa Array{T,3} - @test size(y) == (size(src)[1:Nsrc-M]..., size(index)...) -end - -@testset "gather gradient for scalar index" begin - T = Float64 - src = T[3, 4, 5, 6, 7] - index = [1 2 3 4; + @testset "gather scalar index" begin + ## 1d src, 2d index of ints -> 2d output + src = adapt(backend, T[3, 4, 5, 6, 7]) + index = adapt(backend, [ + 1 2 3 4; 4 2 1 3; - 3 5 5 3] - dst = T[3 4 5 6; + 3 5 5 3]) + output = T[ + 3 4 5 6; 6 4 3 5; 5 7 7 5] - gradtest(xs -> gather!(dst, xs, index), src) - gradtest(xs -> gather(xs, index), src) -end + y = adapt(cpu, gather(src, index)) + @test y isa Array{T,2} + @test size(y) == size(index) + @test y == output + + dst = adapt(backend, T.(zero(index))) + @test adapt(cpu, gather!(dst, src, index)) == output + dst = adapt(backend, zeros(T, 3, 5)) + @test_throws ArgumentError gather!(dst, src, index) + + if Backend == CPU + index2 = [1 2 3 4; + 4 2 1 3; + 3 6 5 3] + @test_throws BoundsError gather!(T.(zero(index)), src, index2) + end + + ## 1d src, 3d index of ints -> 3d output + src = adapt(backend, T[3, 4, 5, 6, 7]) + index = adapt(backend, [ + 1 2 3 4; + 4 2 1 3; + 3 5 5 3][:,:,1:1]) + output = T[ + 3 4 5 6; + 6 4 3 5; + 5 7 7 5][:,:,1:1] + + y = adapt(cpu, gather(src, index)) + @test y isa Array{T,3} + @test size(y) == size(index) + @test y == output + + ## 2d src, 2d index of ints -> 3d output + src = adapt(backend, T[ + 3 5 7 + 4 6 8]) + index = adapt(backend, [ + 1 2 3; + 2 2 1; + 3 1 3]) -@testset "gather gradient for tuple index" begin - T = Float64 - src = T[3 5 7 + output = zeros(T, 2, 3, 3) + output[:,:,1] = [ + 3 5 7 4 6 8] - index = [(1,1), (1,2), (1,3), (2,1), (2,2), (2,3)] - dst = T[3, 5, 7, 4, 6, 8] - - gradtest(xs -> gather!(dst, xs, index), src) - gradtest(xs -> gather(xs, index), src) + output[:,:,2] = [ + 5 5 3 + 6 6 4] + output[:,:,3] = [ + 7 3 7 + 8 4 8] + + y = adapt(cpu, gather(src, index)) + M = NNlib.typelength(eltype(index)) + Nsrc = ndims(src) + @test y isa Array{T,3} + @test size(y) == (size(src)[1:Nsrc-M]..., size(index)...) + @test y == output + end + + @testset "gather tuple index" begin + ## 2d src, 1d index of 2-tuples -> 1d output + src = adapt(backend, T[ + 3 5 7 + 4 6 8]) + index = adapt(backend, [(1,1), (1,2), (1,3), (2,1), (2,2), (2,3)]) + output = T[3, 5, 7, 4, 6, 8] + + y = adapt(cpu, gather(src, index)) + M = NNlib.typelength(eltype(index)) + Nsrc = ndims(src) + @test y isa Array{T,1} + @test size(y) == (size(src)[1:Nsrc-M]..., size(index)...) + @test y == output + + ## 3d src, 2d index of 2-tuples -> 3d output + n1, nsrc, nidx = 2, 3, 6 + src = adapt(backend, rand(T, n1, nsrc, nsrc)) + index = adapt(backend, [ + (rand(1:nsrc), rand(1:nsrc)) for i=1:nidx, j=1:nidx]) + + y = adapt(cpu, gather(src, index)) + M = NNlib.typelength(eltype(index)) + Nsrc = ndims(src) + @test y isa Array{T,3} + @test size(y) == (size(src)[1:Nsrc-M]..., size(index)...) + end + + @testset "gather cartesian index" begin + ## 2d src, 1d index of 2-tuples -> 1d output + src = adapt(backend, T[ + 3 5 7 + 4 6 8]) + index = adapt(backend, CartesianIndex.([(1,1), (1,2), (1,3), (2,1), (2,2), (2,3)])) + output = T[3, 5, 7, 4, 6, 8] + + y = adapt(cpu, gather(src, index)) + M = NNlib.typelength(eltype(index)) + Nsrc = ndims(src) + @test y isa Array{T,1} + @test size(y) == (size(src)[1:Nsrc-M]..., size(index)...) + @test y == output + + ## 3d src, 2d index of 2-tuples -> 3d output + n1, nsrc, nidx = 2, 3, 6 + src = adapt(backend, rand(Float32, n1, nsrc, nsrc)) + index = adapt(backend, [ + CartesianIndex((rand(1:nsrc), rand(1:nsrc))) for i=1:nidx, j=1:nidx]) + + y = adapt(cpu, gather(src, index)) + M = NNlib.typelength(eltype(index)) + Nsrc = ndims(src) + @test y isa Array{T,3} + @test size(y) == (size(src)[1:Nsrc-M]..., size(index)...) + end end -@testset "gather(src, IJK...)" begin - x = reshape([1:15;], 3, 5) - - y = gather(x, [1,2], [2,4]) - @test y == [4, 11] - - @test gather(x, [1, 2]) == [1 4 - 2 5 - 3 6] -end +# @testset "gather gradient for scalar index" begin +# T = Float64 +# src = T[3, 4, 5, 6, 7] +# index = [1 2 3 4; +# 4 2 1 3; +# 3 5 5 3] +# dst = T[3 4 5 6; +# 6 4 3 5; +# 5 7 7 5] + +# gradtest(xs -> gather!(dst, xs, index), src) +# gradtest(xs -> gather(xs, index), src) +# end + +# @testset "gather gradient for tuple index" begin +# T = Float64 +# src = T[3 5 7 +# 4 6 8] +# index = [(1,1), (1,2), (1,3), (2,1), (2,2), (2,3)] +# dst = T[3, 5, 7, 4, 6, 8] + +# gradtest(xs -> gather!(dst, xs, index), src) +# gradtest(xs -> gather(xs, index), src) +# end + +# @testset "gather(src, IJK...)" begin +# x = reshape([1:15;], 3, 5) + +# y = gather(x, [1,2], [2,4]) +# @test y == [4, 11] + +# @test gather(x, [1, 2]) == [1 4 +# 2 5 +# 3 6] +# end diff --git a/test/runtests.jl b/test/runtests.jl index 8b9061b77..a3db7c8f2 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -28,11 +28,19 @@ end cpu(x) = adapt(CPU(), x) +include("gather.jl") +include("scatter.jl") include("upsample.jl") function nnlib_testsuite(Backend; skip_tests = Set{String}()) - @conditional_testset "Upsample" skip_tests begin - upsample_testsuite(Backend) + # @conditional_testset "Upsample" skip_tests begin + # upsample_testsuite(Backend) + # end + # @conditional_testset "Gather" skip_tests begin + # gather_testsuite(Backend) + # end + @conditional_testset "Scatter" skip_tests begin + scatter_testsuite(Backend) end end @@ -80,115 +88,107 @@ end end end - @testset "Tests" begin - if get(ENV, "NNLIB_TEST_CUDA", "false") == "true" - using CUDA - if CUDA.functional() - import Pkg - using NNlibCUDA - @testset "CUDA" begin - Pkg.test("NNlibCUDA") - end - else - @info "Insufficient version or CUDA not found; Skipping CUDA tests" - end - else - @info "Skipping CUDA tests, set NNLIB_TEST_CUDA=true to run them" - end - - if get(ENV, "NNLIB_TEST_AMDGPU", "false") == "true" - import Pkg - test_info = Pkg.project() - # Add MIOpen_jll to AMDGPU. - Pkg.develop("AMDGPU") - Pkg.activate(joinpath(Pkg.devdir(), "AMDGPU")) - Pkg.add("MIOpen_jll") - Pkg.update() - # Update test project. - Pkg.activate(test_info.path) - Pkg.update() - - using AMDGPU - AMDGPU.versioninfo() - if AMDGPU.functional() && AMDGPU.functional(:MIOpen) - @show AMDGPU.MIOpen.version() - @testset "AMDGPU" begin - include("ext_amdgpu/runtests.jl") - end - else - @info "AMDGPU.jl package is not functional. Skipping AMDGPU tests." - end - else - @info "Skipping AMDGPU tests, set NNLIB_TEST_CUDA=true to run them." - end - - @testset "Doctests" begin - doctest(NNlib, manual=false) - end - - @testset "Activation Functions" begin - include("activations.jl") - end - - @testset "Attention" begin - include("attention.jl") - end - - @testset "Batched Multiplication" begin - include("batchedmul.jl") - end - - @testset "Convolution" begin - include("conv.jl") - include("conv_bias_act.jl") - end - - @testset "CTC Loss" begin - include("ctc.jl") - end - - @testset "Dropout" begin - include("dropout.jl") - end - - @testset "Fold/Unfold" begin - include("fold.jl") - end - - @testset "Inference" begin - include("inference.jl") - end - - @testset "Pooling" begin - include("pooling.jl") - end - - @testset "Padding" begin - include("padding.jl") - end - - @testset "Softmax" begin - include("softmax.jl") - end - - @testset "Gather" begin - include("gather.jl") - end - - @testset "Scatter" begin - include("scatter.jl") - end - - @testset "Utilities" begin - include("utils.jl") - end - - @testset "Grid Sampling" begin - include("sampling.jl") - end - - @testset "Functions" begin - include("functions.jl") - end - end + # @testset verbose=true "Tests" begin + # if get(ENV, "NNLIB_TEST_CUDA", "false") == "true" + # using CUDA + # if CUDA.functional() + # import Pkg + # using NNlibCUDA + # @testset "CUDA" begin + # Pkg.test("NNlibCUDA") + # end + # else + # @info "Insufficient version or CUDA not found; Skipping CUDA tests" + # end + # else + # @info "Skipping CUDA tests, set NNLIB_TEST_CUDA=true to run them" + # end + + # if get(ENV, "NNLIB_TEST_AMDGPU", "false") == "true" + # import Pkg + # test_info = Pkg.project() + # # Add MIOpen_jll to AMDGPU. + # Pkg.develop("AMDGPU") + # Pkg.activate(joinpath(Pkg.devdir(), "AMDGPU")) + # Pkg.add("MIOpen_jll") + # Pkg.update() + # # Update test project. + # Pkg.activate(test_info.path) + # Pkg.update() + + # using AMDGPU + # AMDGPU.versioninfo() + # if AMDGPU.functional() && AMDGPU.functional(:MIOpen) + # @show AMDGPU.MIOpen.version() + # @testset "AMDGPU" begin + # include("ext_amdgpu/runtests.jl") + # end + # else + # @info "AMDGPU.jl package is not functional. Skipping AMDGPU tests." + # end + # else + # @info "Skipping AMDGPU tests, set NNLIB_TEST_CUDA=true to run them." + # end + + # @testset "Doctests" begin + # doctest(NNlib, manual=false) + # end + + # @testset "Activation Functions" begin + # include("activations.jl") + # end + + # @testset "Attention" begin + # include("attention.jl") + # end + + # @testset "Batched Multiplication" begin + # include("batchedmul.jl") + # end + + # @testset "Convolution" begin + # include("conv.jl") + # include("conv_bias_act.jl") + # end + + # @testset "CTC Loss" begin + # include("ctc.jl") + # end + + # @testset "Dropout" begin + # include("dropout.jl") + # end + + # @testset "Fold/Unfold" begin + # include("fold.jl") + # end + + # @testset "Inference" begin + # include("inference.jl") + # end + + # @testset "Pooling" begin + # include("pooling.jl") + # end + + # @testset "Padding" begin + # include("padding.jl") + # end + + # @testset "Softmax" begin + # include("softmax.jl") + # end + + # @testset "Utilities" begin + # include("utils.jl") + # end + + # @testset "Grid Sampling" begin + # include("sampling.jl") + # end + + # @testset "Functions" begin + # include("functions.jl") + # end + # end end diff --git a/test/scatter.jl b/test/scatter.jl index 0383e8e5d..66a104437 100644 --- a/test/scatter.jl +++ b/test/scatter.jl @@ -15,13 +15,13 @@ idxs = Dict( :int => [1 2 3 4; 4 2 1 3; 3 5 5 3], - :tup => [(1,) (2,) (3,) (4,); - (4,) (2,) (1,) (3,); - (3,) (5,) (5,) (3,)], - :car => CartesianIndex.( - [(1,) (2,) (3,) (4,); - (4,) (2,) (1,) (3,); - (3,) (5,) (5,) (3,)]), + # :tup => [(1,) (2,) (3,) (4,); + # (4,) (2,) (1,) (3,); + # (3,) (5,) (5,) (3,)], + # :car => CartesianIndex.( + # [(1,) (2,) (3,) (4,); + # (4,) (2,) (1,) (3,); + # (3,) (5,) (5,) (3,)]), ) res = Dict( (+, 0, true) => [5, 6, 9, 8, 9], @@ -44,168 +44,205 @@ res = Dict( 6 4 8 8 6], (min, 0, true) => [1, 1, 1, 1, 1], (min, 1, true) => [1 1 1 1 1; - 1 1 1 1 1], + 1 1 1 1 1], (min, 0, false) => [1, 2, 1, 1, 2], (min, 1, false) => [1 2 1 1 2; 2 4 2 2 4], (*, 0, true) => [3, 4, 5, 6, 7], (*, 1, true) => [3 3 4 4 5; - 5 5 6 6 7], + 5 5 6 6 7], (*, 0, false) => [3, 4, 48, 4, 6], (*, 1, false) => [3 4 48 4 6; 12 16 768 16 24], (/, 0, true) => [0.75, 1., 0.3125, 1.5, 1.75], (/, 1, true) => [0.75 0.75 0.25 1. 1.25; - 1.25 1.25 0.375 1.5 1.75], + 1.25 1.25 0.375 1.5 1.75], (/, 0, false) => [1//3, 1//4, 1//48, 1//4, 1//6], (/, 1, false) => [1//3 1//4 1//48 1//4 1//6; 1//12 1//16 1//768 1//16 1//24], (mean, 0, true) => [4., 5., 6., 7., 8.], (mean, 1, true) => [4. 4. 5. 5. 6.; - 6. 6. 7. 7. 8.], + 6. 6. 7. 7. 8.], (mean, 0, false) => [2, 2, 3, 2.5, 2.5], (mean, 1, false) => [2. 2. 3. 2.5 2.5; 4. 4. 6. 5. 5.], ) -types = [UInt8, UInt32, Int64, - Float16, Float32, Float64, BigFloat, Rational] +function scatter_testsuite(Backend) + cpu, backend = CPU(), Backend() + gradtest_fn = backend == CPU() ? gradtest : gputest + # types = [UInt8, UInt32, Int64, Float16, Float32, Float64, BigFloat, Rational] + types = [Int64] -@testset "scatter" begin - for T = types - @testset "$T" begin - PT = promote_type(T, Int) + for T in types + PT = promote_type(T, Int) + @testset "T" begin @testset "+" begin - for idx = values(idxs), dims = [0, 1] - mutated = true - @test scatter!(+, T.(copy(dsts[dims])), T.(srcs[(dims, mutated)]), idx) == T.(res[(+, dims, mutated)]) - @test scatter!(+, T.(copy(dsts[dims])), srcs[(dims, mutated)], idx) == PT.(res[(+, dims, mutated)]) - @test scatter!(+, copy(dsts[dims]), T.(srcs[(dims, mutated)]), idx) == PT.(res[(+, dims, mutated)]) - - mutated = false - @test scatter(+, T.(srcs[(dims, mutated)]), idx) == T.(res[(+, dims, mutated)]) - end - end - - @testset "-" begin - for idx = values(idxs), dims = [0, 1] - mutated = true - @test scatter!(-, T.(copy(dsts[dims])), T.(srcs[(dims, mutated)]), idx) == T.(res[(-, dims, mutated)]) - @test scatter!(-, T.(copy(dsts[dims])), srcs[(dims, mutated)], idx) == PT.(res[(-, dims, mutated)]) - @test scatter!(-, copy(dsts[dims]), T.(srcs[(dims, mutated)]), idx) == PT.(res[(-, dims, mutated)]) - - mutated = false - if !(T in [UInt8, UInt16, UInt32, UInt64, UInt128]) - @test scatter(-, T.(srcs[(dims, mutated)]), idx) == T.(res[(-, dims, mutated)]) - end - end - end - - @testset "max" begin - for idx = values(idxs), dims = [0, 1] - mutated = true - @test scatter!(max, T.(copy(dsts[dims])), T.(srcs[(dims, mutated)]), idx) == T.(res[(max, dims, mutated)]) - @test scatter!(max, T.(copy(dsts[dims])), srcs[(dims, mutated)], idx) == PT.(res[(max, dims, mutated)]) - @test scatter!(max, copy(dsts[dims]), T.(srcs[(dims, mutated)]), idx) == PT.(res[(max, dims, mutated)]) - - mutated = false - if !(T in [BigInt]) - @test scatter(max, T.(srcs[(dims, mutated)]), idx) == T.(res[(max, dims, mutated)]) - end - end - end - - @testset "min" begin - for idx = values(idxs), dims = [0, 1] - mutated = true - @test scatter!(min, T.(copy(dsts[dims])), T.(srcs[(dims, mutated)]), idx) == T.(res[(min, dims, mutated)]) - @test scatter!(min, T.(copy(dsts[dims])), srcs[(dims, mutated)], idx) == PT.(res[(min, dims, mutated)]) - @test scatter!(min, copy(dsts[dims]), T.(srcs[(dims, mutated)]), idx) == PT.(res[(min, dims, mutated)]) - - mutated = false - if !(T in [BigInt]) - @test scatter(min, T.(srcs[(dims, mutated)]), idx) == T.(res[(min, dims, mutated)]) - end - end - end + # for idx = values(idxs), dims = [0, 1] + for idx = values(idxs), dims = [0] + idx = adapt(backend, idx) - @testset "*" begin - for idx = values(idxs), dims = [0, 1] mutated = true - @test scatter!(*, T.(copy(dsts[dims])), T.(srcs[(dims, mutated)]), idx) == T.(res[(*, dims, mutated)]) - @test scatter!(*, T.(copy(dsts[dims])), srcs[(dims, mutated)], idx) == PT.(res[(*, dims, mutated)]) - @test scatter!(*, copy(dsts[dims]), T.(srcs[(dims, mutated)]), idx) == PT.(res[(*, dims, mutated)]) - - mutated = false - if !(T in [UInt8, Int8]) - @test scatter(*, T.(srcs[(dims, mutated)]), idx) == T.(res[(*, dims, mutated)]) - end + src = adapt(backend, srcs[(dims, mutated)]) + dst = adapt(backend, dsts[dims]) + target_y = res[(+, dims, mutated)] + + @show src + @show idx + @show dst + @show target_y + y = scatter!(+, T.(dst), T.(src), idx) + @test adapt(cpu, y) == T.(target_y) + # y = scatter!(+, T.(dst), src, idx) + # @test adapt(cpu, y) == PT.(target_y) + # y = scatter!(+, copy(dst), T.(src), idx) + # @test adapt(cpu, y) == PT.(target_y) + + # mutated = false + # src = adapt(backend, srcs[(dims, mutated)]) + # y = scatter(+, T.(src), idx) + # @test adapt(cpu, y) == T.(res[(+, dims, mutated)]) end end end end - - for T = [Float16, Float32, BigFloat, Rational] - @testset "$T" begin - PT = promote_type(T, Float64) - @testset "/" begin - for idx = values(idxs), dims = [0, 1] - mutated = true - @test scatter!(/, T.(dsts[dims]), T.(srcs[(dims, mutated)].*2), idx) == T.(res[(/, dims, mutated)]) - @test scatter!(/, T.(dsts[dims]), srcs[(dims, mutated)].*2, idx) == PT.(res[(/, dims, mutated)]) - @test scatter!(/, T.(dsts[dims]), T.(srcs[(dims, mutated)].*2), idx) == PT.(res[(/, dims, mutated)]) - - mutated = false - @test scatter(/, T.(srcs[(dims, mutated)]), idx) == T.(res[(/, dims, mutated)]) - end - end - - @testset "mean" begin - for idx = values(idxs), dims = [0, 1] - mutated = true - @test scatter!(mean, T.(dsts[dims]), T.(srcs[(dims, mutated)]), idx) == T.(res[(mean, dims, mutated)]) - @test scatter!(mean, T.(dsts[dims]), srcs[(dims, mutated)], idx) == PT.(res[(mean, dims, mutated)]) - @test scatter!(mean, copy(dsts[dims]), T.(srcs[(dims, mutated)]), idx) == PT.(res[(mean, dims, mutated)]) - - mutated = false - @test scatter(mean, T.(srcs[(dims, mutated)]), idx) == T.(res[(mean, dims, mutated)]) - end - end - end - end - - @test_throws AssertionError scatter!(+, dsts[0], srcs[(1, true)], idxs[:int]) - idx = [1 2 3 4; 4 2 1 3; 6 7 8 9] - @test_throws BoundsError scatter!(+, dsts[1], srcs[(1, true)], idx) - - @testset "dstsize" begin - idx = [2, 2, 3, 4, 4] - src = ones(3, 5) - y = scatter(+, src, idx, dstsize = (3, 6)) - @test size(y) == (3, 6) - gradtest(x -> scatter(+, x, idx, dstsize = (3,6)), src) - end end -@testset "∇scatter" begin - T = Float64 - fdm(op) = op == min ? :backward : :forward - # fdm(op) = :forward - - @testset "∂dst" begin - for op in (+, -, *, /, mean, max, min) - gradtest(xs -> scatter!(op, copy(xs), srcs[(0, true)], idxs[:int]), T.(dsts[0]), fdm=fdm(op)) - gradtest(xs -> scatter!(op, copy(xs), srcs[(1, true)], idxs[:int]), T.(dsts[1]), fdm=fdm(op)) - end - end - - @testset "∂src" begin - for op in (+, -, *, /, mean, max, min) - gradtest(xs -> scatter!(op, T.(dsts[0]), xs, idxs[:int]), T.(srcs[(0, true)]), fdm=fdm(op)) - gradtest(xs -> scatter!(op, T.(dsts[1]), xs, idxs[:int]), T.(srcs[(1, true)]), fdm=fdm(op)) - - gradtest(xs -> scatter(op, xs, idxs[:int]), T.(srcs[(0, false)]), fdm=fdm(op)) - gradtest(xs -> scatter(op, xs, idxs[:int]), T.(srcs[(1, false)]), fdm=fdm(op)) - end - end -end +# @testset "scatter" begin +# for T = types +# @testset "$T" begin +# # PT = promote_type(T, Int) +# # @testset "+" begin +# # for idx = values(idxs), dims = [0, 1] +# # mutated = true +# # @test scatter!(+, T.(copy(dsts[dims])), T.(srcs[(dims, mutated)]), idx) == T.(res[(+, dims, mutated)]) +# # @test scatter!(+, T.(copy(dsts[dims])), srcs[(dims, mutated)], idx) == PT.(res[(+, dims, mutated)]) +# # @test scatter!(+, copy(dsts[dims]), T.(srcs[(dims, mutated)]), idx) == PT.(res[(+, dims, mutated)]) + +# # mutated = false +# # @test scatter(+, T.(srcs[(dims, mutated)]), idx) == T.(res[(+, dims, mutated)]) +# # end +# # end + +# @testset "-" begin +# for idx = values(idxs), dims = [0, 1] +# mutated = true +# @test scatter!(-, T.(copy(dsts[dims])), T.(srcs[(dims, mutated)]), idx) == T.(res[(-, dims, mutated)]) +# @test scatter!(-, T.(copy(dsts[dims])), srcs[(dims, mutated)], idx) == PT.(res[(-, dims, mutated)]) +# @test scatter!(-, copy(dsts[dims]), T.(srcs[(dims, mutated)]), idx) == PT.(res[(-, dims, mutated)]) + +# mutated = false +# if !(T in [UInt8, UInt16, UInt32, UInt64, UInt128]) +# @test scatter(-, T.(srcs[(dims, mutated)]), idx) == T.(res[(-, dims, mutated)]) +# end +# end +# end + +# @testset "max" begin +# for idx = values(idxs), dims = [0, 1] +# mutated = true +# @test scatter!(max, T.(copy(dsts[dims])), T.(srcs[(dims, mutated)]), idx) == T.(res[(max, dims, mutated)]) +# @test scatter!(max, T.(copy(dsts[dims])), srcs[(dims, mutated)], idx) == PT.(res[(max, dims, mutated)]) +# @test scatter!(max, copy(dsts[dims]), T.(srcs[(dims, mutated)]), idx) == PT.(res[(max, dims, mutated)]) + +# mutated = false +# if !(T in [BigInt]) +# @test scatter(max, T.(srcs[(dims, mutated)]), idx) == T.(res[(max, dims, mutated)]) +# end +# end +# end + +# @testset "min" begin +# for idx = values(idxs), dims = [0, 1] +# mutated = true +# @test scatter!(min, T.(copy(dsts[dims])), T.(srcs[(dims, mutated)]), idx) == T.(res[(min, dims, mutated)]) +# @test scatter!(min, T.(copy(dsts[dims])), srcs[(dims, mutated)], idx) == PT.(res[(min, dims, mutated)]) +# @test scatter!(min, copy(dsts[dims]), T.(srcs[(dims, mutated)]), idx) == PT.(res[(min, dims, mutated)]) + +# mutated = false +# if !(T in [BigInt]) +# @test scatter(min, T.(srcs[(dims, mutated)]), idx) == T.(res[(min, dims, mutated)]) +# end +# end +# end + +# @testset "*" begin +# for idx = values(idxs), dims = [0, 1] +# mutated = true +# @test scatter!(*, T.(copy(dsts[dims])), T.(srcs[(dims, mutated)]), idx) == T.(res[(*, dims, mutated)]) +# @test scatter!(*, T.(copy(dsts[dims])), srcs[(dims, mutated)], idx) == PT.(res[(*, dims, mutated)]) +# @test scatter!(*, copy(dsts[dims]), T.(srcs[(dims, mutated)]), idx) == PT.(res[(*, dims, mutated)]) + +# mutated = false +# if !(T in [UInt8, Int8]) +# @test scatter(*, T.(srcs[(dims, mutated)]), idx) == T.(res[(*, dims, mutated)]) +# end +# end +# end +# end +# end + +# for T = [Float16, Float32, BigFloat, Rational] +# @testset "$T" begin +# PT = promote_type(T, Float64) +# @testset "/" begin +# for idx = values(idxs), dims = [0, 1] +# mutated = true +# @test scatter!(/, T.(dsts[dims]), T.(srcs[(dims, mutated)].*2), idx) == T.(res[(/, dims, mutated)]) +# @test scatter!(/, T.(dsts[dims]), srcs[(dims, mutated)].*2, idx) == PT.(res[(/, dims, mutated)]) +# @test scatter!(/, T.(dsts[dims]), T.(srcs[(dims, mutated)].*2), idx) == PT.(res[(/, dims, mutated)]) + +# mutated = false +# @test scatter(/, T.(srcs[(dims, mutated)]), idx) == T.(res[(/, dims, mutated)]) +# end +# end + +# @testset "mean" begin +# for idx = values(idxs), dims = [0, 1] +# mutated = true +# @test scatter!(mean, T.(dsts[dims]), T.(srcs[(dims, mutated)]), idx) == T.(res[(mean, dims, mutated)]) +# @test scatter!(mean, T.(dsts[dims]), srcs[(dims, mutated)], idx) == PT.(res[(mean, dims, mutated)]) +# @test scatter!(mean, copy(dsts[dims]), T.(srcs[(dims, mutated)]), idx) == PT.(res[(mean, dims, mutated)]) + +# mutated = false +# @test scatter(mean, T.(srcs[(dims, mutated)]), idx) == T.(res[(mean, dims, mutated)]) +# end +# end +# end +# end + +# @test_throws AssertionError scatter!(+, dsts[0], srcs[(1, true)], idxs[:int]) +# idx = [1 2 3 4; 4 2 1 3; 6 7 8 9] +# @test_throws BoundsError scatter!(+, dsts[1], srcs[(1, true)], idx) + +# @testset "dstsize" begin +# idx = [2, 2, 3, 4, 4] +# src = ones(3, 5) +# y = scatter(+, src, idx, dstsize = (3, 6)) +# @test size(y) == (3, 6) +# gradtest(x -> scatter(+, x, idx, dstsize = (3,6)), src) +# end +# end + +# @testset "∇scatter" begin +# T = Float64 +# fdm(op) = op == min ? :backward : :forward +# # fdm(op) = :forward + +# @testset "∂dst" begin +# for op in (+, -, *, /, mean, max, min) +# gradtest(xs -> scatter!(op, copy(xs), srcs[(0, true)], idxs[:int]), T.(dsts[0]), fdm=fdm(op)) +# gradtest(xs -> scatter!(op, copy(xs), srcs[(1, true)], idxs[:int]), T.(dsts[1]), fdm=fdm(op)) +# end +# end + +# @testset "∂src" begin +# for op in (+, -, *, /, mean, max, min) +# gradtest(xs -> scatter!(op, T.(dsts[0]), xs, idxs[:int]), T.(srcs[(0, true)]), fdm=fdm(op)) +# gradtest(xs -> scatter!(op, T.(dsts[1]), xs, idxs[:int]), T.(srcs[(1, true)]), fdm=fdm(op)) + +# gradtest(xs -> scatter(op, xs, idxs[:int]), T.(srcs[(0, false)]), fdm=fdm(op)) +# gradtest(xs -> scatter(op, xs, idxs[:int]), T.(srcs[(1, false)]), fdm=fdm(op)) +# end +# end +# end From 1a4c557dd466c6cb06abb30bb347566d4d63f746 Mon Sep 17 00:00:00 2001 From: Anton Smirnov Date: Sun, 9 Apr 2023 12:05:21 +0300 Subject: [PATCH 02/10] Finish scatter --- Project.toml | 1 + src/NNlib.jl | 1 + src/scatter.jl | 124 ++++++++++++++++++------- test/scatter.jl | 237 ++++++++++++++++++------------------------------ 4 files changed, 181 insertions(+), 182 deletions(-) diff --git a/Project.toml b/Project.toml index 5bfef021f..8bc81e773 100644 --- a/Project.toml +++ b/Project.toml @@ -4,6 +4,7 @@ version = "0.8.19" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" +Atomix = "a9b6321e-bd34-4604-b9c9-b65b8de01458" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" diff --git a/src/NNlib.jl b/src/NNlib.jl index 7dfbd8805..183d83cbc 100644 --- a/src/NNlib.jl +++ b/src/NNlib.jl @@ -1,5 +1,6 @@ module NNlib +import Atomix import ChainRulesCore: rrule using Base.Broadcast: broadcasted diff --git a/src/scatter.jl b/src/scatter.jl index f0d17039e..b4e5e54dd 100644 --- a/src/scatter.jl +++ b/src/scatter.jl @@ -70,16 +70,25 @@ julia> NNlib.scatter!(*, fill(0.5, 2, 4), [1 10; 100 1000], [3,2]) 0.5 500.0 50.0 0.5 ``` """ -# function scatter!(op::OP, dst::AbstractArray, src::AbstractArray, idx::AbstractArray) where OP -# dims = scatter_dims(dst, src, idx) -# colons = Base.ntuple(_->Colon(), dims) -# for k in CartesianIndices(idx) -# dst_v = _view(dst, colons, idx[k]) -# src_v = _view(src, colons, k) -# dst_v .= (op).(dst_v, src_v) -# end -# dst -# end +function scatter!(op::OP, dst::AbstractArray, src::AbstractArray, idx::AbstractArray) where OP + dims = scatter_dims(dst, src, idx) + colons = Base.ntuple(_->Colon(), dims) + for k in CartesianIndices(idx) + dst_v = _view(dst, colons, idx[k]) + src_v = _view(src, colons, k) + dst_v .= (op).(dst_v, src_v) + end + dst +end + +for AT in (AbstractArray, AbstractGPUArray) + @eval function scatter!(op::typeof(mean), dst::$AT, src::$AT, idx::$AT) + Ns = scatter!(+, zero(dst), one.(src), idx) + dst_ = scatter!(+, zero(dst), src, idx) + dst .+= safe_div.(dst_, Ns) + return dst + end +end function scatter!(op::OP, dst::AbstractGPUArray, src::AbstractGPUArray, idx::AbstractGPUArray) where OP n_dims = scatter_dims(dst, src, idx) @@ -100,24 +109,22 @@ end @kernel function _scatter!(op::OP, dst, src, idxs) where OP i = @index(Global) idx = Tuple(idxs[i]) - @atomic dst[idx...] = op(dst[idx...], src[i]) + Atomix.modify!(Atomix.IndexableRef(dst, idx), op, src[i]) + # FIXME `@atomic` macro silently fails to perform atomic op below + # @atomic dst[idx...] = op(dst[idx...], src[i]) end -# @kernel function _scatter!( -# op::OP, dst, src, idxs, dim_ids::CartesianIndices, max_dims_idx::Int, -# ) where OP -# i = @index(Global) -# j, k = divrem(i - 1, max_dims_idx) -# dim_i = Tuple(dim_ids[k + 1]) -# idx = idxs[j + 1] -# @atomic dst[dim_i..., idx...] = op(dst[dim_i..., idx...], src[i]) -# end - -function scatter!(op::typeof(mean), dst::AbstractArray, src::AbstractArray, idx::AbstractArray) - Ns = scatter!(+, zero(dst), one.(src), idx) - dst_ = scatter!(+, zero(dst), src, idx) - dst .+= safe_div.(dst_, Ns) - return dst +@kernel function _scatter!( + op::OP, dst, src, idxs, dim_ids::CartesianIndices, max_dims_idx::Int, +) where OP + i = @index(Global) + j, k = divrem(i - 1, max_dims_idx) + idx = (Tuple(dim_ids[k + 1])..., Tuple(idxs[j + 1])...) + Atomix.modify!(Atomix.IndexableRef(dst, idx), op, src[i]) + # FIXME + # dim_i = Tuple(dim_ids[k + 1]) + # idx = idxs[j + 1] + # @atomic dst[dim_i..., idx...] = op(dst[dim_i..., idx...], src[i]) end """ @@ -180,6 +187,8 @@ scatter_empty(op::typeof(mean), T) = zero(T) ## Gradients ∇scatter!_src(op, Δ, dst, src, idx) = ∇scatter_src(op, Δ, dst, src, idx) +∇scatter!_src(op::Union{typeof(*),typeof(/)}, Δ, dst, src, idx) = + gather(dst, idx) .* ∇scatter_src(op, Δ, dst, src, idx) ∇scatter!_dst(op, Δ, dst, y) = Δ ∇scatter!_dst(op::Union{typeof(max),typeof(min)}, Δ, dst_old, dst) = (dst_old .== op.(dst_old, dst)) .* Δ @@ -189,9 +198,10 @@ modify_src(::typeof(-), X) = -X modify_src(::typeof(*), X, Y) = X modify_src(::typeof(/), X, Y) = .-X ./ Y.^2 -∇scatter_src(op::Union{typeof(+),typeof(-)}, Δ, dst, src, idx) = modify_src(op, gather(Δ, idx)) -∇scatter!_src(op::Union{typeof(*),typeof(/)}, Δ, dst, src, idx) = - gather(dst, idx) .* ∇scatter_src(op, Δ, dst, src, idx) +∇scatter_src(op::Union{typeof(+),typeof(-)}, Δ, dst, src, idx) = + modify_src(op, gather(Δ, idx)) +∇scatter_src(::Union{typeof(max),typeof(min)}, Δ, dst, src, idx) = + (src .== gather(dst, idx)) .* gather(Δ, idx) function ∇scatter_src( op::Union{typeof(*),typeof(/)}, Δ, dst, @@ -210,8 +220,57 @@ function ∇scatter_src( Δsrc end -∇scatter_src(::Union{typeof(max),typeof(min)}, Δ, dst, src, idx) = - (src .== gather(dst, idx)) .* gather(Δ, idx) +function ∇scatter_src( + op::Union{typeof(*), typeof(/)}, Δ, dst, + src::AbstractGPUArray{Tsrc, Nsrc}, idx::AbstractGPUArray{Tidx, Nidx}, +) where {Tsrc, Nsrc, Tidx, Nidx} + n_dims = Nsrc - Nidx + Δsrc = NNlib.modify_src(op, NNlib.gather(Δ, idx), src) + rev_idx = NNlib.reverse_indices(idx) + + args = if n_dims == 0 + ndrange = length(idx) + () + else + dims = size(dst)[1:n_dims] + max_dims_idx = prod(dims) + ndrange = max_dims_idx * length(idx) + (CartesianIndices(dims), max_dims_idx) + end + _∇scatter_src(KernelAbstractions.get_backend(src))( + op, Δsrc, src, idx, rev_idx, args...; ndrange) + # TODO KernelAbstractions.unsafe_free!(rev_idx) + return Δsrc +end + +@kernel function _∇scatter_src(op, Δsrc, src::AbstractArray{T}, idx, rev_idx) where T + i = @index(Global) + cart_j = CartesianIndices(idx)[i] + inds = rev_idx[Tuple(idx[cart_j])...] + x = one(T) + for k in inds + x *= src[k] + end + x /= src[cart_j] + Δsrc[cart_j] = op(Δsrc[cart_j], x) +end + +@kernel function _∇scatter_src( + op, Δsrc, src::AbstractArray{T}, idx, rev_idx, + dim_ids::CartesianIndices, max_dims_idx::Int, +) where T + i = @index(Global) + j, k = fldmod1(i, max_dims_idx) + cart_j = CartesianIndices(idx)[j] + cart_k = dim_ids[k] + inds = rev_idx[Tuple(cart_j)...] + x = one(T) + for s in inds + x *= src[Tuple(cart_k)..., Tuple(s)...] + end + x /= src[i] + Δsrc[i] = op(Δsrc[i], x) +end function ∇scatter_src( ::typeof(mean), Δ, dst, @@ -229,8 +288,7 @@ function ∇scatter_src( return safe_div.(num, den) end -∇scatter_src(op, Δ, dst, src, idx) = - ∇scatter_src(op, unthunk(Δ), dst, src, idx) +∇scatter_src(op, Δ, dst, src, idx) = ∇scatter_src(op, unthunk(Δ), dst, src, idx) function rrule(::typeof(scatter!), op, dst::AbstractArray, src::AbstractArray, idx::AbstractArray) dst_old = copy(dst) diff --git a/test/scatter.jl b/test/scatter.jl index 66a104437..05cdc2e42 100644 --- a/test/scatter.jl +++ b/test/scatter.jl @@ -15,13 +15,13 @@ idxs = Dict( :int => [1 2 3 4; 4 2 1 3; 3 5 5 3], - # :tup => [(1,) (2,) (3,) (4,); - # (4,) (2,) (1,) (3,); - # (3,) (5,) (5,) (3,)], - # :car => CartesianIndex.( - # [(1,) (2,) (3,) (4,); - # (4,) (2,) (1,) (3,); - # (3,) (5,) (5,) (3,)]), + :tup => [(1,) (2,) (3,) (4,); + (4,) (2,) (1,) (3,); + (3,) (5,) (5,) (3,)], + :car => CartesianIndex.( + [(1,) (2,) (3,) (4,); + (4,) (2,) (1,) (3,); + (3,) (5,) (5,) (3,)]), ) res = Dict( (+, 0, true) => [5, 6, 9, 8, 9], @@ -68,152 +68,97 @@ res = Dict( 4. 4. 6. 5. 5.], ) -function scatter_testsuite(Backend) - cpu, backend = CPU(), Backend() - gradtest_fn = backend == CPU() ? gradtest : gputest - # types = [UInt8, UInt32, Int64, Float16, Float32, Float64, BigFloat, Rational] - types = [Int64] - +function test_scatter(backend, types, ops; pt, ops_skip_types) + cpu = CPU() for T in types - PT = promote_type(T, Int) - @testset "T" begin - @testset "+" begin - # for idx = values(idxs), dims = [0, 1] - for idx = values(idxs), dims = [0] - idx = adapt(backend, idx) - - mutated = true - src = adapt(backend, srcs[(dims, mutated)]) - dst = adapt(backend, dsts[dims]) - target_y = res[(+, dims, mutated)] - - @show src - @show idx - @show dst - @show target_y - y = scatter!(+, T.(dst), T.(src), idx) - @test adapt(cpu, y) == T.(target_y) - # y = scatter!(+, T.(dst), src, idx) - # @test adapt(cpu, y) == PT.(target_y) - # y = scatter!(+, copy(dst), T.(src), idx) - # @test adapt(cpu, y) == PT.(target_y) - - # mutated = false - # src = adapt(backend, srcs[(dims, mutated)]) - # y = scatter(+, T.(src), idx) - # @test adapt(cpu, y) == T.(res[(+, dims, mutated)]) + PT = promote_type(T, pt) + @testset failfast=true "$T" begin + for op in ops + skip_types = get(ops_skip_types, op, []) + @testset "$op" begin + for idx = values(idxs), dims = [0, 1] + idx = adapt(backend, idx) + dst = adapt(backend, dsts[dims]) + + mutated = true + target_y = res[(op, dims, mutated)] + src = adapt(backend, srcs[(dims, mutated)]) + if op == / + src = src .* T(2) + end + + @test adapt(cpu, scatter!(op, T.(dst), T.(src), idx)) == T.(target_y) + @test adapt(cpu, scatter!(op, T.(dst), src, idx)) == PT.(target_y) + if op == / + @test adapt(cpu, scatter!(op, T.(dst), T.(src), idx)) == PT.(target_y) + else + @test adapt(cpu, scatter!(op, copy(dst), T.(src), idx)) == PT.(target_y) + end + + if T ∉ skip_types + mutated = false + src = adapt(backend, srcs[(dims, mutated)]) + @test adapt(cpu, scatter(op, T.(src), idx)) == T.(res[(op, dims, mutated)]) + end + end end end end end end -# @testset "scatter" begin -# for T = types -# @testset "$T" begin -# # PT = promote_type(T, Int) -# # @testset "+" begin -# # for idx = values(idxs), dims = [0, 1] -# # mutated = true -# # @test scatter!(+, T.(copy(dsts[dims])), T.(srcs[(dims, mutated)]), idx) == T.(res[(+, dims, mutated)]) -# # @test scatter!(+, T.(copy(dsts[dims])), srcs[(dims, mutated)], idx) == PT.(res[(+, dims, mutated)]) -# # @test scatter!(+, copy(dsts[dims]), T.(srcs[(dims, mutated)]), idx) == PT.(res[(+, dims, mutated)]) - -# # mutated = false -# # @test scatter(+, T.(srcs[(dims, mutated)]), idx) == T.(res[(+, dims, mutated)]) -# # end -# # end - -# @testset "-" begin -# for idx = values(idxs), dims = [0, 1] -# mutated = true -# @test scatter!(-, T.(copy(dsts[dims])), T.(srcs[(dims, mutated)]), idx) == T.(res[(-, dims, mutated)]) -# @test scatter!(-, T.(copy(dsts[dims])), srcs[(dims, mutated)], idx) == PT.(res[(-, dims, mutated)]) -# @test scatter!(-, copy(dsts[dims]), T.(srcs[(dims, mutated)]), idx) == PT.(res[(-, dims, mutated)]) - -# mutated = false -# if !(T in [UInt8, UInt16, UInt32, UInt64, UInt128]) -# @test scatter(-, T.(srcs[(dims, mutated)]), idx) == T.(res[(-, dims, mutated)]) -# end -# end -# end - -# @testset "max" begin -# for idx = values(idxs), dims = [0, 1] -# mutated = true -# @test scatter!(max, T.(copy(dsts[dims])), T.(srcs[(dims, mutated)]), idx) == T.(res[(max, dims, mutated)]) -# @test scatter!(max, T.(copy(dsts[dims])), srcs[(dims, mutated)], idx) == PT.(res[(max, dims, mutated)]) -# @test scatter!(max, copy(dsts[dims]), T.(srcs[(dims, mutated)]), idx) == PT.(res[(max, dims, mutated)]) - -# mutated = false -# if !(T in [BigInt]) -# @test scatter(max, T.(srcs[(dims, mutated)]), idx) == T.(res[(max, dims, mutated)]) -# end -# end -# end - -# @testset "min" begin -# for idx = values(idxs), dims = [0, 1] -# mutated = true -# @test scatter!(min, T.(copy(dsts[dims])), T.(srcs[(dims, mutated)]), idx) == T.(res[(min, dims, mutated)]) -# @test scatter!(min, T.(copy(dsts[dims])), srcs[(dims, mutated)], idx) == PT.(res[(min, dims, mutated)]) -# @test scatter!(min, copy(dsts[dims]), T.(srcs[(dims, mutated)]), idx) == PT.(res[(min, dims, mutated)]) - -# mutated = false -# if !(T in [BigInt]) -# @test scatter(min, T.(srcs[(dims, mutated)]), idx) == T.(res[(min, dims, mutated)]) -# end -# end -# end - -# @testset "*" begin -# for idx = values(idxs), dims = [0, 1] -# mutated = true -# @test scatter!(*, T.(copy(dsts[dims])), T.(srcs[(dims, mutated)]), idx) == T.(res[(*, dims, mutated)]) -# @test scatter!(*, T.(copy(dsts[dims])), srcs[(dims, mutated)], idx) == PT.(res[(*, dims, mutated)]) -# @test scatter!(*, copy(dsts[dims]), T.(srcs[(dims, mutated)]), idx) == PT.(res[(*, dims, mutated)]) - -# mutated = false -# if !(T in [UInt8, Int8]) -# @test scatter(*, T.(srcs[(dims, mutated)]), idx) == T.(res[(*, dims, mutated)]) -# end -# end -# end -# end -# end - -# for T = [Float16, Float32, BigFloat, Rational] -# @testset "$T" begin -# PT = promote_type(T, Float64) -# @testset "/" begin -# for idx = values(idxs), dims = [0, 1] -# mutated = true -# @test scatter!(/, T.(dsts[dims]), T.(srcs[(dims, mutated)].*2), idx) == T.(res[(/, dims, mutated)]) -# @test scatter!(/, T.(dsts[dims]), srcs[(dims, mutated)].*2, idx) == PT.(res[(/, dims, mutated)]) -# @test scatter!(/, T.(dsts[dims]), T.(srcs[(dims, mutated)].*2), idx) == PT.(res[(/, dims, mutated)]) - -# mutated = false -# @test scatter(/, T.(srcs[(dims, mutated)]), idx) == T.(res[(/, dims, mutated)]) -# end -# end +function scatter_testsuite(Backend) + backend = Backend() + gradtest_fn = backend == CPU() ? gradtest : gputest -# @testset "mean" begin -# for idx = values(idxs), dims = [0, 1] -# mutated = true -# @test scatter!(mean, T.(dsts[dims]), T.(srcs[(dims, mutated)]), idx) == T.(res[(mean, dims, mutated)]) -# @test scatter!(mean, T.(dsts[dims]), srcs[(dims, mutated)], idx) == PT.(res[(mean, dims, mutated)]) -# @test scatter!(mean, copy(dsts[dims]), T.(srcs[(dims, mutated)]), idx) == PT.(res[(mean, dims, mutated)]) + ops_skip_types = Dict( + (+) => [], + (-) => [UInt8, UInt16, UInt32, UInt64, UInt128], + (*) => [UInt8, Int8], + max => [BigInt], + min => [BigInt]) + types = if backend == CPU() + [UInt8, UInt32, UInt64, Int32, Int64, Float16, Float32, Float64, BigFloat, Rational] + elseif Symbol(typeof(backend)) == :CUDABackend + [Int32, Int64, Float32, Float64] + else + # Need LLVM 15+ for atomic fmin/fmax: + # https://reviews.llvm.org/D127041 + # But min/max can be done by reinterpreting an array to `UInt`. + [Int32, Int64, UInt32, UInt64] + end + ops = backend == CPU() ? + (+, -, max, min, *) : + (+, -, max, min) + test_scatter(backend, types, ops; pt=Int, ops_skip_types) + + types = backend == CPU() ? + [Float16, Float32, BigFloat, Rational] : + [Float32, Float64] + ops = if backend == CPU() + (/, mean) + elseif Symbol(typeof(backend)) == :CUDABackend + (*, /, mean) + else + # LLVM does not support atomic fmul/fdiv: + # https://llvm.org/docs/LangRef.html#atomicrmw-instruction + (mean,) + end + test_scatter(backend, types, ops; pt=Float64, ops_skip_types=Dict()) -# mutated = false -# @test scatter(mean, T.(srcs[(dims, mutated)]), idx) == T.(res[(mean, dims, mutated)]) -# end -# end -# end -# end + if backend == CPU() + @testset "scatter exceptions" begin + @test_throws AssertionError scatter!(+, dsts[0], srcs[(1, true)], idxs[:int]) + idx = [1 2 3 4; 4 2 1 3; 6 7 8 9] + @test_throws BoundsError scatter!(+, dsts[1], srcs[(1, true)], idx) + end + end +end -# @test_throws AssertionError scatter!(+, dsts[0], srcs[(1, true)], idxs[:int]) -# idx = [1 2 3 4; 4 2 1 3; 6 7 8 9] -# @test_throws BoundsError scatter!(+, dsts[1], srcs[(1, true)], idx) +# @testset "∇scatter" begin +# T = Float64 +# fdm(op) = op == min ? :backward : :forward +# # fdm(op) = :forward # @testset "dstsize" begin # idx = [2, 2, 3, 4, 4] @@ -222,12 +167,6 @@ end # @test size(y) == (3, 6) # gradtest(x -> scatter(+, x, idx, dstsize = (3,6)), src) # end -# end - -# @testset "∇scatter" begin -# T = Float64 -# fdm(op) = op == min ? :backward : :forward -# # fdm(op) = :forward # @testset "∂dst" begin # for op in (+, -, *, /, mean, max, min) From ea4f155e694a62c3af0452d9a51ee16472c287c0 Mon Sep 17 00:00:00 2001 From: Anton Smirnov Date: Sun, 9 Apr 2023 21:14:48 +0300 Subject: [PATCH 03/10] Update testsuite for gather/scatter --- test/gather.jl | 79 +++++++++-------- test/runtests.jl | 218 +++++++++++++++++++++++------------------------ test/scatter.jl | 87 ++++++++++++------- 3 files changed, 209 insertions(+), 175 deletions(-) diff --git a/test/gather.jl b/test/gather.jl index 818e9ad5e..1cf88843f 100644 --- a/test/gather.jl +++ b/test/gather.jl @@ -133,40 +133,49 @@ function gather_testsuite(Backend) @test y isa Array{T,3} @test size(y) == (size(src)[1:Nsrc-M]..., size(index)...) end + + @testset "gather gradient for scalar index" begin + src = adapt(backend, Float64[3, 4, 5, 6, 7]) + idx = adapt(backend, [ + 1 2 3 4; + 4 2 1 3; + 3 5 5 3]) + dst = adapt(backend, Float64[ + 3 4 5 6; + 6 4 3 5; + 5 7 7 5]) + backend == cpu ? + gradtest_fn(xs -> gather!(dst, xs, idx), src) : + gradtest_fn((d, s, i) -> gather!(d, s, i), dst, src, idx) + backend == cpu ? + gradtest_fn(xs -> gather(xs, idx), src) : + gradtest_fn((s, i) -> gather(s, i), src, idx) + end + + @testset "gather gradient for tuple index" begin + src = adapt(backend, Float64[ + 3 5 7 + 4 6 8]) + idx = adapt(backend, [(1,1), (1,2), (1,3), (2,1), (2,2), (2,3)]) + dst = adapt(backend, Float64[3, 5, 7, 4, 6, 8]) + backend == cpu ? + gradtest_fn(xs -> gather!(dst, xs, idx), src) : + gradtest_fn((d, s, i) -> gather!(d, s, i), dst, src, idx) + backend == cpu ? + gradtest_fn(xs -> gather(xs, idx), src) : + gradtest_fn((s, i) -> gather(s, i), src, idx) + end + + @testset "gather(src, IJK...)" begin + x = adapt(backend, reshape([1:15;], 3, 5)) + i, j = adapt(backend, [1,2]), adapt(backend, [2,4]) + y = gather(x, i, j) + @test adapt(cpu, y) == [4, 11] + y = gather(x, adapt(backend, [1, 2])) + @test adapt(cpu, y) == [ + 1 4 + 2 5 + 3 6] + end end -# @testset "gather gradient for scalar index" begin -# T = Float64 -# src = T[3, 4, 5, 6, 7] -# index = [1 2 3 4; -# 4 2 1 3; -# 3 5 5 3] -# dst = T[3 4 5 6; -# 6 4 3 5; -# 5 7 7 5] - -# gradtest(xs -> gather!(dst, xs, index), src) -# gradtest(xs -> gather(xs, index), src) -# end - -# @testset "gather gradient for tuple index" begin -# T = Float64 -# src = T[3 5 7 -# 4 6 8] -# index = [(1,1), (1,2), (1,3), (2,1), (2,2), (2,3)] -# dst = T[3, 5, 7, 4, 6, 8] - -# gradtest(xs -> gather!(dst, xs, index), src) -# gradtest(xs -> gather(xs, index), src) -# end - -# @testset "gather(src, IJK...)" begin -# x = reshape([1:15;], 3, 5) - -# y = gather(x, [1,2], [2,4]) -# @test y == [4, 11] - -# @test gather(x, [1, 2]) == [1 4 -# 2 5 -# 3 6] -# end diff --git a/test/runtests.jl b/test/runtests.jl index a3db7c8f2..a292a072c 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -33,12 +33,12 @@ include("scatter.jl") include("upsample.jl") function nnlib_testsuite(Backend; skip_tests = Set{String}()) - # @conditional_testset "Upsample" skip_tests begin - # upsample_testsuite(Backend) - # end - # @conditional_testset "Gather" skip_tests begin - # gather_testsuite(Backend) - # end + @conditional_testset "Upsample" skip_tests begin + upsample_testsuite(Backend) + end + @conditional_testset "Gather" skip_tests begin + gather_testsuite(Backend) + end @conditional_testset "Scatter" skip_tests begin scatter_testsuite(Backend) end @@ -88,107 +88,107 @@ end end end - # @testset verbose=true "Tests" begin - # if get(ENV, "NNLIB_TEST_CUDA", "false") == "true" - # using CUDA - # if CUDA.functional() - # import Pkg - # using NNlibCUDA - # @testset "CUDA" begin - # Pkg.test("NNlibCUDA") - # end - # else - # @info "Insufficient version or CUDA not found; Skipping CUDA tests" - # end - # else - # @info "Skipping CUDA tests, set NNLIB_TEST_CUDA=true to run them" - # end - - # if get(ENV, "NNLIB_TEST_AMDGPU", "false") == "true" - # import Pkg - # test_info = Pkg.project() - # # Add MIOpen_jll to AMDGPU. - # Pkg.develop("AMDGPU") - # Pkg.activate(joinpath(Pkg.devdir(), "AMDGPU")) - # Pkg.add("MIOpen_jll") - # Pkg.update() - # # Update test project. - # Pkg.activate(test_info.path) - # Pkg.update() - - # using AMDGPU - # AMDGPU.versioninfo() - # if AMDGPU.functional() && AMDGPU.functional(:MIOpen) - # @show AMDGPU.MIOpen.version() - # @testset "AMDGPU" begin - # include("ext_amdgpu/runtests.jl") - # end - # else - # @info "AMDGPU.jl package is not functional. Skipping AMDGPU tests." - # end - # else - # @info "Skipping AMDGPU tests, set NNLIB_TEST_CUDA=true to run them." - # end - - # @testset "Doctests" begin - # doctest(NNlib, manual=false) - # end - - # @testset "Activation Functions" begin - # include("activations.jl") - # end - - # @testset "Attention" begin - # include("attention.jl") - # end - - # @testset "Batched Multiplication" begin - # include("batchedmul.jl") - # end - - # @testset "Convolution" begin - # include("conv.jl") - # include("conv_bias_act.jl") - # end - - # @testset "CTC Loss" begin - # include("ctc.jl") - # end - - # @testset "Dropout" begin - # include("dropout.jl") - # end - - # @testset "Fold/Unfold" begin - # include("fold.jl") - # end - - # @testset "Inference" begin - # include("inference.jl") - # end - - # @testset "Pooling" begin - # include("pooling.jl") - # end - - # @testset "Padding" begin - # include("padding.jl") - # end - - # @testset "Softmax" begin - # include("softmax.jl") - # end - - # @testset "Utilities" begin - # include("utils.jl") - # end - - # @testset "Grid Sampling" begin - # include("sampling.jl") - # end - - # @testset "Functions" begin - # include("functions.jl") - # end - # end + @testset verbose=true "Tests" begin + if get(ENV, "NNLIB_TEST_CUDA", "false") == "true" + using CUDA + if CUDA.functional() + import Pkg + using NNlibCUDA + @testset "CUDA" begin + Pkg.test("NNlibCUDA") + end + else + @info "Insufficient version or CUDA not found; Skipping CUDA tests" + end + else + @info "Skipping CUDA tests, set NNLIB_TEST_CUDA=true to run them" + end + + if get(ENV, "NNLIB_TEST_AMDGPU", "false") == "true" + import Pkg + test_info = Pkg.project() + # Add MIOpen_jll to AMDGPU. + Pkg.develop("AMDGPU") + Pkg.activate(joinpath(Pkg.devdir(), "AMDGPU")) + Pkg.add("MIOpen_jll") + Pkg.update() + # Update test project. + Pkg.activate(test_info.path) + Pkg.update() + + using AMDGPU + AMDGPU.versioninfo() + if AMDGPU.functional() && AMDGPU.functional(:MIOpen) + @show AMDGPU.MIOpen.version() + @testset "AMDGPU" begin + include("ext_amdgpu/runtests.jl") + end + else + @info "AMDGPU.jl package is not functional. Skipping AMDGPU tests." + end + else + @info "Skipping AMDGPU tests, set NNLIB_TEST_CUDA=true to run them." + end + + @testset "Doctests" begin + doctest(NNlib, manual=false) + end + + @testset "Activation Functions" begin + include("activations.jl") + end + + @testset "Attention" begin + include("attention.jl") + end + + @testset "Batched Multiplication" begin + include("batchedmul.jl") + end + + @testset "Convolution" begin + include("conv.jl") + include("conv_bias_act.jl") + end + + @testset "CTC Loss" begin + include("ctc.jl") + end + + @testset "Dropout" begin + include("dropout.jl") + end + + @testset "Fold/Unfold" begin + include("fold.jl") + end + + @testset "Inference" begin + include("inference.jl") + end + + @testset "Pooling" begin + include("pooling.jl") + end + + @testset "Padding" begin + include("padding.jl") + end + + @testset "Softmax" begin + include("softmax.jl") + end + + @testset "Utilities" begin + include("utils.jl") + end + + @testset "Grid Sampling" begin + include("sampling.jl") + end + + @testset "Functions" begin + include("functions.jl") + end + end end diff --git a/test/scatter.jl b/test/scatter.jl index 05cdc2e42..f60d980e4 100644 --- a/test/scatter.jl +++ b/test/scatter.jl @@ -72,7 +72,7 @@ function test_scatter(backend, types, ops; pt, ops_skip_types) cpu = CPU() for T in types PT = promote_type(T, pt) - @testset failfast=true "$T" begin + @testset "$T" begin for op in ops skip_types = get(ops_skip_types, op, []) @testset "$op" begin @@ -124,7 +124,7 @@ function scatter_testsuite(Backend) else # Need LLVM 15+ for atomic fmin/fmax: # https://reviews.llvm.org/D127041 - # But min/max can be done by reinterpreting an array to `UInt`. + # But fmin/fmax can be done by reinterpreting an array to `UInt`. [Int32, Int64, UInt32, UInt64] end ops = backend == CPU() ? @@ -148,40 +148,65 @@ function scatter_testsuite(Backend) if backend == CPU() @testset "scatter exceptions" begin - @test_throws AssertionError scatter!(+, dsts[0], srcs[(1, true)], idxs[:int]) idx = [1 2 3 4; 4 2 1 3; 6 7 8 9] - @test_throws BoundsError scatter!(+, dsts[1], srcs[(1, true)], idx) + @test_throws AssertionError scatter!(+, copy(dsts[0]), srcs[(1, true)], idxs[:int]) + @test_throws BoundsError scatter!(+, copy(dsts[1]), srcs[(1, true)], idx) end end -end -# @testset "∇scatter" begin -# T = Float64 -# fdm(op) = op == min ? :backward : :forward -# # fdm(op) = :forward + @testset "∇scatter" begin + T = Float64 + fdm(op) = op == min ? :backward : :forward -# @testset "dstsize" begin -# idx = [2, 2, 3, 4, 4] -# src = ones(3, 5) -# y = scatter(+, src, idx, dstsize = (3, 6)) -# @test size(y) == (3, 6) -# gradtest(x -> scatter(+, x, idx, dstsize = (3,6)), src) -# end + @testset "dstsize" begin + idx = adapt(backend, [2, 2, 3, 4, 4]) + src = adapt(backend, ones(T, 3, 5)) + y = scatter(+, src, idx, dstsize = (3, 6)) + @test eltype(y) == T + @test size(y) == (3, 6) + backend == CPU() ? + gradtest_fn(x -> scatter(+, x, idx; dstsize=(3, 6)), src) : + gradtest_fn((x, i) -> scatter(+, x, i; dstsize=(3, 6)), src, idx) + end -# @testset "∂dst" begin -# for op in (+, -, *, /, mean, max, min) -# gradtest(xs -> scatter!(op, copy(xs), srcs[(0, true)], idxs[:int]), T.(dsts[0]), fdm=fdm(op)) -# gradtest(xs -> scatter!(op, copy(xs), srcs[(1, true)], idxs[:int]), T.(dsts[1]), fdm=fdm(op)) -# end -# end + @testset "∂dst" begin + ops = if backend == CPU() || Symbol(typeof(backend)) == :CUDABackend + (+, -, *, /, mean, max, min) + else + (+, -, mean, max, min) + end + for op in ops, i in (0, 1) + PT = ( # If not CPU and CUDA -> use Int64 for min/max. + backend != CPU() && + Symbol(typeof(backend)) != :CUDABackend && + (op == max || op == min)) ? Int64 : T -# @testset "∂src" begin -# for op in (+, -, *, /, mean, max, min) -# gradtest(xs -> scatter!(op, T.(dsts[0]), xs, idxs[:int]), T.(srcs[(0, true)]), fdm=fdm(op)) -# gradtest(xs -> scatter!(op, T.(dsts[1]), xs, idxs[:int]), T.(srcs[(1, true)]), fdm=fdm(op)) + src = adapt(backend, srcs[(i, true)]) + idx = adapt(backend, idxs[:int]) + dst = adapt(backend, PT.(dsts[i])) + backend == CPU() ? + gradtest_fn(x -> scatter!(op, copy(x), src, idx), dst; fdm=fdm(op)) : + gradtest_fn((x, s, i) -> scatter!(op, x, s, i), dst, src, idx) + end + end -# gradtest(xs -> scatter(op, xs, idxs[:int]), T.(srcs[(0, false)]), fdm=fdm(op)) -# gradtest(xs -> scatter(op, xs, idxs[:int]), T.(srcs[(1, false)]), fdm=fdm(op)) -# end -# end -# end + @testset "∂src" begin + ops = if backend == CPU() || Symbol(typeof(backend)) == :CUDABackend + (+, -, *, /, mean, max, min) + else + (+, -, mean, max, min) + end + for op in ops, i in (0, 1) + PT = ( # If not CPU and CUDA -> use Int64 for min/max. + backend != CPU() && + Symbol(typeof(backend)) != :CUDABackend && + (op == max || op == min)) ? Int64 : T + src = PT.(adapt(backend, srcs[(i, false)])) + idx = adapt(backend, idxs[:int]) + backend == CPU() ? + gradtest_fn(xs -> scatter(op, xs, idx), src; fdm=fdm(op)) : + gradtest_fn((xs, i) -> scatter(op, xs, i), src, idx) + end + end + end +end From fac67d5dc2699cbbcb794f80c4236d958e7f3af4 Mon Sep 17 00:00:00 2001 From: Anton Smirnov Date: Sun, 9 Apr 2023 21:17:36 +0300 Subject: [PATCH 04/10] Cleanup --- ext/NNlibCUDA/src/NNlibCUDA.jl | 1 - ext/NNlibCUDA/src/scatter.jl | 187 --------------------------------- ext/NNlibCUDA/test/runtests.jl | 1 - ext/NNlibCUDA/test/scatter.jl | 106 ------------------- 4 files changed, 295 deletions(-) delete mode 100644 ext/NNlibCUDA/src/scatter.jl delete mode 100644 ext/NNlibCUDA/test/scatter.jl diff --git a/ext/NNlibCUDA/src/NNlibCUDA.jl b/ext/NNlibCUDA/src/NNlibCUDA.jl index e1c3c225c..c55a8ab10 100644 --- a/ext/NNlibCUDA/src/NNlibCUDA.jl +++ b/ext/NNlibCUDA/src/NNlibCUDA.jl @@ -12,7 +12,6 @@ include("batchedadjtrans.jl") include("batchedmul.jl") include("ctc.jl") include("fold.jl") -include("scatter.jl") include("utils.jl") include("cudnn/cudnn.jl") include("cudnn/conv.jl") diff --git a/ext/NNlibCUDA/src/scatter.jl b/ext/NNlibCUDA/src/scatter.jl deleted file mode 100644 index f839554f1..000000000 --- a/ext/NNlibCUDA/src/scatter.jl +++ /dev/null @@ -1,187 +0,0 @@ -# supported op: +, -, *, /, max, min, &, |, mean - -function scatter_kernel!(op::OP, dst, src, idx) where OP - index = threadIdx().x + (blockIdx().x - 1) * blockDim().x - - @inbounds if index <= length(idx) - CUDA.@atomic dst[idx[index]...] = op(dst[idx[index]...], src[index]) - end - return nothing -end - -function scatter_kernel!(op::OP, dst, src, idx::CUDA.CuDeviceArray{<:CartesianIndex}) where OP - index = threadIdx().x + (blockIdx().x - 1) * blockDim().x - - @inbounds if index <= length(idx) - li = Base._to_linear_index(dst, Tuple(idx[index])...) - CUDA.@atomic dst[li] = op(dst[li], src[index]) - end - return nothing -end - -function scatter_kernel!(op::OP, dst, src, idx, max_idx, max_dims_idx, dims_size) where OP - index = threadIdx().x + (blockIdx().x - 1) * blockDim().x - - @inbounds if index <= max_idx - j, k = divrem(index-1, max_dims_idx) - dims_i = CartesianIndices(dims_size)[k+1] - CUDA.@atomic dst[Tuple(dims_i)..., idx[j+1]...] = op(dst[Tuple(dims_i)..., idx[j+1]...], src[index]) - end - return nothing -end - -function scatter_kernel!(op::OP, dst, src, idx::CUDA.CuDeviceArray{<:CartesianIndex}, max_idx, max_dims_idx, dims_size) where OP - index = threadIdx().x + (blockIdx().x - 1) * blockDim().x - - @inbounds if index <= max_idx - j, k = divrem(index-1, max_dims_idx) - dims_i = CartesianIndices(dims_size)[k+1] - li = Base._to_linear_index(dst, Tuple(dims_i)..., Tuple(idx[j+1])...) - CUDA.@atomic dst[li] = op(dst[li], src[index]) - end - return nothing -end - -function NNlib.scatter!(op::OP, dst::AnyCuArray, src::AnyCuArray, idx::AnyCuArray) where OP - dims = NNlib.scatter_dims(dst, src, idx) - args = if dims == 0 - max_idx = length(idx) - op, dst, src, idx - else - dims_size = size(dst)[1:dims] - max_dims_idx = prod(dims_size) - max_idx = max_dims_idx * length(idx) - op, dst, src, idx, max_idx, max_dims_idx, dims_size - end - - kernel = @cuda launch=false scatter_kernel!(args...) - config = launch_configuration(kernel.fun; max_threads=256) - threads = min(max_idx, config.threads) - blocks = cld(max_idx, threads) - kernel(args...; threads=threads, blocks=blocks) - return dst -end - -function NNlib.scatter!(op::typeof(mean), dst::AnyCuArray, src::AnyCuArray, idx::AnyCuArray) - Ns = NNlib.scatter!(+, zero(dst), one.(src), idx) - dst_ = NNlib.scatter!(+, zero(dst), src, idx) - dst .+= NNlib.safe_div.(dst_, Ns) - return dst -end - - -## Gradients - -function ∇scatter_src_kernel!(op::OP, Δsrc, src, idx, - rev_idx, max_idx, T::Type{TT}) where {OP,TT} - index = threadIdx().x + (blockIdx().x - 1) * blockDim().x - - @inbounds if index <= max_idx - cart_j = CartesianIndices(idx)[index] - # get aggregating indeices, which is to be aggregated together, and itself index - inds = rev_idx[idx[cart_j]...] - # multiply all values to be aggregated but not itself - x = one(T) - for k in inds - x *= src[k] - end - x /= src[cart_j] - # apply `op` on `Δsrc[i, k]` and `x` - Δsrc[cart_j] = op(Δsrc[cart_j], x) - end - return nothing -end - -function ∇scatter_src_kernel!(op::OP, Δsrc, src, idx::CUDA.CuDeviceArray{<:CartesianIndex}, - rev_idx, max_idx, T::Type{TT}) where {OP,TT} - index = threadIdx().x + (blockIdx().x - 1) * blockDim().x - - @inbounds if index <= max_idx - cart_j = CartesianIndices(idx)[index] - # get aggregating indeices, which is to be aggregated together, and itself index - inds = rev_idx[Tuple(idx[cart_j])...] - # multiply all values to be aggregated but not itself - x = one(T) - for k in inds - x *= src[k] - end - x /= src[cart_j] - # apply `op` on `Δsrc[i, k]` and `x` - Δsrc[cart_j] = op(Δsrc[cart_j], x) - end - return nothing -end - -function ∇scatter_src_kernel!(op::OP, Δsrc, src, idx, - rev_idx, pre_cart_idx, max_dims_idx, max_idx, T::Type{TT}) where {OP,TT} - index = threadIdx().x + (blockIdx().x - 1) * blockDim().x - - @inbounds if index <= max_idx - i, j = fldmod1(index, max_dims_idx) - cart_i = CartesianIndices(idx)[i] - cart_j = pre_cart_idx[j] - # get aggregating indeices, which is to be aggregated together, and itself index - inds = rev_idx[idx[cart_i]...] - # multiply all values to be aggregated but not itself - x = one(T) - for k in inds - jk = Base._to_linear_index(src, Tuple(cart_j)..., Tuple(k)...) - x *= src[jk] - end - x /= src[index] - # apply `op` on `Δsrc[i, k]` and `x` - Δsrc[index] = op(Δsrc[index], x) - end - return nothing -end - -function ∇scatter_src_kernel!(op::OP, Δsrc, src, idx::CUDA.CuDeviceArray{<:CartesianIndex}, - rev_idx, pre_cart_idx, max_dims_idx, max_idx, T::Type{TT}) where {OP,TT} - index = threadIdx().x + (blockIdx().x - 1) * blockDim().x - - @inbounds if index <= max_idx - i, j = fldmod1(index, max_dims_idx) - cart_i = CartesianIndices(idx)[i] - cart_j = pre_cart_idx[j] - # get aggregating indeices, which is to be aggregated together, and itself index - inds = rev_idx[Tuple(idx[cart_i])...] - # multiply all values to be aggregated but not itself - x = one(T) - for k in inds - jk = Base._to_linear_index(src, Tuple(cart_j)..., Tuple(k)...) - x *= src[jk] - end - x /= src[index] - # apply `op` on `Δsrc[i, k]` and `x` - Δsrc[index] = op(Δsrc[index], x) - end - return nothing -end - -function NNlib.∇scatter_src(op::Union{typeof(*),typeof(/)}, Δ, dst, - src::AnyCuArray{Tsrc,Nsrc}, - idx::AnyCuArray{Tidx,Nidx}) where {Tsrc,Tidx,Nsrc,Nidx} - dims = Nsrc - Nidx - Δsrc = NNlib.modify_src(op, NNlib.gather(Δ, idx), src) - rev_idx = NNlib.reverse_indices(idx) - rev_idx = CuArray(map(CUDA.cudaconvert, rev_idx)) - - if dims == 0 - max_idx = length(idx) - args = op, Δsrc, src, idx, rev_idx, max_idx, Tsrc - else - pre_cart_idx = CartesianIndices(axes(src)[1:dims]) - max_dims_idx = length(pre_cart_idx) - max_idx = max_dims_idx * length(idx) - args = op, Δsrc, src, idx, rev_idx, pre_cart_idx, max_dims_idx, max_idx, Tsrc - end - - kernel = @cuda launch=false ∇scatter_src_kernel!(args...) - config = launch_configuration(kernel.fun; max_threads=256) - threads = min(max_idx, config.threads) - blocks = cld(max_idx, threads) - kernel(args...; threads=threads, blocks=blocks) - - CUDA.unsafe_free!(rev_idx) - return Δsrc -end diff --git a/ext/NNlibCUDA/test/runtests.jl b/ext/NNlibCUDA/test/runtests.jl index 8af877bba..fab782d49 100644 --- a/ext/NNlibCUDA/test/runtests.jl +++ b/ext/NNlibCUDA/test/runtests.jl @@ -19,7 +19,6 @@ include("fold.jl") include("pooling.jl") include("softmax.jl") include("batchnorm.jl") -include("scatter.jl") include("gather.jl") include("sampling.jl") end diff --git a/ext/NNlibCUDA/test/scatter.jl b/ext/NNlibCUDA/test/scatter.jl deleted file mode 100644 index a4977f285..000000000 --- a/ext/NNlibCUDA/test/scatter.jl +++ /dev/null @@ -1,106 +0,0 @@ -dsts = Dict( - 0 => cu([3, 4, 5, 6, 7]), - 1 => cu([3 3 4 4 5; - 5 5 6 6 7]), -) -srcs = Dict( - (0, true) => cu(ones(Int, 3, 4)), - (0, false) => cu(ones(Int, 3) * collect(1:4)'), - (1, true) => cu(ones(Int, 2, 3, 4)), - (1, false) => cu([1, 2] .* reshape(ones(Int, 3) * collect(1:4)', 1,3,4)), -) -idxs = [ - cu([1 2 3 4; - 4 2 1 3; - 3 5 5 3]), # integer index - cu([(1,) (2,) (3,) (4,); - (4,) (2,) (1,) (3,); - (3,) (5,) (5,) (3,)]), # tuple index - cu(CartesianIndex.([(1,) (2,) (3,) (4,); - (4,) (2,) (1,) (3,); - (3,) (5,) (5,) (3,)])), # CartesianIndex index -] - -types = [CuArray{Int32}, CuArray{Int64}, CuArray{Float32}, CuArray{Float64}] - - -@testset "scatter" begin - for T = types - @testset "$(T)" begin - @testset "+" begin - for idx = idxs, dims = [0, 1] - mutated = true - gputest((dst, src) -> NNlib.scatter!(+, dst, src, idx), T(copy(dsts[dims])), T(srcs[(dims, mutated)]), checkgrad=true) - - mutated = false - gputest(src -> NNlib.scatter(+, src, idx), T(srcs[(dims, mutated)]), checkgrad=true) - end - end - - @testset "-" begin - for idx = idxs, dims = [0, 1] - mutated = true - gputest((dst, src) -> NNlib.scatter!(-, dst, src, idx), T(copy(dsts[dims])), T(srcs[(dims, mutated)]), checkgrad=true) - - mutated = false - gputest(src -> NNlib.scatter(-, src, idx), T(srcs[(dims, mutated)]), checkgrad=true) - end - end - - @testset "max" begin - for idx = idxs, dims = [0, 1] - mutated = true - gputest((dst, src) -> NNlib.scatter!(max, dst, src, idx), T(copy(dsts[dims])), T(srcs[(dims, mutated)]), checkgrad=true) - - mutated = false - gputest(src -> NNlib.scatter(max, src, idx), T(srcs[(dims, mutated)]), checkgrad=true) - end - end - - @testset "min" begin - for idx = idxs, dims = [0, 1] - mutated = true - gputest((dst, src) -> NNlib.scatter!(min, dst, src, idx), T(copy(dsts[dims])), T(srcs[(dims, mutated)]), checkgrad=true) - - mutated = false - gputest(src -> NNlib.scatter(min, src, idx), T(srcs[(dims, mutated)]), checkgrad=true) - end - end - end - end - - - for T = [CuArray{Float32}, CuArray{Float64}] - @testset "$(T)" begin - @testset "*" begin - for idx = idxs, dims = [0, 1] - mutated = true - gputest((dst, src) -> NNlib.scatter!(*, dst, src, idx), T(copy(dsts[dims])), T(srcs[(dims, mutated)]), checkgrad=true) - - mutated = false - gputest(src -> NNlib.scatter(*, src, idx), T(srcs[(dims, mutated)]), checkgrad=true) - end - end - - @testset "/" begin - for idx = idxs, dims = [0, 1] - mutated = true - gputest((dst, src) -> NNlib.scatter!(/, dst, src, idx), T(copy(dsts[dims])), T(srcs[(dims, mutated)]), checkgrad=true) - - mutated = false - gputest(src -> NNlib.scatter(/, src, idx), T(srcs[(dims, mutated)]), checkgrad=true) - end - end - - @testset "mean" begin - for idx = idxs, dims = [0, 1] - mutated = true - gputest((dst, src) -> NNlib.scatter!(mean, dst, src, idx), T(copy(dsts[dims])), T(srcs[(dims, mutated)]), checkgrad=true) - - mutated = false - gputest(src -> NNlib.scatter(mean, src, idx), T(srcs[(dims, mutated)]), checkgrad=true) - end - end - end - end -end From 50cc60bb945d9837603059a7eea41ef18a9eba90 Mon Sep 17 00:00:00 2001 From: Anton Smirnov Date: Tue, 11 Apr 2023 13:38:26 +0300 Subject: [PATCH 05/10] Update tests --- test/gather.jl | 82 ++++++++++++++++++++++++------------------------ test/scatter.jl | 75 ++++++++++++++++++++++--------------------- test/upsample.jl | 2 +- 3 files changed, 79 insertions(+), 80 deletions(-) diff --git a/test/gather.jl b/test/gather.jl index 1cf88843f..e3221145b 100644 --- a/test/gather.jl +++ b/test/gather.jl @@ -1,14 +1,14 @@ using NNlib: gather, gather! function gather_testsuite(Backend) - cpu, backend = CPU(), Backend() + device(x) = adapt(Backend(), x) + gradtest_fn = Backend == CPU ? gradtest : gputest T = Float32 - gradtest_fn = backend == CPU() ? gradtest : gputest @testset "gather scalar index" begin ## 1d src, 2d index of ints -> 2d output - src = adapt(backend, T[3, 4, 5, 6, 7]) - index = adapt(backend, [ + src = device(T[3, 4, 5, 6, 7]) + index = device([ 1 2 3 4; 4 2 1 3; 3 5 5 3]) @@ -17,14 +17,14 @@ function gather_testsuite(Backend) 6 4 3 5; 5 7 7 5] - y = adapt(cpu, gather(src, index)) + y = cpu(gather(src, index)) @test y isa Array{T,2} @test size(y) == size(index) @test y == output - dst = adapt(backend, T.(zero(index))) - @test adapt(cpu, gather!(dst, src, index)) == output - dst = adapt(backend, zeros(T, 3, 5)) + dst = device(T.(zero(index))) + @test cpu(gather!(dst, src, index)) == output + dst = device(zeros(T, 3, 5)) @test_throws ArgumentError gather!(dst, src, index) if Backend == CPU @@ -35,8 +35,8 @@ function gather_testsuite(Backend) end ## 1d src, 3d index of ints -> 3d output - src = adapt(backend, T[3, 4, 5, 6, 7]) - index = adapt(backend, [ + src = device(T[3, 4, 5, 6, 7]) + index = device([ 1 2 3 4; 4 2 1 3; 3 5 5 3][:,:,1:1]) @@ -45,16 +45,16 @@ function gather_testsuite(Backend) 6 4 3 5; 5 7 7 5][:,:,1:1] - y = adapt(cpu, gather(src, index)) + y = cpu(gather(src, index)) @test y isa Array{T,3} @test size(y) == size(index) @test y == output ## 2d src, 2d index of ints -> 3d output - src = adapt(backend, T[ + src = device(T[ 3 5 7 4 6 8]) - index = adapt(backend, [ + index = device([ 1 2 3; 2 2 1; 3 1 3]) @@ -70,7 +70,7 @@ function gather_testsuite(Backend) 7 3 7 8 4 8] - y = adapt(cpu, gather(src, index)) + y = cpu(gather(src, index)) M = NNlib.typelength(eltype(index)) Nsrc = ndims(src) @test y isa Array{T,3} @@ -80,13 +80,13 @@ function gather_testsuite(Backend) @testset "gather tuple index" begin ## 2d src, 1d index of 2-tuples -> 1d output - src = adapt(backend, T[ + src = device(T[ 3 5 7 4 6 8]) - index = adapt(backend, [(1,1), (1,2), (1,3), (2,1), (2,2), (2,3)]) + index = device([(1,1), (1,2), (1,3), (2,1), (2,2), (2,3)]) output = T[3, 5, 7, 4, 6, 8] - y = adapt(cpu, gather(src, index)) + y = cpu(gather(src, index)) M = NNlib.typelength(eltype(index)) Nsrc = ndims(src) @test y isa Array{T,1} @@ -95,11 +95,11 @@ function gather_testsuite(Backend) ## 3d src, 2d index of 2-tuples -> 3d output n1, nsrc, nidx = 2, 3, 6 - src = adapt(backend, rand(T, n1, nsrc, nsrc)) - index = adapt(backend, [ + src = device(rand(T, n1, nsrc, nsrc)) + index = device([ (rand(1:nsrc), rand(1:nsrc)) for i=1:nidx, j=1:nidx]) - y = adapt(cpu, gather(src, index)) + y = cpu(gather(src, index)) M = NNlib.typelength(eltype(index)) Nsrc = ndims(src) @test y isa Array{T,3} @@ -108,13 +108,13 @@ function gather_testsuite(Backend) @testset "gather cartesian index" begin ## 2d src, 1d index of 2-tuples -> 1d output - src = adapt(backend, T[ + src = device(T[ 3 5 7 4 6 8]) - index = adapt(backend, CartesianIndex.([(1,1), (1,2), (1,3), (2,1), (2,2), (2,3)])) + index = device(CartesianIndex.([(1,1), (1,2), (1,3), (2,1), (2,2), (2,3)])) output = T[3, 5, 7, 4, 6, 8] - y = adapt(cpu, gather(src, index)) + y = cpu(gather(src, index)) M = NNlib.typelength(eltype(index)) Nsrc = ndims(src) @test y isa Array{T,1} @@ -123,11 +123,11 @@ function gather_testsuite(Backend) ## 3d src, 2d index of 2-tuples -> 3d output n1, nsrc, nidx = 2, 3, 6 - src = adapt(backend, rand(Float32, n1, nsrc, nsrc)) - index = adapt(backend, [ + src = device(rand(Float32, n1, nsrc, nsrc)) + index = device([ CartesianIndex((rand(1:nsrc), rand(1:nsrc))) for i=1:nidx, j=1:nidx]) - y = adapt(cpu, gather(src, index)) + y = cpu(gather(src, index)) M = NNlib.typelength(eltype(index)) Nsrc = ndims(src) @test y isa Array{T,3} @@ -135,44 +135,44 @@ function gather_testsuite(Backend) end @testset "gather gradient for scalar index" begin - src = adapt(backend, Float64[3, 4, 5, 6, 7]) - idx = adapt(backend, [ + src = device(Float64[3, 4, 5, 6, 7]) + idx = device([ 1 2 3 4; 4 2 1 3; 3 5 5 3]) - dst = adapt(backend, Float64[ + dst = device(Float64[ 3 4 5 6; 6 4 3 5; 5 7 7 5]) - backend == cpu ? + Backend == CPU ? gradtest_fn(xs -> gather!(dst, xs, idx), src) : gradtest_fn((d, s, i) -> gather!(d, s, i), dst, src, idx) - backend == cpu ? + Backend == CPU ? gradtest_fn(xs -> gather(xs, idx), src) : gradtest_fn((s, i) -> gather(s, i), src, idx) end @testset "gather gradient for tuple index" begin - src = adapt(backend, Float64[ + src = device(Float64[ 3 5 7 4 6 8]) - idx = adapt(backend, [(1,1), (1,2), (1,3), (2,1), (2,2), (2,3)]) - dst = adapt(backend, Float64[3, 5, 7, 4, 6, 8]) - backend == cpu ? + idx = device([(1,1), (1,2), (1,3), (2,1), (2,2), (2,3)]) + dst = device(Float64[3, 5, 7, 4, 6, 8]) + Backend == CPU ? gradtest_fn(xs -> gather!(dst, xs, idx), src) : gradtest_fn((d, s, i) -> gather!(d, s, i), dst, src, idx) - backend == cpu ? + Backend == CPU ? gradtest_fn(xs -> gather(xs, idx), src) : gradtest_fn((s, i) -> gather(s, i), src, idx) end @testset "gather(src, IJK...)" begin - x = adapt(backend, reshape([1:15;], 3, 5)) - i, j = adapt(backend, [1,2]), adapt(backend, [2,4]) + x = device(reshape([1:15;], 3, 5)) + i, j = device([1,2]), device([2,4]) y = gather(x, i, j) - @test adapt(cpu, y) == [4, 11] - y = gather(x, adapt(backend, [1, 2])) - @test adapt(cpu, y) == [ + @test cpu(y) == [4, 11] + y = gather(x, device([1, 2])) + @test cpu(y) == [ 1 4 2 5 3 6] diff --git a/test/scatter.jl b/test/scatter.jl index f60d980e4..26fc06cde 100644 --- a/test/scatter.jl +++ b/test/scatter.jl @@ -68,8 +68,7 @@ res = Dict( 4. 4. 6. 5. 5.], ) -function test_scatter(backend, types, ops; pt, ops_skip_types) - cpu = CPU() +function test_scatter(device, types, ops; pt, ops_skip_types) for T in types PT = promote_type(T, pt) @testset "$T" begin @@ -77,28 +76,28 @@ function test_scatter(backend, types, ops; pt, ops_skip_types) skip_types = get(ops_skip_types, op, []) @testset "$op" begin for idx = values(idxs), dims = [0, 1] - idx = adapt(backend, idx) - dst = adapt(backend, dsts[dims]) + idx = device(idx) + dst = device(dsts[dims]) mutated = true target_y = res[(op, dims, mutated)] - src = adapt(backend, srcs[(dims, mutated)]) + src = device(srcs[(dims, mutated)]) if op == / src = src .* T(2) end - @test adapt(cpu, scatter!(op, T.(dst), T.(src), idx)) == T.(target_y) - @test adapt(cpu, scatter!(op, T.(dst), src, idx)) == PT.(target_y) + @test cpu(scatter!(op, T.(dst), T.(src), idx)) == T.(target_y) + @test cpu(scatter!(op, T.(dst), src, idx)) == PT.(target_y) if op == / - @test adapt(cpu, scatter!(op, T.(dst), T.(src), idx)) == PT.(target_y) + @test cpu(scatter!(op, T.(dst), T.(src), idx)) == PT.(target_y) else - @test adapt(cpu, scatter!(op, copy(dst), T.(src), idx)) == PT.(target_y) + @test cpu(scatter!(op, copy(dst), T.(src), idx)) == PT.(target_y) end if T ∉ skip_types mutated = false - src = adapt(backend, srcs[(dims, mutated)]) - @test adapt(cpu, scatter(op, T.(src), idx)) == T.(res[(op, dims, mutated)]) + src = device(srcs[(dims, mutated)]) + @test cpu(scatter(op, T.(src), idx)) == T.(res[(op, dims, mutated)]) end end end @@ -108,8 +107,8 @@ function test_scatter(backend, types, ops; pt, ops_skip_types) end function scatter_testsuite(Backend) - backend = Backend() - gradtest_fn = backend == CPU() ? gradtest : gputest + device(x) = adapt(Backend(), x) + gradtest_fn = Backend == CPU ? gradtest : gputest ops_skip_types = Dict( (+) => [], @@ -117,9 +116,9 @@ function scatter_testsuite(Backend) (*) => [UInt8, Int8], max => [BigInt], min => [BigInt]) - types = if backend == CPU() + types = if Backend == CPU [UInt8, UInt32, UInt64, Int32, Int64, Float16, Float32, Float64, BigFloat, Rational] - elseif Symbol(typeof(backend)) == :CUDABackend + elseif Symbol(Backend) == :CUDABackend [Int32, Int64, Float32, Float64] else # Need LLVM 15+ for atomic fmin/fmax: @@ -127,26 +126,26 @@ function scatter_testsuite(Backend) # But fmin/fmax can be done by reinterpreting an array to `UInt`. [Int32, Int64, UInt32, UInt64] end - ops = backend == CPU() ? + ops = Backend == CPU ? (+, -, max, min, *) : (+, -, max, min) - test_scatter(backend, types, ops; pt=Int, ops_skip_types) + test_scatter(device, types, ops; pt=Int, ops_skip_types) - types = backend == CPU() ? + types = Backend == CPU ? [Float16, Float32, BigFloat, Rational] : [Float32, Float64] - ops = if backend == CPU() + ops = if Backend == CPU (/, mean) - elseif Symbol(typeof(backend)) == :CUDABackend + elseif Symbol(Backend) == :CUDABackend (*, /, mean) else # LLVM does not support atomic fmul/fdiv: # https://llvm.org/docs/LangRef.html#atomicrmw-instruction (mean,) end - test_scatter(backend, types, ops; pt=Float64, ops_skip_types=Dict()) + test_scatter(device, types, ops; pt=Float64, ops_skip_types=Dict()) - if backend == CPU() + if Backend == CPU @testset "scatter exceptions" begin idx = [1 2 3 4; 4 2 1 3; 6 7 8 9] @test_throws AssertionError scatter!(+, copy(dsts[0]), srcs[(1, true)], idxs[:int]) @@ -159,51 +158,51 @@ function scatter_testsuite(Backend) fdm(op) = op == min ? :backward : :forward @testset "dstsize" begin - idx = adapt(backend, [2, 2, 3, 4, 4]) - src = adapt(backend, ones(T, 3, 5)) + idx = device([2, 2, 3, 4, 4]) + src = device(ones(T, 3, 5)) y = scatter(+, src, idx, dstsize = (3, 6)) @test eltype(y) == T @test size(y) == (3, 6) - backend == CPU() ? + Backend == CPU ? gradtest_fn(x -> scatter(+, x, idx; dstsize=(3, 6)), src) : gradtest_fn((x, i) -> scatter(+, x, i; dstsize=(3, 6)), src, idx) end @testset "∂dst" begin - ops = if backend == CPU() || Symbol(typeof(backend)) == :CUDABackend + ops = if Backend == CPU || Symbol(Backend) == :CUDABackend (+, -, *, /, mean, max, min) else (+, -, mean, max, min) end for op in ops, i in (0, 1) PT = ( # If not CPU and CUDA -> use Int64 for min/max. - backend != CPU() && - Symbol(typeof(backend)) != :CUDABackend && + Backend != CPU && + Symbol(Backend) != :CUDABackend && (op == max || op == min)) ? Int64 : T - src = adapt(backend, srcs[(i, true)]) - idx = adapt(backend, idxs[:int]) - dst = adapt(backend, PT.(dsts[i])) - backend == CPU() ? + src = device(srcs[(i, true)]) + idx = device(idxs[:int]) + dst = device(PT.(dsts[i])) + Backend == CPU ? gradtest_fn(x -> scatter!(op, copy(x), src, idx), dst; fdm=fdm(op)) : gradtest_fn((x, s, i) -> scatter!(op, x, s, i), dst, src, idx) end end @testset "∂src" begin - ops = if backend == CPU() || Symbol(typeof(backend)) == :CUDABackend + ops = if Backend == CPU || Symbol(Backend) == :CUDABackend (+, -, *, /, mean, max, min) else (+, -, mean, max, min) end for op in ops, i in (0, 1) PT = ( # If not CPU and CUDA -> use Int64 for min/max. - backend != CPU() && - Symbol(typeof(backend)) != :CUDABackend && + Backend != CPU && + Symbol(Backend) != :CUDABackend && (op == max || op == min)) ? Int64 : T - src = PT.(adapt(backend, srcs[(i, false)])) - idx = adapt(backend, idxs[:int]) - backend == CPU() ? + src = PT.(device(srcs[(i, false)])) + idx = device(idxs[:int]) + Backend == CPU ? gradtest_fn(xs -> scatter(op, xs, idx), src; fdm=fdm(op)) : gradtest_fn((xs, i) -> scatter(op, xs, i), src, idx) end diff --git a/test/upsample.jl b/test/upsample.jl index 28109d4b2..c13d64958 100644 --- a/test/upsample.jl +++ b/test/upsample.jl @@ -1,6 +1,6 @@ function upsample_testsuite(Backend) device(x) = adapt(Backend(), x) - gradtest_fn = KernelAbstractions.isgpu(Backend()) ? gputest : gradtest + gradtest_fn = Backend == CPU ? gradtest : gputest T = Float32 # TODO test against all supported eltypes for each backend. atol = T == Float32 ? 1e-3 : 1e-6 From ac0e31067088eea97ee699d986deb161ade7c991 Mon Sep 17 00:00:00 2001 From: Anton Smirnov Date: Tue, 11 Apr 2023 15:52:39 +0300 Subject: [PATCH 06/10] Retain NNlibCUDA scatter kernels --- ext/NNlibCUDA/src/NNlibCUDA.jl | 1 + ext/NNlibCUDA/src/scatter.jl | 188 +++++++++++++++++++++++++++++++++ ext/NNlibCUDA/test/gather.jl | 22 ++-- ext/NNlibCUDA/test/runtests.jl | 1 + ext/NNlibCUDA/test/scatter.jl | 106 +++++++++++++++++++ test/runtests.jl | 8 +- 6 files changed, 311 insertions(+), 15 deletions(-) create mode 100644 ext/NNlibCUDA/src/scatter.jl create mode 100644 ext/NNlibCUDA/test/scatter.jl diff --git a/ext/NNlibCUDA/src/NNlibCUDA.jl b/ext/NNlibCUDA/src/NNlibCUDA.jl index c55a8ab10..e1c3c225c 100644 --- a/ext/NNlibCUDA/src/NNlibCUDA.jl +++ b/ext/NNlibCUDA/src/NNlibCUDA.jl @@ -12,6 +12,7 @@ include("batchedadjtrans.jl") include("batchedmul.jl") include("ctc.jl") include("fold.jl") +include("scatter.jl") include("utils.jl") include("cudnn/cudnn.jl") include("cudnn/conv.jl") diff --git a/ext/NNlibCUDA/src/scatter.jl b/ext/NNlibCUDA/src/scatter.jl new file mode 100644 index 000000000..874207c77 --- /dev/null +++ b/ext/NNlibCUDA/src/scatter.jl @@ -0,0 +1,188 @@ +# supported op: +, -, *, /, max, min, &, |, mean + +function scatter_kernel!(op::OP, dst, src, idx) where OP + index = threadIdx().x + (blockIdx().x - 1) * blockDim().x + + @inbounds if index <= length(idx) + CUDA.@atomic dst[idx[index]...] = op(dst[idx[index]...], src[index]) + end + return nothing +end + +function scatter_kernel!(op::OP, dst, src, idx::CUDA.CuDeviceArray{<:CartesianIndex}) where OP + index = threadIdx().x + (blockIdx().x - 1) * blockDim().x + + @inbounds if index <= length(idx) + li = Base._to_linear_index(dst, Tuple(idx[index])...) + CUDA.@atomic dst[li] = op(dst[li], src[index]) + end + return nothing +end + +function scatter_kernel!(op::OP, dst, src, idx, max_idx, max_dims_idx, dims_size) where OP + index = threadIdx().x + (blockIdx().x - 1) * blockDim().x + + @inbounds if index <= max_idx + j, k = divrem(index-1, max_dims_idx) + dims_i = CartesianIndices(dims_size)[k+1] + CUDA.@atomic dst[Tuple(dims_i)..., idx[j+1]...] = op(dst[Tuple(dims_i)..., idx[j+1]...], src[index]) + end + return nothing +end + +function scatter_kernel!(op::OP, dst, src, idx::CUDA.CuDeviceArray{<:CartesianIndex}, + max_idx, max_dims_idx, dims_size) where OP + index = threadIdx().x + (blockIdx().x - 1) * blockDim().x + + @inbounds if index <= max_idx + j, k = divrem(index-1, max_dims_idx) + dims_i = CartesianIndices(dims_size)[k+1] + li = Base._to_linear_index(dst, Tuple(dims_i)..., Tuple(idx[j+1])...) + CUDA.@atomic dst[li] = op(dst[li], src[index]) + end + return nothing +end + +function NNlib.scatter!(op::OP, dst::AnyCuArray, src::AnyCuArray, idx::AnyCuArray) where OP + dims = NNlib.scatter_dims(dst, src, idx) + args = if dims == 0 + max_idx = length(idx) + op, dst, src, idx + else + dims_size = size(dst)[1:dims] + max_dims_idx = prod(dims_size) + max_idx = max_dims_idx * length(idx) + op, dst, src, idx, max_idx, max_dims_idx, dims_size + end + + kernel = @cuda launch=false scatter_kernel!(args...) + config = launch_configuration(kernel.fun; max_threads=256) + threads = min(max_idx, config.threads) + blocks = cld(max_idx, threads) + kernel(args...; threads=threads, blocks=blocks) + return dst +end + +function NNlib.scatter!(op::typeof(mean), dst::AnyCuArray, src::AnyCuArray, idx::AnyCuArray) + Ns = NNlib.scatter!(+, zero(dst), one.(src), idx) + dst_ = NNlib.scatter!(+, zero(dst), src, idx) + dst .+= NNlib.safe_div.(dst_, Ns) + return dst +end + + +## Gradients + +function ∇scatter_src_kernel!(op::OP, Δsrc, src, idx, + rev_idx, max_idx, T::Type{TT}) where {OP,TT} + index = threadIdx().x + (blockIdx().x - 1) * blockDim().x + + @inbounds if index <= max_idx + cart_j = CartesianIndices(idx)[index] + # get aggregating indeices, which is to be aggregated together, and itself index + inds = rev_idx[idx[cart_j]...] + # multiply all values to be aggregated but not itself + x = one(T) + for k in inds + x *= src[k] + end + x /= src[cart_j] + # apply `op` on `Δsrc[i, k]` and `x` + Δsrc[cart_j] = op(Δsrc[cart_j], x) + end + return nothing +end + +function ∇scatter_src_kernel!(op::OP, Δsrc, src, idx::CUDA.CuDeviceArray{<:CartesianIndex}, + rev_idx, max_idx, T::Type{TT}) where {OP,TT} + index = threadIdx().x + (blockIdx().x - 1) * blockDim().x + + @inbounds if index <= max_idx + cart_j = CartesianIndices(idx)[index] + # get aggregating indeices, which is to be aggregated together, and itself index + inds = rev_idx[Tuple(idx[cart_j])...] + # multiply all values to be aggregated but not itself + x = one(T) + for k in inds + x *= src[k] + end + x /= src[cart_j] + # apply `op` on `Δsrc[i, k]` and `x` + Δsrc[cart_j] = op(Δsrc[cart_j], x) + end + return nothing +end + +function ∇scatter_src_kernel!(op::OP, Δsrc, src, idx, + rev_idx, pre_cart_idx, max_dims_idx, max_idx, T::Type{TT}) where {OP,TT} + index = threadIdx().x + (blockIdx().x - 1) * blockDim().x + + @inbounds if index <= max_idx + i, j = fldmod1(index, max_dims_idx) + cart_i = CartesianIndices(idx)[i] + cart_j = pre_cart_idx[j] + # get aggregating indeices, which is to be aggregated together, and itself index + inds = rev_idx[idx[cart_i]...] + # multiply all values to be aggregated but not itself + x = one(T) + for k in inds + jk = Base._to_linear_index(src, Tuple(cart_j)..., Tuple(k)...) + x *= src[jk] + end + x /= src[index] + # apply `op` on `Δsrc[i, k]` and `x` + Δsrc[index] = op(Δsrc[index], x) + end + return nothing +end + +function ∇scatter_src_kernel!(op::OP, Δsrc, src, idx::CUDA.CuDeviceArray{<:CartesianIndex}, + rev_idx, pre_cart_idx, max_dims_idx, max_idx, T::Type{TT}) where {OP,TT} + index = threadIdx().x + (blockIdx().x - 1) * blockDim().x + + @inbounds if index <= max_idx + i, j = fldmod1(index, max_dims_idx) + cart_i = CartesianIndices(idx)[i] + cart_j = pre_cart_idx[j] + # get aggregating indeices, which is to be aggregated together, and itself index + inds = rev_idx[Tuple(idx[cart_i])...] + # multiply all values to be aggregated but not itself + x = one(T) + for k in inds + jk = Base._to_linear_index(src, Tuple(cart_j)..., Tuple(k)...) + x *= src[jk] + end + x /= src[index] + # apply `op` on `Δsrc[i, k]` and `x` + Δsrc[index] = op(Δsrc[index], x) + end + return nothing +end + +function NNlib.∇scatter_src(op::Union{typeof(*),typeof(/)}, Δ, dst, + src::AnyCuArray{Tsrc,Nsrc}, + idx::AnyCuArray{Tidx,Nidx}) where {Tsrc,Tidx,Nsrc,Nidx} + dims = Nsrc - Nidx + Δsrc = NNlib.modify_src(op, NNlib.gather(Δ, idx), src) + rev_idx = NNlib.reverse_indices(idx) + rev_idx = CuArray(map(CUDA.cudaconvert, rev_idx)) + + if dims == 0 + max_idx = length(idx) + args = op, Δsrc, src, idx, rev_idx, max_idx, Tsrc + else + pre_cart_idx = CartesianIndices(axes(src)[1:dims]) + max_dims_idx = length(pre_cart_idx) + max_idx = max_dims_idx * length(idx) + args = op, Δsrc, src, idx, rev_idx, pre_cart_idx, max_dims_idx, max_idx, Tsrc + end + + kernel = @cuda launch=false ∇scatter_src_kernel!(args...) + config = launch_configuration(kernel.fun; max_threads=256) + threads = min(max_idx, config.threads) + blocks = cld(max_idx, threads) + kernel(args...; threads=threads, blocks=blocks) + + CUDA.unsafe_free!(rev_idx) + return Δsrc +end diff --git a/ext/NNlibCUDA/test/gather.jl b/ext/NNlibCUDA/test/gather.jl index f200f77b7..36d42dbcc 100644 --- a/ext/NNlibCUDA/test/gather.jl +++ b/ext/NNlibCUDA/test/gather.jl @@ -1,7 +1,7 @@ -@testset "gather" begin +@testset "gather" begin T = Float32 CT = CuArray{Float32} - + ## 1d src, 2d index of ints -> 2d output src = CT([3, 4, 5, 6, 7]) index = cu([1 2 3 4; @@ -10,14 +10,14 @@ output = CT([3 4 5 6; 6 4 3 5; 5 7 7 5]) - + y = NNlib.gather(src, index) @test y isa CuArray{Float32,2} @test size(y) == size(index) gputest(src -> NNlib.gather(src, index), src, checkgrad=true) @test NNlib.gather!(CUDA.zeros(T, size(index)...), src, index) == output @test_throws ArgumentError NNlib.gather!(zeros(T, 3, 5), src, index) - + ## 1d src, 2d index of tuples -> 2d output src = CT([3, 4, 5, 6, 7]) index = cu([(1,) (2,) (3,) (4,); @@ -26,14 +26,14 @@ output = CT([3 4 5 6; 6 4 3 5; 5 7 7 5]) - + y = NNlib.gather(src, index) @test y isa CuArray{Float32,2} @test size(y) == size(index) gputest(src -> NNlib.gather(src, index), src, checkgrad=true) @test NNlib.gather!(CUDA.zeros(T, size(index)...), src, index) == output @test_throws ArgumentError NNlib.gather!(zeros(T, 3, 5), src, index) - + ## 1d src, 2d index of CartesianIndex -> 2d output src = CT([3, 4, 5, 6, 7]) index = cu(CartesianIndex.([(1,) (2,) (3,) (4,); @@ -42,7 +42,7 @@ output = CT([3 4 5 6; 6 4 3 5; 5 7 7 5]) - + y = NNlib.gather(src, index) @test y isa CuArray{Float32,2} @test size(y) == size(index) @@ -66,7 +66,7 @@ ## 2d src, 2d index of ints -> 3d output - src = CT([3 5 7 + src = CT([3 5 7 4 6 8]) index = cu([1 2 3; 2 2 1; @@ -79,14 +79,14 @@ output[:,:,2] = [5 5 3 6 6 4] - + output[:,:,3] = [7 3 7 8 4 8] - + y = NNlib.gather(src, index) M = NNlib.typelength(eltype(index)) Nsrc = ndims(src) @test y isa CuArray{Float32,3} - @test size(y) == (size(src)[1:Nsrc-M]..., size(index)...) + @test size(y) == (size(src)[1:Nsrc-M]..., size(index)...) gputest(src -> NNlib.gather(src, index), src, checkgrad=true) end diff --git a/ext/NNlibCUDA/test/runtests.jl b/ext/NNlibCUDA/test/runtests.jl index fab782d49..8af877bba 100644 --- a/ext/NNlibCUDA/test/runtests.jl +++ b/ext/NNlibCUDA/test/runtests.jl @@ -19,6 +19,7 @@ include("fold.jl") include("pooling.jl") include("softmax.jl") include("batchnorm.jl") +include("scatter.jl") include("gather.jl") include("sampling.jl") end diff --git a/ext/NNlibCUDA/test/scatter.jl b/ext/NNlibCUDA/test/scatter.jl new file mode 100644 index 000000000..a4977f285 --- /dev/null +++ b/ext/NNlibCUDA/test/scatter.jl @@ -0,0 +1,106 @@ +dsts = Dict( + 0 => cu([3, 4, 5, 6, 7]), + 1 => cu([3 3 4 4 5; + 5 5 6 6 7]), +) +srcs = Dict( + (0, true) => cu(ones(Int, 3, 4)), + (0, false) => cu(ones(Int, 3) * collect(1:4)'), + (1, true) => cu(ones(Int, 2, 3, 4)), + (1, false) => cu([1, 2] .* reshape(ones(Int, 3) * collect(1:4)', 1,3,4)), +) +idxs = [ + cu([1 2 3 4; + 4 2 1 3; + 3 5 5 3]), # integer index + cu([(1,) (2,) (3,) (4,); + (4,) (2,) (1,) (3,); + (3,) (5,) (5,) (3,)]), # tuple index + cu(CartesianIndex.([(1,) (2,) (3,) (4,); + (4,) (2,) (1,) (3,); + (3,) (5,) (5,) (3,)])), # CartesianIndex index +] + +types = [CuArray{Int32}, CuArray{Int64}, CuArray{Float32}, CuArray{Float64}] + + +@testset "scatter" begin + for T = types + @testset "$(T)" begin + @testset "+" begin + for idx = idxs, dims = [0, 1] + mutated = true + gputest((dst, src) -> NNlib.scatter!(+, dst, src, idx), T(copy(dsts[dims])), T(srcs[(dims, mutated)]), checkgrad=true) + + mutated = false + gputest(src -> NNlib.scatter(+, src, idx), T(srcs[(dims, mutated)]), checkgrad=true) + end + end + + @testset "-" begin + for idx = idxs, dims = [0, 1] + mutated = true + gputest((dst, src) -> NNlib.scatter!(-, dst, src, idx), T(copy(dsts[dims])), T(srcs[(dims, mutated)]), checkgrad=true) + + mutated = false + gputest(src -> NNlib.scatter(-, src, idx), T(srcs[(dims, mutated)]), checkgrad=true) + end + end + + @testset "max" begin + for idx = idxs, dims = [0, 1] + mutated = true + gputest((dst, src) -> NNlib.scatter!(max, dst, src, idx), T(copy(dsts[dims])), T(srcs[(dims, mutated)]), checkgrad=true) + + mutated = false + gputest(src -> NNlib.scatter(max, src, idx), T(srcs[(dims, mutated)]), checkgrad=true) + end + end + + @testset "min" begin + for idx = idxs, dims = [0, 1] + mutated = true + gputest((dst, src) -> NNlib.scatter!(min, dst, src, idx), T(copy(dsts[dims])), T(srcs[(dims, mutated)]), checkgrad=true) + + mutated = false + gputest(src -> NNlib.scatter(min, src, idx), T(srcs[(dims, mutated)]), checkgrad=true) + end + end + end + end + + + for T = [CuArray{Float32}, CuArray{Float64}] + @testset "$(T)" begin + @testset "*" begin + for idx = idxs, dims = [0, 1] + mutated = true + gputest((dst, src) -> NNlib.scatter!(*, dst, src, idx), T(copy(dsts[dims])), T(srcs[(dims, mutated)]), checkgrad=true) + + mutated = false + gputest(src -> NNlib.scatter(*, src, idx), T(srcs[(dims, mutated)]), checkgrad=true) + end + end + + @testset "/" begin + for idx = idxs, dims = [0, 1] + mutated = true + gputest((dst, src) -> NNlib.scatter!(/, dst, src, idx), T(copy(dsts[dims])), T(srcs[(dims, mutated)]), checkgrad=true) + + mutated = false + gputest(src -> NNlib.scatter(/, src, idx), T(srcs[(dims, mutated)]), checkgrad=true) + end + end + + @testset "mean" begin + for idx = idxs, dims = [0, 1] + mutated = true + gputest((dst, src) -> NNlib.scatter!(mean, dst, src, idx), T(copy(dsts[dims])), T(srcs[(dims, mutated)]), checkgrad=true) + + mutated = false + gputest(src -> NNlib.scatter(mean, src, idx), T(srcs[(dims, mutated)]), checkgrad=true) + end + end + end + end +end diff --git a/test/runtests.jl b/test/runtests.jl index a292a072c..ca9dd3def 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -46,15 +46,15 @@ end @testset "NNlib.jl" verbose=true begin @testset verbose=true "Test Suite" begin - @testset "CPU" begin - nnlib_testsuite(CPU) - end + # @testset "CPU" begin + # nnlib_testsuite(CPU) + # end if get(ENV, "NNLIB_TEST_CUDA", "false") == "true" using CUDA if CUDA.functional() @testset "CUDABackend" begin - nnlib_testsuite(CUDABackend) + nnlib_testsuite(CUDABackend; skip_tests=Set(("Scatter", "Gather"))) end else @info "CUDA.jl is not functional. Skipping test suite for CUDABackend." From 99f6ed4098df95e21fcea0a4797780daf1d32fb9 Mon Sep 17 00:00:00 2001 From: Anton Smirnov Date: Tue, 11 Apr 2023 16:00:22 +0300 Subject: [PATCH 07/10] Fixup --- test/runtests.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index ca9dd3def..3d7786ccc 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -46,9 +46,9 @@ end @testset "NNlib.jl" verbose=true begin @testset verbose=true "Test Suite" begin - # @testset "CPU" begin - # nnlib_testsuite(CPU) - # end + @testset "CPU" begin + nnlib_testsuite(CPU) + end if get(ENV, "NNLIB_TEST_CUDA", "false") == "true" using CUDA From 19040c038baabca5a8e63b4abf64a8b88a8f358b Mon Sep 17 00:00:00 2001 From: Anton Smirnov Date: Tue, 11 Apr 2023 18:43:59 +0300 Subject: [PATCH 08/10] Add at-inbounds --- src/gather.jl | 2 +- src/scatter.jl | 40 ++++++++++++++++++++++------------------ 2 files changed, 23 insertions(+), 19 deletions(-) diff --git a/src/gather.jl b/src/gather.jl index a25b4fc31..1ad69df24 100644 --- a/src/gather.jl +++ b/src/gather.jl @@ -125,7 +125,7 @@ end ) i = @index(Global) j, k = divrem(i - 1, max_dims_idx) - dst[i] = src[dim_ids[k + 1], Tuple(idx[j + 1])...] + @inbounds dst[i] = src[dim_ids[k + 1], Tuple(idx[j + 1])...] end ∇gather_src(Δ, src_size, idx) = scatter!(+, fill!(similar(Δ, eltype(Δ), src_size), 0), Δ, idx) diff --git a/src/scatter.jl b/src/scatter.jl index b4e5e54dd..ad92ea1a8 100644 --- a/src/scatter.jl +++ b/src/scatter.jl @@ -108,8 +108,8 @@ end @kernel function _scatter!(op::OP, dst, src, idxs) where OP i = @index(Global) - idx = Tuple(idxs[i]) - Atomix.modify!(Atomix.IndexableRef(dst, idx), op, src[i]) + @inbounds idx = Tuple(idxs[i]) + @inbounds Atomix.modify!(Atomix.IndexableRef(dst, idx), op, src[i]) # FIXME `@atomic` macro silently fails to perform atomic op below # @atomic dst[idx...] = op(dst[idx...], src[i]) end @@ -119,8 +119,8 @@ end ) where OP i = @index(Global) j, k = divrem(i - 1, max_dims_idx) - idx = (Tuple(dim_ids[k + 1])..., Tuple(idxs[j + 1])...) - Atomix.modify!(Atomix.IndexableRef(dst, idx), op, src[i]) + @inbounds idx = (Tuple(dim_ids[k + 1])..., Tuple(idxs[j + 1])...) + @inbounds Atomix.modify!(Atomix.IndexableRef(dst, idx), op, src[i]) # FIXME # dim_i = Tuple(dim_ids[k + 1]) # idx = idxs[j + 1] @@ -246,13 +246,15 @@ end @kernel function _∇scatter_src(op, Δsrc, src::AbstractArray{T}, idx, rev_idx) where T i = @index(Global) cart_j = CartesianIndices(idx)[i] - inds = rev_idx[Tuple(idx[cart_j])...] - x = one(T) - for k in inds - x *= src[k] + @inbounds begin + inds = rev_idx[Tuple(idx[cart_j])...] + x = one(T) + for k in inds + x *= src[k] + end + x /= src[cart_j] + Δsrc[cart_j] = op(Δsrc[cart_j], x) end - x /= src[cart_j] - Δsrc[cart_j] = op(Δsrc[cart_j], x) end @kernel function _∇scatter_src( @@ -261,15 +263,17 @@ end ) where T i = @index(Global) j, k = fldmod1(i, max_dims_idx) - cart_j = CartesianIndices(idx)[j] - cart_k = dim_ids[k] - inds = rev_idx[Tuple(cart_j)...] - x = one(T) - for s in inds - x *= src[Tuple(cart_k)..., Tuple(s)...] + @inbounds begin + cart_j = CartesianIndices(idx)[j] + cart_k = dim_ids[k] + inds = rev_idx[Tuple(cart_j)...] + x = one(T) + for s in inds + x *= src[Tuple(cart_k)..., Tuple(s)...] + end + x /= src[i] + Δsrc[i] = op(Δsrc[i], x) end - x /= src[i] - Δsrc[i] = op(Δsrc[i], x) end function ∇scatter_src( From 232111e603e29cd3e4a7d4a9b1896558677eda43 Mon Sep 17 00:00:00 2001 From: Anton Smirnov Date: Tue, 11 Apr 2023 19:15:06 +0300 Subject: [PATCH 09/10] Add compat --- Project.toml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/Project.toml b/Project.toml index 8bc81e773..1782ec8c6 100644 --- a/Project.toml +++ b/Project.toml @@ -23,7 +23,9 @@ NNlibAMDGPUExt = "AMDGPU" [compat] AMDGPU = "0.4.8" Adapt = "2, 3.2" +Atomix = "0.1" ChainRulesCore = "1.13" +GPUArraysCore = "0.1" KernelAbstractions = "0.9" Requires = "0.5, 1.0" julia = "1.6" From 1e1ced20e6cb5ace61c106a98b1ccec28d423371 Mon Sep 17 00:00:00 2001 From: Anton Smirnov Date: Tue, 11 Apr 2023 20:20:02 +0300 Subject: [PATCH 10/10] Use KA unsafe free --- Project.toml | 2 +- src/scatter.jl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index 1782ec8c6..000a29a7a 100644 --- a/Project.toml +++ b/Project.toml @@ -26,7 +26,7 @@ Adapt = "2, 3.2" Atomix = "0.1" ChainRulesCore = "1.13" GPUArraysCore = "0.1" -KernelAbstractions = "0.9" +KernelAbstractions = "0.9.2" Requires = "0.5, 1.0" julia = "1.6" diff --git a/src/scatter.jl b/src/scatter.jl index ad92ea1a8..6057e4528 100644 --- a/src/scatter.jl +++ b/src/scatter.jl @@ -239,7 +239,7 @@ function ∇scatter_src( end _∇scatter_src(KernelAbstractions.get_backend(src))( op, Δsrc, src, idx, rev_idx, args...; ndrange) - # TODO KernelAbstractions.unsafe_free!(rev_idx) + KernelAbstractions.unsafe_free!(rev_idx) return Δsrc end