diff --git a/test/layers/basic_tests.jl b/test/layers/basic_tests.jl index 0d8926764..1e36b9899 100644 --- a/test/layers/basic_tests.jl +++ b/test/layers/basic_tests.jl @@ -165,15 +165,19 @@ end ) x = randn(SVector{N, Float64}) - grad1 = ForwardDiff.gradient(ComponentArray(ps)) do ps - sumabs2first(d, x, ps, (;)) - end + broken = pkgversion(Enzyme) ≥ v"0.13.18" + + @test begin + grad1 = ForwardDiff.gradient(ComponentArray(ps)) do ps + sumabs2first(d, x, ps, (;)) + end - grad2 = Enzyme.gradient( - Enzyme.Reverse, sumabs2first, Const(d), Const(x), ps, Const((;)) - )[3] + grad2 = Enzyme.gradient( + Enzyme.Reverse, sumabs2first, Const(d), Const(x), ps, Const((;)) + )[3] - @test maximum(abs, grad1 .- ComponentArray(grad2)) < 1e-6 + maximum(abs, grad1 .- ComponentArray(grad2)) < 1e-6 + end broken=broken end end