diff --git a/Project.toml b/Project.toml index e8c55345e..000a29a7a 100644 --- a/Project.toml +++ b/Project.toml @@ -4,7 +4,9 @@ version = "0.8.19" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" +Atomix = "a9b6321e-bd34-4604-b9c9-b65b8de01458" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" @@ -21,8 +23,10 @@ NNlibAMDGPUExt = "AMDGPU" [compat] AMDGPU = "0.4.8" Adapt = "2, 3.2" +Atomix = "0.1" ChainRulesCore = "1.13" -KernelAbstractions = "0.9" +GPUArraysCore = "0.1" +KernelAbstractions = "0.9.2" Requires = "0.5, 1.0" julia = "1.6" diff --git a/ext/NNlibCUDA/src/NNlibCUDA.jl b/ext/NNlibCUDA/src/NNlibCUDA.jl index b0d5ea9e7..e1c3c225c 100644 --- a/ext/NNlibCUDA/src/NNlibCUDA.jl +++ b/ext/NNlibCUDA/src/NNlibCUDA.jl @@ -13,7 +13,6 @@ include("batchedmul.jl") include("ctc.jl") include("fold.jl") include("scatter.jl") -include("gather.jl") include("utils.jl") include("cudnn/cudnn.jl") include("cudnn/conv.jl") diff --git a/ext/NNlibCUDA/src/gather.jl b/ext/NNlibCUDA/src/gather.jl deleted file mode 100644 index 5690dfc59..000000000 --- a/ext/NNlibCUDA/src/gather.jl +++ /dev/null @@ -1,65 +0,0 @@ -function gather_check_dims(X::AbstractArray{Tx,Nx}, - Y::AbstractArray{Ty,Ny}, - idx::AbstractArray{Tidx,Nidx}) where - {Tx,Ty,Tidx<:IntOrIntTuple,Nx,Ny,Nidx} - M = NNlib.typelength(Tidx) - dims = gather_check_dims(Nx, Ny, M, Nidx) - size(X)[1:dims] == size(Y)[1:dims] || throw(ArgumentError("Incompatible input shapes.")) - size(Y)[dims+1:end] == size(idx) || throw(ArgumentError("Incompatible input shapes.")) - return dims -end - -function gather_check_dims(X::AbstractArray{Tx,Nx}, - Y::AbstractArray{Ty,Ny}, - idx::AbstractArray{CartesianIndex{M},Nidx}) where - {Tx,Ty,Nx,Ny,M,Nidx} - dims = gather_check_dims(Nx, Ny, M, Nidx) - size(X)[1:dims] == size(Y)[1:dims] || throw(ArgumentError("Incompatible input shapes.")) - size(Y)[dims+1:end] == size(idx) || throw(ArgumentError("Incompatible input shapes.")) - return dims -end - -function gather_check_dims(Nx, Ny, M, Nidx) - @assert Nx - M == Ny - Nidx "Incompatible input shapes of (dst, src, idx) = ($Nx, $Ny, $Nidx)." - dims = Nx - M - dims < 0 && throw(ArgumentError("dims must be non-negative but got dims=$dims.")) - return dims -end - -function gather_kernel!(dst, src, idx, max_idx, max_dims_idx, dims_size) - index = threadIdx().x + (blockIdx().x - 1) * blockDim().x - - @inbounds if index <= max_idx - j, k = divrem(index-1, max_dims_idx) - dims_i = CartesianIndices(dims_size)[k+1] - dst[index] = src[dims_i, idx[j+1]...] - end - return nothing -end - -function gather_kernel!(dst, src, idx::CUDA.CuDeviceArray{<:CartesianIndex}, max_idx, max_dims_idx, dims_size) - index = threadIdx().x + (blockIdx().x - 1) * blockDim().x - - @inbounds if index <= max_idx - j, k = divrem(index-1, max_dims_idx) - dims_i = CartesianIndices(dims_size)[k+1] - li = Base._to_linear_index(src, Tuple(dims_i)..., Tuple(idx[j+1])...) - dst[index] = src[li] - end - return nothing -end - -function NNlib.gather!(dst::AnyCuArray, src::AnyCuArray, idx::AnyCuArray) - dims = gather_check_dims(src, dst, idx) - dims_size = size(src)[1:dims] - max_dims_idx = prod(dims_size) - max_idx = max_dims_idx * length(idx) - args = dst, src, idx, max_idx, max_dims_idx, dims_size - - kernel = @cuda launch=false gather_kernel!(args...) - config = launch_configuration(kernel.fun; max_threads=256) - threads = min(max_idx, config.threads) - blocks = cld(max_idx, threads) - kernel(args...; threads=threads, blocks=blocks) - return dst -end diff --git a/ext/NNlibCUDA/src/scatter.jl b/ext/NNlibCUDA/src/scatter.jl index 6a9c5ba87..874207c77 100644 --- a/ext/NNlibCUDA/src/scatter.jl +++ b/ext/NNlibCUDA/src/scatter.jl @@ -30,7 +30,7 @@ function scatter_kernel!(op::OP, dst, src, idx, max_idx, max_dims_idx, dims_size return nothing end -function scatter_kernel!(op::OP, dst, src, idx::CUDA.CuDeviceArray{<:CartesianIndex}, +function scatter_kernel!(op::OP, dst, src, idx::CUDA.CuDeviceArray{<:CartesianIndex}, max_idx, max_dims_idx, dims_size) where OP index = threadIdx().x + (blockIdx().x - 1) * blockDim().x @@ -73,7 +73,7 @@ end ## Gradients -function ∇scatter_src_kernel!(op::OP, Δsrc, src, idx, +function ∇scatter_src_kernel!(op::OP, Δsrc, src, idx, rev_idx, max_idx, T::Type{TT}) where {OP,TT} index = threadIdx().x + (blockIdx().x - 1) * blockDim().x @@ -93,7 +93,7 @@ function ∇scatter_src_kernel!(op::OP, Δsrc, src, idx, return nothing end -function ∇scatter_src_kernel!(op::OP, Δsrc, src, idx::CUDA.CuDeviceArray{<:CartesianIndex}, +function ∇scatter_src_kernel!(op::OP, Δsrc, src, idx::CUDA.CuDeviceArray{<:CartesianIndex}, rev_idx, max_idx, T::Type{TT}) where {OP,TT} index = threadIdx().x + (blockIdx().x - 1) * blockDim().x @@ -113,7 +113,7 @@ function ∇scatter_src_kernel!(op::OP, Δsrc, src, idx::CUDA.CuDeviceArray{<:Ca return nothing end -function ∇scatter_src_kernel!(op::OP, Δsrc, src, idx, +function ∇scatter_src_kernel!(op::OP, Δsrc, src, idx, rev_idx, pre_cart_idx, max_dims_idx, max_idx, T::Type{TT}) where {OP,TT} index = threadIdx().x + (blockIdx().x - 1) * blockDim().x @@ -160,13 +160,13 @@ function ∇scatter_src_kernel!(op::OP, Δsrc, src, idx::CUDA.CuDeviceArray{<:Ca end function NNlib.∇scatter_src(op::Union{typeof(*),typeof(/)}, Δ, dst, - src::AnyCuArray{Tsrc,Nsrc}, + src::AnyCuArray{Tsrc,Nsrc}, idx::AnyCuArray{Tidx,Nidx}) where {Tsrc,Tidx,Nsrc,Nidx} dims = Nsrc - Nidx Δsrc = NNlib.modify_src(op, NNlib.gather(Δ, idx), src) rev_idx = NNlib.reverse_indices(idx) rev_idx = CuArray(map(CUDA.cudaconvert, rev_idx)) - + if dims == 0 max_idx = length(idx) args = op, Δsrc, src, idx, rev_idx, max_idx, Tsrc diff --git a/ext/NNlibCUDA/test/gather.jl b/ext/NNlibCUDA/test/gather.jl index f200f77b7..36d42dbcc 100644 --- a/ext/NNlibCUDA/test/gather.jl +++ b/ext/NNlibCUDA/test/gather.jl @@ -1,7 +1,7 @@ -@testset "gather" begin +@testset "gather" begin T = Float32 CT = CuArray{Float32} - + ## 1d src, 2d index of ints -> 2d output src = CT([3, 4, 5, 6, 7]) index = cu([1 2 3 4; @@ -10,14 +10,14 @@ output = CT([3 4 5 6; 6 4 3 5; 5 7 7 5]) - + y = NNlib.gather(src, index) @test y isa CuArray{Float32,2} @test size(y) == size(index) gputest(src -> NNlib.gather(src, index), src, checkgrad=true) @test NNlib.gather!(CUDA.zeros(T, size(index)...), src, index) == output @test_throws ArgumentError NNlib.gather!(zeros(T, 3, 5), src, index) - + ## 1d src, 2d index of tuples -> 2d output src = CT([3, 4, 5, 6, 7]) index = cu([(1,) (2,) (3,) (4,); @@ -26,14 +26,14 @@ output = CT([3 4 5 6; 6 4 3 5; 5 7 7 5]) - + y = NNlib.gather(src, index) @test y isa CuArray{Float32,2} @test size(y) == size(index) gputest(src -> NNlib.gather(src, index), src, checkgrad=true) @test NNlib.gather!(CUDA.zeros(T, size(index)...), src, index) == output @test_throws ArgumentError NNlib.gather!(zeros(T, 3, 5), src, index) - + ## 1d src, 2d index of CartesianIndex -> 2d output src = CT([3, 4, 5, 6, 7]) index = cu(CartesianIndex.([(1,) (2,) (3,) (4,); @@ -42,7 +42,7 @@ output = CT([3 4 5 6; 6 4 3 5; 5 7 7 5]) - + y = NNlib.gather(src, index) @test y isa CuArray{Float32,2} @test size(y) == size(index) @@ -66,7 +66,7 @@ ## 2d src, 2d index of ints -> 3d output - src = CT([3 5 7 + src = CT([3 5 7 4 6 8]) index = cu([1 2 3; 2 2 1; @@ -79,14 +79,14 @@ output[:,:,2] = [5 5 3 6 6 4] - + output[:,:,3] = [7 3 7 8 4 8] - + y = NNlib.gather(src, index) M = NNlib.typelength(eltype(index)) Nsrc = ndims(src) @test y isa CuArray{Float32,3} - @test size(y) == (size(src)[1:Nsrc-M]..., size(index)...) + @test size(y) == (size(src)[1:Nsrc-M]..., size(index)...) gputest(src -> NNlib.gather(src, index), src, checkgrad=true) end diff --git a/src/NNlib.jl b/src/NNlib.jl index 184fbfc74..183d83cbc 100644 --- a/src/NNlib.jl +++ b/src/NNlib.jl @@ -1,10 +1,12 @@ module NNlib +import Atomix import ChainRulesCore: rrule using Base.Broadcast: broadcasted using Base.Threads using ChainRulesCore +using GPUArraysCore using KernelAbstractions using KernelAbstractions: @atomic using LinearAlgebra diff --git a/src/gather.jl b/src/gather.jl index 014d085be..1ad69df24 100644 --- a/src/gather.jl +++ b/src/gather.jl @@ -1,41 +1,10 @@ -""" - NNlib.gather!(dst, src, idx) - -Reverse operation of [`scatter!`](@ref). Gathers data from source `src` -and writes it in destination `dst` according to the index array `idx`. -For each `k` in `CartesianIndices(idx)`, assign values to `dst` according to - - dst[:, ... , k] .= src[:, ... , idx[k]...] - -Notice that if `idx` is a vector containing integers, -and both `dst` and `src` are matrices, previous expression simplifies to - - dst[:, k] .= src[:, idx[k]] - -and `k` will run over `1:length(idx)`. - -The elements of `idx` can be integers or integer tuples and may be repeated. -A single `src` column can end up being copied into zero, one, -or multiple `dst` columns. - -See [`gather`](@ref) for an allocating version. -""" -function gather!(dst::AbstractArray, src::AbstractArray, idx::AbstractArray) - dims = scatter_dims(src, dst, idx) - colons = ntuple(i -> Colon(), dims) - for k in CartesianIndices(idx) - _view(dst, colons, k) .= _view(src, colons, idx[k]) - end - return dst -end - """ NNlib.gather(src, idx) -> dst -Reverse operation of [`scatter`](@ref). Gathers data from source `src` +Reverse operation of [`scatter`](@ref). Gathers data from source `src` and writes it in a destination `dst` according to the index array `idx`. -For each `k` in `CartesianIndices(idx)`, assign values to `dst` +For each `k` in `CartesianIndices(idx)`, assign values to `dst` according to dst[:, ... , k] .= src[:, ... , idx[k]...] @@ -45,10 +14,10 @@ and `src` is a matrix, previous expression simplifies to dst[:, k] .= src[:, idx[k]] -and `k` will run over `1:length(idx)`. +and `k` will run over `1:length(idx)`. -The elements of `idx` can be integers or integer tuples and may be repeated. -A single `src` column can end up being copied into zero, one, +The elements of `idx` can be integers or integer tuples and may be repeated. +A single `src` column can end up being copied into zero, one, or multiple `dst` columns. See [`gather!`](@ref) for an in-place version. @@ -68,29 +37,19 @@ julia> NNlib.gather([1 2 3; 4 5 6], [1,3,1,3,1]) 4 6 4 6 4 ``` """ -function gather(src::AbstractArray{Tsrc, Nsrc}, - idx::AbstractArray{Tidx, Nidx}) where - {Tsrc, Nsrc, Nidx, Tidx} - - M = typelength(Tidx) +function gather( + src::AbstractArray{Tsrc, Nsrc}, idx::AbstractArray{Tidx, Nidx}, +) where {Tsrc, Nsrc, Nidx, Tidx} + M = typelength(Tidx) dstsize = (size(src)[1:Nsrc-M]..., size(idx)...) dst = similar(src, Tsrc, dstsize) return gather!(dst, src, idx) end -∇gather_src(Δ, src_size, idx) = scatter!(+, fill!(similar(Δ, eltype(Δ), src_size), 0), Δ, idx) - -function rrule(::typeof(gather!), dst::AbstractArray, src::AbstractArray, idx::AbstractArray) - y = gather!(dst, src, idx) - src_size = size(src) - gather!_pullback(Δ) = (NoTangent(), NoTangent(), ∇gather_src(unthunk(Δ), src_size, idx), NoTangent()) - return y, gather!_pullback -end - """ gather(src, IJK...) -Convert the tuple of integer vectors `IJK` to a tuple of `CartesianIndex` and +Convert the tuple of integer vectors `IJK` to a tuple of `CartesianIndex` and call `gather` on it: `gather(src, CartesianIndex.(IJK...))`. # Examples @@ -108,14 +67,72 @@ julia> NNlib.gather(src, [1, 2], [2, 4]) 11 ``` """ -function gather(src::AbstractArray{Tsrc, Nsrc}, - I::AbstractVector{<:Integer}, - J::AbstractVector{<:Integer}, - Ks::AbstractVector{<:Integer}...) where {Nsrc, Tsrc} - +function gather( + src::AbstractArray{Tsrc, Nsrc}, + I::AbstractVector{<:Integer}, J::AbstractVector{<:Integer}, + Ks::AbstractVector{<:Integer}..., +) where {Nsrc, Tsrc} return gather(src, to_cartesian_index(I, J, Ks...)) end to_cartesian_index(IJK...) = CartesianIndex.(IJK...) @non_differentiable to_cartesian_index(::Any...) +""" + NNlib.gather!(dst, src, idx) + +Reverse operation of [`scatter!`](@ref). Gathers data from source `src` +and writes it in destination `dst` according to the index array `idx`. +For each `k` in `CartesianIndices(idx)`, assign values to `dst` according to + + dst[:, ... , k] .= src[:, ... , idx[k]...] + +Notice that if `idx` is a vector containing integers, +and both `dst` and `src` are matrices, previous expression simplifies to + + dst[:, k] .= src[:, idx[k]] + +and `k` will run over `1:length(idx)`. + +The elements of `idx` can be integers or integer tuples and may be repeated. +A single `src` column can end up being copied into zero, one, +or multiple `dst` columns. + +See [`gather`](@ref) for an allocating version. +""" +function gather!(dst::AbstractArray, src::AbstractArray, idx::AbstractArray) + dims = scatter_dims(src, dst, idx) + colons = ntuple(i -> Colon(), dims) + for k in CartesianIndices(idx) + _view(dst, colons, k) .= _view(src, colons, idx[k]) + end + return dst +end + +function gather!(dst::AbstractGPUArray, src::AbstractGPUArray, idx::AbstractGPUArray) + n_dims = scatter_dims(src, dst, idx) + dims = size(src)[1:n_dims] + max_dims_idx = prod(dims) + ndrange = max_dims_idx * length(idx) + _gather!(KernelAbstractions.get_backend(src))( + dst, src, idx, CartesianIndices(dims), max_dims_idx; ndrange) + return dst +end + +@kernel function _gather!( + dst, @Const(src), @Const(idx), + dim_ids::CartesianIndices, max_dims_idx::Int, +) + i = @index(Global) + j, k = divrem(i - 1, max_dims_idx) + @inbounds dst[i] = src[dim_ids[k + 1], Tuple(idx[j + 1])...] +end + +∇gather_src(Δ, src_size, idx) = scatter!(+, fill!(similar(Δ, eltype(Δ), src_size), 0), Δ, idx) + +function rrule(::typeof(gather!), dst::AbstractArray, src::AbstractArray, idx::AbstractArray) + y = gather!(dst, src, idx) + src_size = size(src) + gather!_pullback(Δ) = (NoTangent(), NoTangent(), ∇gather_src(unthunk(Δ), src_size, idx), NoTangent()) + return y, gather!_pullback +end diff --git a/src/scatter.jl b/src/scatter.jl index 88bb42c65..6057e4528 100644 --- a/src/scatter.jl +++ b/src/scatter.jl @@ -14,14 +14,14 @@ typelength(::Type{<:NTuple{M}}) where M = M typelength(::Type{CartesianIndex{M}}) where M = M """ -Performs dimensional consistency checks and return the +Performs dimensional consistency checks and return the dimensionality of the scattered objects. """ -function scatter_dims(X::AbstractArray{Tx,Nx}, - Y::AbstractArray{Ty,Ny}, - idx::AbstractArray{Tidx,Nidx}) where {Tx,Ty,Tidx,Nx,Ny,Nidx} - M = typelength(Tidx) - dims = scatter_dims(Nx, Ny, M, Nidx) +function scatter_dims( + X::AbstractArray{Tx,Nx}, Y::AbstractArray{Ty,Ny}, + idx::AbstractArray{Tidx,Nidx}, +) where {Tx,Ty,Tidx,Nx,Ny,Nidx} + dims = scatter_dims(Nx, Ny, typelength(Tidx), Nidx) size(X)[1:dims] == size(Y)[1:dims] || throw(ArgumentError("Incompatible input shapes.")) size(Y)[dims+1:end] == size(idx) || throw(ArgumentError("Incompatible input shapes.")) return dims @@ -41,7 +41,7 @@ _view(X, colons, k::Union{Integer, CartesianIndex}) = view(X, colons..., k) NNlib.scatter!(op, dst, src, idx) Scatter operation, which writes data in `src` into `dst` at locations `idx`. -A binary reduction operator `op` is applied during the scatter. +A binary reduction operator `op` is applied during the scatter. For each index `k` in `idx`, accumulates values in `dst` according to dst[:, ..., idx[k]...] = (op).(dst[:, ..., idx[k]...], src[:, ..., k...]) @@ -53,7 +53,7 @@ See also [`scatter`](@ref), [`gather`](@ref). - `op`: Operations to be applied on `dst` and `src`, e.g. `+`, `-`, `*`, `/`, `max`, `min` and `mean`. - `dst`: The destination for `src` to aggregate to. This argument will be mutated. - `src`: The source data for aggregating. -- `idx`: The mapping for aggregation from source (index) to destination (value). +- `idx`: The mapping for aggregation from source (index) to destination (value). The `idx` array can contain either integers or tuples. # Examples @@ -81,23 +81,61 @@ function scatter!(op::OP, dst::AbstractArray, src::AbstractArray, idx::AbstractA dst end -function scatter!(op::typeof(mean), dst::AbstractArray, src::AbstractArray, idx::AbstractArray) - Ns = scatter!(+, zero(dst), one.(src), idx) - dst_ = scatter!(+, zero(dst), src, idx) - dst .+= safe_div.(dst_, Ns) - return dst +for AT in (AbstractArray, AbstractGPUArray) + @eval function scatter!(op::typeof(mean), dst::$AT, src::$AT, idx::$AT) + Ns = scatter!(+, zero(dst), one.(src), idx) + dst_ = scatter!(+, zero(dst), src, idx) + dst .+= safe_div.(dst_, Ns) + return dst + end end +function scatter!(op::OP, dst::AbstractGPUArray, src::AbstractGPUArray, idx::AbstractGPUArray) where OP + n_dims = scatter_dims(dst, src, idx) + args = if n_dims == 0 + ndrange = length(idx) + () + else + dims = size(dst)[1:n_dims] + max_dims_idx = prod(dims) + ndrange = max_dims_idx * length(idx) + (CartesianIndices(dims), max_dims_idx) + end + _scatter!(KernelAbstractions.get_backend(dst))( + op, dst, src, idx, args...; ndrange) + dst +end + +@kernel function _scatter!(op::OP, dst, src, idxs) where OP + i = @index(Global) + @inbounds idx = Tuple(idxs[i]) + @inbounds Atomix.modify!(Atomix.IndexableRef(dst, idx), op, src[i]) + # FIXME `@atomic` macro silently fails to perform atomic op below + # @atomic dst[idx...] = op(dst[idx...], src[i]) +end + +@kernel function _scatter!( + op::OP, dst, src, idxs, dim_ids::CartesianIndices, max_dims_idx::Int, +) where OP + i = @index(Global) + j, k = divrem(i - 1, max_dims_idx) + @inbounds idx = (Tuple(dim_ids[k + 1])..., Tuple(idxs[j + 1])...) + @inbounds Atomix.modify!(Atomix.IndexableRef(dst, idx), op, src[i]) + # FIXME + # dim_i = Tuple(dim_ids[k + 1]) + # idx = idxs[j + 1] + # @atomic dst[dim_i..., idx...] = op(dst[dim_i..., idx...], src[i]) +end """ NNlib.scatter(op, src, idx; [init, dstsize]) -Scatter operation allocating a destination array `dst` and +Scatter operation allocating a destination array `dst` and calling `scatter!(op, dst, src, idx)` on it. * If keyword `init` is provided, it is used to initialize the content of `dst`. Otherwise, the init values is inferred from the reduction operator `op` - for some common operators (e.g. `init = 0` for `op = +`). + for some common operators (e.g. `init = 0` for `op = +`). * If `dstsize` is provided, it will be used to define the size of destination array, otherwise it will be inferred by `src` and `idx`. @@ -127,15 +165,14 @@ julia> NNlib.scatter(*, [10,200,3000], [1,4,2]; init = 10, dstsize = 6) 10 ``` """ -function scatter(op::OP, - src::AbstractArray{Tsrc,Nsrc}, - idx::AbstractArray{Tidx,Nidx}; - init = nothing, dstsize = nothing) where {Tsrc,Tidx,Nsrc,Nidx,OP} - +function scatter( + op::OP, src::AbstractArray{Tsrc,Nsrc}, idx::AbstractArray{Tidx,Nidx}; + init = nothing, dstsize = nothing, +) where {Tsrc,Tidx,Nsrc,Nidx,OP} dims = Nsrc - Nidx - dstsz = isnothing(dstsize) ? (size(src)[1:dims]..., maximum_dims(idx)...) : dstsize + dstsz = isnothing(dstsize) ? (size(src)[1:dims]..., maximum_dims(idx)...) : dstsize dst = similar(src, Tsrc, dstsz) - xinit = isnothing(init) ? scatter_empty(op, Tsrc) : init + xinit = isnothing(init) ? scatter_empty(op, Tsrc) : init fill!(dst, xinit) scatter!(op, dst, src, idx) end @@ -147,13 +184,13 @@ scatter_empty(op::typeof(min), T) = typemax(T) scatter_empty(op::typeof(max), T) = typemin(T) scatter_empty(op::typeof(mean), T) = zero(T) - ## Gradients -∇scatter!_src(op, Δ, dst, src, idx) = ∇scatter_src(op, Δ, dst, src, idx) +∇scatter!_src(op, Δ, dst, src, idx) = ∇scatter_src(op, Δ, dst, src, idx) +∇scatter!_src(op::Union{typeof(*),typeof(/)}, Δ, dst, src, idx) = + gather(dst, idx) .* ∇scatter_src(op, Δ, dst, src, idx) ∇scatter!_dst(op, Δ, dst, y) = Δ - -∇scatter!_dst(op::Union{typeof(max),typeof(min)}, Δ, dst_old, dst) = +∇scatter!_dst(op::Union{typeof(max),typeof(min)}, Δ, dst_old, dst) = (dst_old .== op.(dst_old, dst)) .* Δ modify_src(::typeof(+), X) = X @@ -161,14 +198,15 @@ modify_src(::typeof(-), X) = -X modify_src(::typeof(*), X, Y) = X modify_src(::typeof(/), X, Y) = .-X ./ Y.^2 -∇scatter_src(op::Union{typeof(+),typeof(-)}, Δ, dst, src, idx) = modify_src(op, gather(Δ, idx)) +∇scatter_src(op::Union{typeof(+),typeof(-)}, Δ, dst, src, idx) = + modify_src(op, gather(Δ, idx)) +∇scatter_src(::Union{typeof(max),typeof(min)}, Δ, dst, src, idx) = + (src .== gather(dst, idx)) .* gather(Δ, idx) -∇scatter!_src(op::Union{typeof(*),typeof(/)}, Δ, dst, src, idx) = - gather(dst, idx) .* ∇scatter_src(op, Δ, dst, src, idx) - -function ∇scatter_src(op::Union{typeof(*),typeof(/)}, Δ, dst, - src::AbstractArray{Tsrc,Nsrc}, - idx::AbstractArray{Tidx,Nidx}) where {Tsrc,Tidx,Nsrc,Nidx} +function ∇scatter_src( + op::Union{typeof(*),typeof(/)}, Δ, dst, + src::AbstractArray{Tsrc,Nsrc}, idx::AbstractArray{Tidx,Nidx}, +) where {Tsrc,Tidx,Nsrc,Nidx} dims = Nsrc - Nidx Δsrc = modify_src(op, gather(Δ, idx), src) rev_idx = reverse_indices(idx) @@ -182,12 +220,66 @@ function ∇scatter_src(op::Union{typeof(*),typeof(/)}, Δ, dst, Δsrc end -∇scatter_src(::Union{typeof(max),typeof(min)}, Δ, dst, src, idx) = (src .== gather(dst, idx)) .* gather(Δ, idx) +function ∇scatter_src( + op::Union{typeof(*), typeof(/)}, Δ, dst, + src::AbstractGPUArray{Tsrc, Nsrc}, idx::AbstractGPUArray{Tidx, Nidx}, +) where {Tsrc, Nsrc, Tidx, Nidx} + n_dims = Nsrc - Nidx + Δsrc = NNlib.modify_src(op, NNlib.gather(Δ, idx), src) + rev_idx = NNlib.reverse_indices(idx) + + args = if n_dims == 0 + ndrange = length(idx) + () + else + dims = size(dst)[1:n_dims] + max_dims_idx = prod(dims) + ndrange = max_dims_idx * length(idx) + (CartesianIndices(dims), max_dims_idx) + end + _∇scatter_src(KernelAbstractions.get_backend(src))( + op, Δsrc, src, idx, rev_idx, args...; ndrange) + KernelAbstractions.unsafe_free!(rev_idx) + return Δsrc +end + +@kernel function _∇scatter_src(op, Δsrc, src::AbstractArray{T}, idx, rev_idx) where T + i = @index(Global) + cart_j = CartesianIndices(idx)[i] + @inbounds begin + inds = rev_idx[Tuple(idx[cart_j])...] + x = one(T) + for k in inds + x *= src[k] + end + x /= src[cart_j] + Δsrc[cart_j] = op(Δsrc[cart_j], x) + end +end + +@kernel function _∇scatter_src( + op, Δsrc, src::AbstractArray{T}, idx, rev_idx, + dim_ids::CartesianIndices, max_dims_idx::Int, +) where T + i = @index(Global) + j, k = fldmod1(i, max_dims_idx) + @inbounds begin + cart_j = CartesianIndices(idx)[j] + cart_k = dim_ids[k] + inds = rev_idx[Tuple(cart_j)...] + x = one(T) + for s in inds + x *= src[Tuple(cart_k)..., Tuple(s)...] + end + x /= src[i] + Δsrc[i] = op(Δsrc[i], x) + end +end -function ∇scatter_src(::typeof(mean), Δ, dst, - src::AbstractArray{Tsrc,Nsrc}, - idx::AbstractArray{Tidx,Nidx}) where {Tsrc,Tidx,Nsrc,Nidx} - +function ∇scatter_src( + ::typeof(mean), Δ, dst, + src::AbstractArray{Tsrc,Nsrc}, idx::AbstractArray{Tidx,Nidx}, +) where {Tsrc,Tidx,Nsrc,Nidx} M = typelength(Tidx) num = gather(Δ, idx) counts = fill!(similar(Δ, Int, size(Δ)[end-M+1:end]), 0) @@ -200,8 +292,7 @@ function ∇scatter_src(::typeof(mean), Δ, dst, return safe_div.(num, den) end -∇scatter_src(op, Δ, dst, src, idx) = - ∇scatter_src(op, unthunk(Δ), dst, src, idx) +∇scatter_src(op, Δ, dst, src, idx) = ∇scatter_src(op, unthunk(Δ), dst, src, idx) function rrule(::typeof(scatter!), op, dst::AbstractArray, src::AbstractArray, idx::AbstractArray) dst_old = copy(dst) diff --git a/test/gather.jl b/test/gather.jl index 21ce1630d..e3221145b 100644 --- a/test/gather.jl +++ b/test/gather.jl @@ -1,162 +1,181 @@ using NNlib: gather, gather! -@testset "gather scalar index" begin +function gather_testsuite(Backend) + device(x) = adapt(Backend(), x) + gradtest_fn = Backend == CPU ? gradtest : gputest T = Float32 - - ## 1d src, 2d index of ints -> 2d output - src = T[3, 4, 5, 6, 7] - index = [1 2 3 4; - 4 2 1 3; - 3 5 5 3] - output = T[3 4 5 6; - 6 4 3 5; - 5 7 7 5] - - y = gather(src, index) - @test y isa Array{T,2} - @test size(y) == size(index) - @test y == output - @test gather!(T.(zero(index)), src, index) == output - @test_throws ArgumentError gather!(zeros(T, 3, 5), src, index) - - index2 = [1 2 3 4; - 4 2 1 3; - 3 6 5 3] - @test_throws BoundsError gather!(T.(zero(index)), src, index2) - - ## 1d src, 3d index of ints -> 3d output - src = T[3, 4, 5, 6, 7] - index = [1 2 3 4; - 4 2 1 3; - 3 5 5 3][:,:,1:1] - output = T[3 4 5 6; - 6 4 3 5; - 5 7 7 5][:,:,1:1] - - y = gather(src, index) - @test y isa Array{T,3} - @test size(y) == size(index) - @test y == output - - - ## 2d src, 2d index of ints -> 3d output - src = T[3 5 7 - 4 6 8] - index = [1 2 3; - 2 2 1; - 3 1 3] - - output = zeros(T, 2, 3, 3) - - output[:,:,1] = [3 5 7 - 4 6 8] - - output[:,:,2] = [5 5 3 - 6 6 4] - - output[:,:,3] = [7 3 7 - 8 4 8] - - y = gather(src, index) - M = NNlib.typelength(eltype(index)) - Nsrc = ndims(src) - @test y isa Array{T,3} - @test size(y) == (size(src)[1:Nsrc-M]..., size(index)...) - @test y == output -end - -@testset "gather tuple index" begin - T = Float32 - - ## 2d src, 1d index of 2-tuples -> 1d output - src = T[3 5 7 - 4 6 8] - index = [(1,1), (1,2), (1,3), (2,1), (2,2), (2,3)] - - output = T[3, 5, 7, 4, 6, 8] - - y = gather(src, index) - M = NNlib.typelength(eltype(index)) - Nsrc = ndims(src) - @test y isa Array{T,1} - @test size(y) == (size(src)[1:Nsrc-M]..., size(index)...) - @test y == output - - ## 3d src, 2d index of 2-tuples -> 3d output - n1, nsrc, nidx = 2, 3, 6 - src = rand(Float32, n1, nsrc, nsrc) - index = [(rand(1:nsrc), rand(1:nsrc)) for i=1:nidx, j=1:nidx] - - y = gather(src, index) - M = NNlib.typelength(eltype(index)) - Nsrc = ndims(src) - @test y isa Array{T,3} - @test size(y) == (size(src)[1:Nsrc-M]..., size(index)...) -end - -@testset "gather cartesian index" begin - T = Float32 - - ## 2d src, 1d index of 2-tuples -> 1d output - src = T[3 5 7 - 4 6 8] - - index = CartesianIndex.([(1,1), (1,2), (1,3), (2,1), (2,2), (2,3)]) - - output = T[3, 5, 7, 4, 6, 8] - - y = gather(src, index) - M = NNlib.typelength(eltype(index)) - Nsrc = ndims(src) - @test y isa Array{T,1} - @test size(y) == (size(src)[1:Nsrc-M]..., size(index)...) - @test y == output - - ## 3d src, 2d index of 2-tuples -> 3d output - n1, nsrc, nidx = 2, 3, 6 - src = rand(Float32, n1, nsrc, nsrc) - index = [CartesianIndex((rand(1:nsrc), rand(1:nsrc))) for i=1:nidx, j=1:nidx] - - y = gather(src, index) - M = NNlib.typelength(eltype(index)) - Nsrc = ndims(src) - @test y isa Array{T,3} - @test size(y) == (size(src)[1:Nsrc-M]..., size(index)...) -end - -@testset "gather gradient for scalar index" begin - T = Float64 - src = T[3, 4, 5, 6, 7] - index = [1 2 3 4; + @testset "gather scalar index" begin + ## 1d src, 2d index of ints -> 2d output + src = device(T[3, 4, 5, 6, 7]) + index = device([ + 1 2 3 4; 4 2 1 3; - 3 5 5 3] - dst = T[3 4 5 6; + 3 5 5 3]) + output = T[ + 3 4 5 6; 6 4 3 5; 5 7 7 5] - gradtest(xs -> gather!(dst, xs, index), src) - gradtest(xs -> gather(xs, index), src) -end + y = cpu(gather(src, index)) + @test y isa Array{T,2} + @test size(y) == size(index) + @test y == output + + dst = device(T.(zero(index))) + @test cpu(gather!(dst, src, index)) == output + dst = device(zeros(T, 3, 5)) + @test_throws ArgumentError gather!(dst, src, index) + + if Backend == CPU + index2 = [1 2 3 4; + 4 2 1 3; + 3 6 5 3] + @test_throws BoundsError gather!(T.(zero(index)), src, index2) + end + + ## 1d src, 3d index of ints -> 3d output + src = device(T[3, 4, 5, 6, 7]) + index = device([ + 1 2 3 4; + 4 2 1 3; + 3 5 5 3][:,:,1:1]) + output = T[ + 3 4 5 6; + 6 4 3 5; + 5 7 7 5][:,:,1:1] + + y = cpu(gather(src, index)) + @test y isa Array{T,3} + @test size(y) == size(index) + @test y == output + + ## 2d src, 2d index of ints -> 3d output + src = device(T[ + 3 5 7 + 4 6 8]) + index = device([ + 1 2 3; + 2 2 1; + 3 1 3]) -@testset "gather gradient for tuple index" begin - T = Float64 - src = T[3 5 7 + output = zeros(T, 2, 3, 3) + output[:,:,1] = [ + 3 5 7 4 6 8] - index = [(1,1), (1,2), (1,3), (2,1), (2,2), (2,3)] - dst = T[3, 5, 7, 4, 6, 8] - - gradtest(xs -> gather!(dst, xs, index), src) - gradtest(xs -> gather(xs, index), src) + output[:,:,2] = [ + 5 5 3 + 6 6 4] + output[:,:,3] = [ + 7 3 7 + 8 4 8] + + y = cpu(gather(src, index)) + M = NNlib.typelength(eltype(index)) + Nsrc = ndims(src) + @test y isa Array{T,3} + @test size(y) == (size(src)[1:Nsrc-M]..., size(index)...) + @test y == output + end + + @testset "gather tuple index" begin + ## 2d src, 1d index of 2-tuples -> 1d output + src = device(T[ + 3 5 7 + 4 6 8]) + index = device([(1,1), (1,2), (1,3), (2,1), (2,2), (2,3)]) + output = T[3, 5, 7, 4, 6, 8] + + y = cpu(gather(src, index)) + M = NNlib.typelength(eltype(index)) + Nsrc = ndims(src) + @test y isa Array{T,1} + @test size(y) == (size(src)[1:Nsrc-M]..., size(index)...) + @test y == output + + ## 3d src, 2d index of 2-tuples -> 3d output + n1, nsrc, nidx = 2, 3, 6 + src = device(rand(T, n1, nsrc, nsrc)) + index = device([ + (rand(1:nsrc), rand(1:nsrc)) for i=1:nidx, j=1:nidx]) + + y = cpu(gather(src, index)) + M = NNlib.typelength(eltype(index)) + Nsrc = ndims(src) + @test y isa Array{T,3} + @test size(y) == (size(src)[1:Nsrc-M]..., size(index)...) + end + + @testset "gather cartesian index" begin + ## 2d src, 1d index of 2-tuples -> 1d output + src = device(T[ + 3 5 7 + 4 6 8]) + index = device(CartesianIndex.([(1,1), (1,2), (1,3), (2,1), (2,2), (2,3)])) + output = T[3, 5, 7, 4, 6, 8] + + y = cpu(gather(src, index)) + M = NNlib.typelength(eltype(index)) + Nsrc = ndims(src) + @test y isa Array{T,1} + @test size(y) == (size(src)[1:Nsrc-M]..., size(index)...) + @test y == output + + ## 3d src, 2d index of 2-tuples -> 3d output + n1, nsrc, nidx = 2, 3, 6 + src = device(rand(Float32, n1, nsrc, nsrc)) + index = device([ + CartesianIndex((rand(1:nsrc), rand(1:nsrc))) for i=1:nidx, j=1:nidx]) + + y = cpu(gather(src, index)) + M = NNlib.typelength(eltype(index)) + Nsrc = ndims(src) + @test y isa Array{T,3} + @test size(y) == (size(src)[1:Nsrc-M]..., size(index)...) + end + + @testset "gather gradient for scalar index" begin + src = device(Float64[3, 4, 5, 6, 7]) + idx = device([ + 1 2 3 4; + 4 2 1 3; + 3 5 5 3]) + dst = device(Float64[ + 3 4 5 6; + 6 4 3 5; + 5 7 7 5]) + Backend == CPU ? + gradtest_fn(xs -> gather!(dst, xs, idx), src) : + gradtest_fn((d, s, i) -> gather!(d, s, i), dst, src, idx) + Backend == CPU ? + gradtest_fn(xs -> gather(xs, idx), src) : + gradtest_fn((s, i) -> gather(s, i), src, idx) + end + + @testset "gather gradient for tuple index" begin + src = device(Float64[ + 3 5 7 + 4 6 8]) + idx = device([(1,1), (1,2), (1,3), (2,1), (2,2), (2,3)]) + dst = device(Float64[3, 5, 7, 4, 6, 8]) + Backend == CPU ? + gradtest_fn(xs -> gather!(dst, xs, idx), src) : + gradtest_fn((d, s, i) -> gather!(d, s, i), dst, src, idx) + Backend == CPU ? + gradtest_fn(xs -> gather(xs, idx), src) : + gradtest_fn((s, i) -> gather(s, i), src, idx) + end + + @testset "gather(src, IJK...)" begin + x = device(reshape([1:15;], 3, 5)) + i, j = device([1,2]), device([2,4]) + y = gather(x, i, j) + @test cpu(y) == [4, 11] + y = gather(x, device([1, 2])) + @test cpu(y) == [ + 1 4 + 2 5 + 3 6] + end end -@testset "gather(src, IJK...)" begin - x = reshape([1:15;], 3, 5) - - y = gather(x, [1,2], [2,4]) - @test y == [4, 11] - - @test gather(x, [1, 2]) == [1 4 - 2 5 - 3 6] -end diff --git a/test/runtests.jl b/test/runtests.jl index 8b9061b77..3d7786ccc 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -28,12 +28,20 @@ end cpu(x) = adapt(CPU(), x) +include("gather.jl") +include("scatter.jl") include("upsample.jl") function nnlib_testsuite(Backend; skip_tests = Set{String}()) @conditional_testset "Upsample" skip_tests begin upsample_testsuite(Backend) end + @conditional_testset "Gather" skip_tests begin + gather_testsuite(Backend) + end + @conditional_testset "Scatter" skip_tests begin + scatter_testsuite(Backend) + end end @testset "NNlib.jl" verbose=true begin @@ -46,7 +54,7 @@ end using CUDA if CUDA.functional() @testset "CUDABackend" begin - nnlib_testsuite(CUDABackend) + nnlib_testsuite(CUDABackend; skip_tests=Set(("Scatter", "Gather"))) end else @info "CUDA.jl is not functional. Skipping test suite for CUDABackend." @@ -80,7 +88,7 @@ end end end - @testset "Tests" begin + @testset verbose=true "Tests" begin if get(ENV, "NNLIB_TEST_CUDA", "false") == "true" using CUDA if CUDA.functional() @@ -171,14 +179,6 @@ end include("softmax.jl") end - @testset "Gather" begin - include("gather.jl") - end - - @testset "Scatter" begin - include("scatter.jl") - end - @testset "Utilities" begin include("utils.jl") end diff --git a/test/scatter.jl b/test/scatter.jl index 0383e8e5d..26fc06cde 100644 --- a/test/scatter.jl +++ b/test/scatter.jl @@ -44,168 +44,168 @@ res = Dict( 6 4 8 8 6], (min, 0, true) => [1, 1, 1, 1, 1], (min, 1, true) => [1 1 1 1 1; - 1 1 1 1 1], + 1 1 1 1 1], (min, 0, false) => [1, 2, 1, 1, 2], (min, 1, false) => [1 2 1 1 2; 2 4 2 2 4], (*, 0, true) => [3, 4, 5, 6, 7], (*, 1, true) => [3 3 4 4 5; - 5 5 6 6 7], + 5 5 6 6 7], (*, 0, false) => [3, 4, 48, 4, 6], (*, 1, false) => [3 4 48 4 6; 12 16 768 16 24], (/, 0, true) => [0.75, 1., 0.3125, 1.5, 1.75], (/, 1, true) => [0.75 0.75 0.25 1. 1.25; - 1.25 1.25 0.375 1.5 1.75], + 1.25 1.25 0.375 1.5 1.75], (/, 0, false) => [1//3, 1//4, 1//48, 1//4, 1//6], (/, 1, false) => [1//3 1//4 1//48 1//4 1//6; 1//12 1//16 1//768 1//16 1//24], (mean, 0, true) => [4., 5., 6., 7., 8.], (mean, 1, true) => [4. 4. 5. 5. 6.; - 6. 6. 7. 7. 8.], + 6. 6. 7. 7. 8.], (mean, 0, false) => [2, 2, 3, 2.5, 2.5], (mean, 1, false) => [2. 2. 3. 2.5 2.5; 4. 4. 6. 5. 5.], ) -types = [UInt8, UInt32, Int64, - Float16, Float32, Float64, BigFloat, Rational] - -@testset "scatter" begin - for T = types +function test_scatter(device, types, ops; pt, ops_skip_types) + for T in types + PT = promote_type(T, pt) @testset "$T" begin - PT = promote_type(T, Int) - @testset "+" begin - for idx = values(idxs), dims = [0, 1] - mutated = true - @test scatter!(+, T.(copy(dsts[dims])), T.(srcs[(dims, mutated)]), idx) == T.(res[(+, dims, mutated)]) - @test scatter!(+, T.(copy(dsts[dims])), srcs[(dims, mutated)], idx) == PT.(res[(+, dims, mutated)]) - @test scatter!(+, copy(dsts[dims]), T.(srcs[(dims, mutated)]), idx) == PT.(res[(+, dims, mutated)]) - - mutated = false - @test scatter(+, T.(srcs[(dims, mutated)]), idx) == T.(res[(+, dims, mutated)]) - end - end - - @testset "-" begin - for idx = values(idxs), dims = [0, 1] - mutated = true - @test scatter!(-, T.(copy(dsts[dims])), T.(srcs[(dims, mutated)]), idx) == T.(res[(-, dims, mutated)]) - @test scatter!(-, T.(copy(dsts[dims])), srcs[(dims, mutated)], idx) == PT.(res[(-, dims, mutated)]) - @test scatter!(-, copy(dsts[dims]), T.(srcs[(dims, mutated)]), idx) == PT.(res[(-, dims, mutated)]) - - mutated = false - if !(T in [UInt8, UInt16, UInt32, UInt64, UInt128]) - @test scatter(-, T.(srcs[(dims, mutated)]), idx) == T.(res[(-, dims, mutated)]) - end - end - end - - @testset "max" begin - for idx = values(idxs), dims = [0, 1] - mutated = true - @test scatter!(max, T.(copy(dsts[dims])), T.(srcs[(dims, mutated)]), idx) == T.(res[(max, dims, mutated)]) - @test scatter!(max, T.(copy(dsts[dims])), srcs[(dims, mutated)], idx) == PT.(res[(max, dims, mutated)]) - @test scatter!(max, copy(dsts[dims]), T.(srcs[(dims, mutated)]), idx) == PT.(res[(max, dims, mutated)]) - - mutated = false - if !(T in [BigInt]) - @test scatter(max, T.(srcs[(dims, mutated)]), idx) == T.(res[(max, dims, mutated)]) - end - end - end - - @testset "min" begin - for idx = values(idxs), dims = [0, 1] - mutated = true - @test scatter!(min, T.(copy(dsts[dims])), T.(srcs[(dims, mutated)]), idx) == T.(res[(min, dims, mutated)]) - @test scatter!(min, T.(copy(dsts[dims])), srcs[(dims, mutated)], idx) == PT.(res[(min, dims, mutated)]) - @test scatter!(min, copy(dsts[dims]), T.(srcs[(dims, mutated)]), idx) == PT.(res[(min, dims, mutated)]) - - mutated = false - if !(T in [BigInt]) - @test scatter(min, T.(srcs[(dims, mutated)]), idx) == T.(res[(min, dims, mutated)]) - end - end - end - - @testset "*" begin - for idx = values(idxs), dims = [0, 1] - mutated = true - @test scatter!(*, T.(copy(dsts[dims])), T.(srcs[(dims, mutated)]), idx) == T.(res[(*, dims, mutated)]) - @test scatter!(*, T.(copy(dsts[dims])), srcs[(dims, mutated)], idx) == PT.(res[(*, dims, mutated)]) - @test scatter!(*, copy(dsts[dims]), T.(srcs[(dims, mutated)]), idx) == PT.(res[(*, dims, mutated)]) - - mutated = false - if !(T in [UInt8, Int8]) - @test scatter(*, T.(srcs[(dims, mutated)]), idx) == T.(res[(*, dims, mutated)]) + for op in ops + skip_types = get(ops_skip_types, op, []) + @testset "$op" begin + for idx = values(idxs), dims = [0, 1] + idx = device(idx) + dst = device(dsts[dims]) + + mutated = true + target_y = res[(op, dims, mutated)] + src = device(srcs[(dims, mutated)]) + if op == / + src = src .* T(2) + end + + @test cpu(scatter!(op, T.(dst), T.(src), idx)) == T.(target_y) + @test cpu(scatter!(op, T.(dst), src, idx)) == PT.(target_y) + if op == / + @test cpu(scatter!(op, T.(dst), T.(src), idx)) == PT.(target_y) + else + @test cpu(scatter!(op, copy(dst), T.(src), idx)) == PT.(target_y) + end + + if T ∉ skip_types + mutated = false + src = device(srcs[(dims, mutated)]) + @test cpu(scatter(op, T.(src), idx)) == T.(res[(op, dims, mutated)]) + end end end end end end +end - for T = [Float16, Float32, BigFloat, Rational] - @testset "$T" begin - PT = promote_type(T, Float64) - @testset "/" begin - for idx = values(idxs), dims = [0, 1] - mutated = true - @test scatter!(/, T.(dsts[dims]), T.(srcs[(dims, mutated)].*2), idx) == T.(res[(/, dims, mutated)]) - @test scatter!(/, T.(dsts[dims]), srcs[(dims, mutated)].*2, idx) == PT.(res[(/, dims, mutated)]) - @test scatter!(/, T.(dsts[dims]), T.(srcs[(dims, mutated)].*2), idx) == PT.(res[(/, dims, mutated)]) - - mutated = false - @test scatter(/, T.(srcs[(dims, mutated)]), idx) == T.(res[(/, dims, mutated)]) - end - end - - @testset "mean" begin - for idx = values(idxs), dims = [0, 1] - mutated = true - @test scatter!(mean, T.(dsts[dims]), T.(srcs[(dims, mutated)]), idx) == T.(res[(mean, dims, mutated)]) - @test scatter!(mean, T.(dsts[dims]), srcs[(dims, mutated)], idx) == PT.(res[(mean, dims, mutated)]) - @test scatter!(mean, copy(dsts[dims]), T.(srcs[(dims, mutated)]), idx) == PT.(res[(mean, dims, mutated)]) - - mutated = false - @test scatter(mean, T.(srcs[(dims, mutated)]), idx) == T.(res[(mean, dims, mutated)]) - end - end - end +function scatter_testsuite(Backend) + device(x) = adapt(Backend(), x) + gradtest_fn = Backend == CPU ? gradtest : gputest + + ops_skip_types = Dict( + (+) => [], + (-) => [UInt8, UInt16, UInt32, UInt64, UInt128], + (*) => [UInt8, Int8], + max => [BigInt], + min => [BigInt]) + types = if Backend == CPU + [UInt8, UInt32, UInt64, Int32, Int64, Float16, Float32, Float64, BigFloat, Rational] + elseif Symbol(Backend) == :CUDABackend + [Int32, Int64, Float32, Float64] + else + # Need LLVM 15+ for atomic fmin/fmax: + # https://reviews.llvm.org/D127041 + # But fmin/fmax can be done by reinterpreting an array to `UInt`. + [Int32, Int64, UInt32, UInt64] end - - @test_throws AssertionError scatter!(+, dsts[0], srcs[(1, true)], idxs[:int]) - idx = [1 2 3 4; 4 2 1 3; 6 7 8 9] - @test_throws BoundsError scatter!(+, dsts[1], srcs[(1, true)], idx) - - @testset "dstsize" begin - idx = [2, 2, 3, 4, 4] - src = ones(3, 5) - y = scatter(+, src, idx, dstsize = (3, 6)) - @test size(y) == (3, 6) - gradtest(x -> scatter(+, x, idx, dstsize = (3,6)), src) + ops = Backend == CPU ? + (+, -, max, min, *) : + (+, -, max, min) + test_scatter(device, types, ops; pt=Int, ops_skip_types) + + types = Backend == CPU ? + [Float16, Float32, BigFloat, Rational] : + [Float32, Float64] + ops = if Backend == CPU + (/, mean) + elseif Symbol(Backend) == :CUDABackend + (*, /, mean) + else + # LLVM does not support atomic fmul/fdiv: + # https://llvm.org/docs/LangRef.html#atomicrmw-instruction + (mean,) end -end - -@testset "∇scatter" begin - T = Float64 - fdm(op) = op == min ? :backward : :forward - # fdm(op) = :forward + test_scatter(device, types, ops; pt=Float64, ops_skip_types=Dict()) - @testset "∂dst" begin - for op in (+, -, *, /, mean, max, min) - gradtest(xs -> scatter!(op, copy(xs), srcs[(0, true)], idxs[:int]), T.(dsts[0]), fdm=fdm(op)) - gradtest(xs -> scatter!(op, copy(xs), srcs[(1, true)], idxs[:int]), T.(dsts[1]), fdm=fdm(op)) + if Backend == CPU + @testset "scatter exceptions" begin + idx = [1 2 3 4; 4 2 1 3; 6 7 8 9] + @test_throws AssertionError scatter!(+, copy(dsts[0]), srcs[(1, true)], idxs[:int]) + @test_throws BoundsError scatter!(+, copy(dsts[1]), srcs[(1, true)], idx) end end - @testset "∂src" begin - for op in (+, -, *, /, mean, max, min) - gradtest(xs -> scatter!(op, T.(dsts[0]), xs, idxs[:int]), T.(srcs[(0, true)]), fdm=fdm(op)) - gradtest(xs -> scatter!(op, T.(dsts[1]), xs, idxs[:int]), T.(srcs[(1, true)]), fdm=fdm(op)) + @testset "∇scatter" begin + T = Float64 + fdm(op) = op == min ? :backward : :forward + + @testset "dstsize" begin + idx = device([2, 2, 3, 4, 4]) + src = device(ones(T, 3, 5)) + y = scatter(+, src, idx, dstsize = (3, 6)) + @test eltype(y) == T + @test size(y) == (3, 6) + Backend == CPU ? + gradtest_fn(x -> scatter(+, x, idx; dstsize=(3, 6)), src) : + gradtest_fn((x, i) -> scatter(+, x, i; dstsize=(3, 6)), src, idx) + end + + @testset "∂dst" begin + ops = if Backend == CPU || Symbol(Backend) == :CUDABackend + (+, -, *, /, mean, max, min) + else + (+, -, mean, max, min) + end + for op in ops, i in (0, 1) + PT = ( # If not CPU and CUDA -> use Int64 for min/max. + Backend != CPU && + Symbol(Backend) != :CUDABackend && + (op == max || op == min)) ? Int64 : T + + src = device(srcs[(i, true)]) + idx = device(idxs[:int]) + dst = device(PT.(dsts[i])) + Backend == CPU ? + gradtest_fn(x -> scatter!(op, copy(x), src, idx), dst; fdm=fdm(op)) : + gradtest_fn((x, s, i) -> scatter!(op, x, s, i), dst, src, idx) + end + end - gradtest(xs -> scatter(op, xs, idxs[:int]), T.(srcs[(0, false)]), fdm=fdm(op)) - gradtest(xs -> scatter(op, xs, idxs[:int]), T.(srcs[(1, false)]), fdm=fdm(op)) + @testset "∂src" begin + ops = if Backend == CPU || Symbol(Backend) == :CUDABackend + (+, -, *, /, mean, max, min) + else + (+, -, mean, max, min) + end + for op in ops, i in (0, 1) + PT = ( # If not CPU and CUDA -> use Int64 for min/max. + Backend != CPU && + Symbol(Backend) != :CUDABackend && + (op == max || op == min)) ? Int64 : T + src = PT.(device(srcs[(i, false)])) + idx = device(idxs[:int]) + Backend == CPU ? + gradtest_fn(xs -> scatter(op, xs, idx), src; fdm=fdm(op)) : + gradtest_fn((xs, i) -> scatter(op, xs, i), src, idx) + end end end end diff --git a/test/upsample.jl b/test/upsample.jl index 28109d4b2..c13d64958 100644 --- a/test/upsample.jl +++ b/test/upsample.jl @@ -1,6 +1,6 @@ function upsample_testsuite(Backend) device(x) = adapt(Backend(), x) - gradtest_fn = KernelAbstractions.isgpu(Backend()) ? gputest : gradtest + gradtest_fn = Backend == CPU ? gradtest : gputest T = Float32 # TODO test against all supported eltypes for each backend. atol = T == Float32 ? 1e-3 : 1e-6