diff --git a/test/autodiff/batched_autodiff_tests.jl b/test/autodiff/batched_autodiff_tests.jl index 633725142..326aa155a 100644 --- a/test/autodiff/batched_autodiff_tests.jl +++ b/test/autodiff/batched_autodiff_tests.jl @@ -1,16 +1,20 @@ @testitem "Batched Jacobian" setup=[SharedTestSetup] tags=[:autodiff] begin - using ComponentArrays, ForwardDiff, Zygote + using ComponentArrays, ForwardDiff, Zygote, ADTypes rng = StableRNG(12345) @testset "$mode" for (mode, aType, dev, ongpu) in MODES models = ( - Chain(Conv((3, 3), 2 => 4, gelu; pad=SamePad()), - Conv((3, 3), 4 => 2, gelu; pad=SamePad()), FlattenLayer(), Dense(18 => 2)), - Chain(Dense(2, 4, gelu), Dense(4, 2))) + Chain( + Conv((3, 3), 2 => 4, gelu; pad=SamePad()), + Conv((3, 3), 4 => 2, gelu; pad=SamePad()), + FlattenLayer(), Dense(18 => 2) + ), + Chain(Dense(2, 4, gelu), Dense(4, 2)) + ) Xs = (aType(randn(rng, Float32, 3, 3, 2, 4)), aType(randn(rng, Float32, 2, 4))) - for (model, X) in zip(models, Xs) + for (i, (model, X)) in enumerate(zip(models, Xs)) ps, st = Lux.setup(rng, model) |> dev smodel = StatefulLuxLayer{true}(model, ps, st) @@ -18,7 +22,20 @@ ForwardDiff.jacobian(smodel, X) end - @testset "$(backend)" for backend in (AutoZygote(), AutoForwardDiff()) + @testset for backend in ( + AutoZygote(), AutoForwardDiff(), + AutoEnzyme(; + mode=Enzyme.Forward, function_annotation=Enzyme.Const + ), + AutoEnzyme(; + mode=Enzyme.set_runtime_activity(Enzyme.Reverse), + function_annotation=Enzyme.Const + ) + ) + # Forward rules for Enzyme is currently not implemented for several Ops + i == 1 && backend isa AutoEnzyme && + ADTypes.mode(backend) isa ADTypes.ForwardMode && continue + J2 = allow_unstable() do batched_jacobian(smodel, backend, X) end @@ -40,7 +57,14 @@ end @testset "Issue #636 Chunksize Specialization" begin - for N in (2, 4, 8, 11, 12, 50, 51), backend in (AutoZygote(), AutoForwardDiff()) + for N in (2, 4, 8, 11, 12, 50, 51), + backend in ( + AutoZygote(), AutoForwardDiff(), AutoEnzyme(), + AutoEnzyme(; mode=Enzyme.Reverse) + ) + + ongpu && backend isa AutoEnzyme && continue + model = @compact(; potential=Dense(N => N, gelu), backend=backend) do x @return allow_unstable() do batched_jacobian(potential, backend, x) @@ -78,6 +102,13 @@ end @test Jx_zygote ≈ Jx_true + if !ongpu + Jx_enzyme = allow_unstable() do + batched_jacobian(ftest, AutoEnzyme(), x) + end + @test Jx_enzyme ≈ Jx_true + end + fincorrect(x) = x[:, 1] x = reshape(Float32.(1:6), 2, 3) |> dev diff --git a/test/autodiff/nested_autodiff_tests.jl b/test/autodiff/nested_autodiff_tests.jl index 6ebe189c4..e328df283 100644 --- a/test/autodiff/nested_autodiff_tests.jl +++ b/test/autodiff/nested_autodiff_tests.jl @@ -267,10 +267,8 @@ end end @test_gradients(__f, x, - ps; - atol=1.0f-3, - rtol=1.0f-3, - skip_backends=[AutoReverseDiff(), AutoTracker(), AutoEnzyme()]) + ps; atol=1.0f-3, + rtol=1.0f-3, skip_backends=[AutoReverseDiff(), AutoTracker(), AutoEnzyme()]) end end end @@ -409,6 +407,5 @@ end end @test_gradients(__f, x, ps; atol=1.0f-3, - rtol=1.0f-3, - skip_backends=[AutoReverseDiff(), AutoTracker(), AutoEnzyme()]) + rtol=1.0f-3, skip_backends=[AutoReverseDiff(), AutoTracker(), AutoEnzyme()]) end