diff --git a/test/layers/basic_tests.jl b/test/layers/basic_tests.jl index 3adea7323d..0d8926764f 100644 --- a/test/layers/basic_tests.jl +++ b/test/layers/basic_tests.jl @@ -165,12 +165,15 @@ end ) x = randn(SVector{N, Float64}) - fun = let d = d, x = x - ps -> sum(d(x, ps, (;))[1]) + grad1 = ForwardDiff.gradient(ComponentArray(ps)) do ps + sumabs2first(d, x, ps, (;)) end - grad1 = ForwardDiff.gradient(fun, ComponentVector(ps)) - grad2 = Enzyme.gradient(Enzyme.Reverse, fun, ps)[1] - @test maximum(abs, grad1 .- ComponentVector(grad2)) < 1e-6 + + grad2 = Enzyme.gradient( + Enzyme.Reverse, sumabs2first, Const(d), Const(x), ps, Const((;)) + )[3] + + @test maximum(abs, grad1 .- ComponentArray(grad2)) < 1e-6 end end