diff --git a/ext/NNlibAMDGPUExt/activations.jl b/ext/NNlibAMDGPUExt/activations.jl index 1563bb45e..498cc8a8a 100644 --- a/ext/NNlibAMDGPUExt/activations.jl +++ b/ext/NNlibAMDGPUExt/activations.jl @@ -1,13 +1,13 @@ for (f, op) in [ - NNlib.relu => MIOpen.relu, - NNlib.relu6 => x -> MIOpen.clippedrelu(x, 6), - NNlib.softplus => MIOpen.softrelu, - NNlib.σ => MIOpen.sigmoid, - Base.tanh => MIOpen.tanh, - # TODO define for leakyrelu, elu, etc.? -] + NNlib.relu => MIOpen.relu, + NNlib.relu6 => x -> MIOpen.clippedrelu(x, 6), + NNlib.softplus => MIOpen.softrelu, + NNlib.σ => MIOpen.sigmoid, + Base.tanh => MIOpen.tanh, + # TODO define for leakyrelu, elu, etc.? + ], N in 1:5 @eval function Base.materialize( - bc::Broadcast.Broadcasted{<:Any,<:Any,typeof($f),<:Tuple{ROCArray{<:MIOPENFloat}}} + bc::Broadcast.Broadcasted{<:Any,<:Any,typeof($f),<:Tuple{ROCArray{<:MIOPENFloat,$N}}} ) return $op(bc.args[1]) end diff --git a/test/ext_amdgpu/activations.jl b/test/ext_amdgpu/activations.jl index 2abb0c272..afc59a45c 100644 --- a/test/ext_amdgpu/activations.jl +++ b/test/ext_amdgpu/activations.jl @@ -1,10 +1,11 @@ @testset "Compare CPU & GPU" begin - for (T, atol) in ((Float16, 1f-2), (Float32, 1f-5)) - x = randn(T, 16) - gputest(x -> NNlib.relu.(x), x; atol) - gputest(x -> NNlib.relu6.(x), x; atol) - gputest(x -> NNlib.softplus.(x), x; atol) - gputest(x -> tanh.(x), x; atol) - gputest(x -> identity.(x), x; atol) + for (T, atol) in ((Float16, 1.0f-2), (Float32, 1.0f-5)) + @testset "ndims: $(ndims(x))" for x in (randn(T, 16), randn(T, ntuple(_ -> 2, 5)...), randn(T, ntuple(_ -> 2, 6)...)) + gputest(x -> NNlib.relu.(x), x; atol) + gputest(x -> NNlib.relu6.(x), x; atol) + gputest(x -> NNlib.softplus.(x), x; atol) + gputest(x -> tanh.(x), x; atol) + gputest(x -> identity.(x), x; atol) + end end end