diff --git a/test/enzyme_tests.jl b/test/enzyme_tests.jl index 8fbf085fc..7141e7d09 100644 --- a/test/enzyme_tests.jl +++ b/test/enzyme_tests.jl @@ -83,12 +83,10 @@ export generic_loss_function, compute_enzyme_gradient, compute_zygote_gradient, end @testitem "Enzyme Integration" setup=[EnzymeTestSetup, SharedTestSetup] tags=[ - :autodiff, :enzyme] begin + :autodiff, :enzyme] timeout=3600 begin rng = StableRNG(12345) @testset "$mode" for (mode, aType, dev, ongpu) in MODES - # TODO: Currently all the tests are run on CPU. We should eventually add tests for - # CUDA and AMDGPU. ongpu && continue @testset "[$(i)] $(nameof(typeof(model)))" for (i, (model, x)) in enumerate(MODELS_LIST) @@ -106,15 +104,13 @@ end end end -@testitem "Enzyme Integration ComponentArray" setup=[EnzymeTestSetup, SharedTestSetup] tags=[ +@testitem "Enzyme Integration ComponentArray" setup=[EnzymeTestSetup, SharedTestSetup] timeout=3600 tags=[ :autodiff, :enzyme] begin using ComponentArrays rng = StableRNG(12345) @testset "$mode" for (mode, aType, dev, ongpu) in MODES - # TODO: Currently all the tests are run on CPU. We should eventually add tests for - # CUDA and AMDGPU. ongpu && continue @testset "[$(i)] $(nameof(typeof(model)))" for (i, (model, x)) in enumerate(MODELS_LIST) diff --git a/test/layers/basic_tests.jl b/test/layers/basic_tests.jl index 3adea7323..1e36b9899 100644 --- a/test/layers/basic_tests.jl +++ b/test/layers/basic_tests.jl @@ -165,12 +165,19 @@ end ) x = randn(SVector{N, Float64}) - fun = let d = d, x = x - ps -> sum(d(x, ps, (;))[1]) - end - grad1 = ForwardDiff.gradient(fun, ComponentVector(ps)) - grad2 = Enzyme.gradient(Enzyme.Reverse, fun, ps)[1] - @test maximum(abs, grad1 .- ComponentVector(grad2)) < 1e-6 + broken = pkgversion(Enzyme) ≥ v"0.13.18" + + @test begin + grad1 = ForwardDiff.gradient(ComponentArray(ps)) do ps + sumabs2first(d, x, ps, (;)) + end + + grad2 = Enzyme.gradient( + Enzyme.Reverse, sumabs2first, Const(d), Const(x), ps, Const((;)) + )[3] + + maximum(abs, grad1 .- ComponentArray(grad2)) < 1e-6 + end broken=broken end end