diff --git a/docs/src/manual/differential_equation_sensitivities.md b/docs/src/manual/differential_equation_sensitivities.md index 291b3cc49..ca1c37e8d 100644 --- a/docs/src/manual/differential_equation_sensitivities.md +++ b/docs/src/manual/differential_equation_sensitivities.md @@ -11,6 +11,7 @@ requiring the user to do any of the setup. Current AD libraries whose calls are captured by the sensitivity system are: + - [Enzyme.jl](https://github.com/EnzymeAD/Enzyme.jl) - [Zygote.jl](https://fluxml.ai/Zygote.jl/stable/) - [Diffractor.jl](https://github.com/JuliaDiff/Diffractor.jl) - [ReverseDiff.jl](https://github.com/JuliaDiff/ReverseDiff.jl) diff --git a/test/alternative_ad_frontend.jl b/test/alternative_ad_frontend.jl index b81046961..f5cb22606 100644 --- a/test/alternative_ad_frontend.jl +++ b/test/alternative_ad_frontend.jl @@ -1,10 +1,10 @@ -using OrdinaryDiffEq, SciMLSensitivity, ForwardDiff, Zygote, ReverseDiff, Tracker +using OrdinaryDiffEq, SciMLSensitivity, ForwardDiff, Zygote, ReverseDiff, Tracker, Enzyme using Test prob = ODEProblem((u, p, t) -> u .* p, [2.0], (0.0, 1.0), [3.0]) -struct senseloss - sense::Any +struct senseloss{T} + sense::T end function (f::senseloss)(u0p) sum(solve(prob, Tsit5(), u0 = u0p[1:1], p = u0p[2:2], abstol = 1e-12, @@ -34,8 +34,15 @@ dup = Zygote.gradient(senseloss(InterpolatingAdjoint()), u0p)[1] @test ForwardDiff.gradient(senseloss(InterpolatingAdjoint()), u0p) ≈ dup -struct senseloss2 - sense::Any +@test Enzyme.gradient(Reverse, senseloss(InterpolatingAdjoint()), u0p) ≈ dup +@test_broken Enzyme.gradient(Reverse, senseloss(ReverseDiffAdjoint()), u0p) ≈ dup +@test Enzyme.gradient(Reverse, senseloss(TrackerAdjoint()), u0p) ≈ dup +@test Enzyme.gradient(Reverse, senseloss(ForwardDiffSensitivity()), u0p) ≈ dup +@test_throws SciMLSensitivity.ForwardSensitivityOutOfPlaceError ReverseDiff.gradient(Reverse, senseloss(ForwardSensitivity()), + u0p)≈dup + +struct senseloss2{T} + sense::T end prob2 = ODEProblem((du, u, p, t) -> du .= u .* p, [2.0], (0.0, 1.0), [3.0]) @@ -62,8 +69,14 @@ dup = Zygote.gradient(senseloss2(InterpolatingAdjoint()), u0p)[1] @test ForwardDiff.gradient(senseloss2(InterpolatingAdjoint()), u0p) ≈ dup -struct senseloss3 - sense::Any +@test Enzyme.gradient(Reverse, senseloss2(InterpolatingAdjoint()), u0p) ≈ dup +@test Enzyme.gradient(Reverse, senseloss2(ReverseDiffAdjoint()), u0p) ≈ dup +@test Enzyme.gradient(Reverse, senseloss2(TrackerAdjoint()), u0p) ≈ dup +@test Enzyme.gradient(Reverse, senseloss2(ForwardDiffSensitivity()), u0p) ≈ dup +@test_broken Enzyme.gradient(Reverse, senseloss2(ForwardSensitivity()), u0p) ≈ dup + +struct senseloss3{T} + sense::T end function (f::senseloss3)(u0p) sum(solve(prob2, Tsit5(), p = u0p, abstol = 1e-12, @@ -88,8 +101,14 @@ dup = Zygote.gradient(senseloss3(InterpolatingAdjoint()), u0p)[1] @test ForwardDiff.gradient(senseloss3(InterpolatingAdjoint()), u0p) ≈ dup -struct senseloss4 - sense::Any +@test Enzyme.gradient(Reverse, senseloss3(InterpolatingAdjoint()), u0p) ≈ dup +@test Enzyme.gradient(Reverse, senseloss3(ReverseDiffAdjoint()), u0p) ≈ dup +@test Enzyme.gradient(Reverse, senseloss3(TrackerAdjoint()), u0p) ≈ dup +@test Enzyme.gradient(Reverse, senseloss3(ForwardDiffSensitivity()), u0p) ≈ dup +@test Enzyme.gradient(Reverse, senseloss3(ForwardSensitivity()), u0p) ≈ dup + +struct senseloss4{T} + sense::T end function (f::senseloss4)(u0p) sum(solve(prob, Tsit5(), p = u0p, abstol = 1e-12, @@ -116,6 +135,13 @@ dup = Zygote.gradient(senseloss4(InterpolatingAdjoint()), u0p)[1] @test ForwardDiff.gradient(senseloss4(InterpolatingAdjoint()), u0p) ≈ dup +@test Enzyme.gradient(Reverse, senseloss4(InterpolatingAdjoint()), u0p) ≈ dup +@test_broken Enzyme.gradient(Reverse, senseloss4(ReverseDiffAdjoint()), u0p) ≈ dup +@test Enzyme.gradient(Reverse, senseloss4(TrackerAdjoint()), u0p) ≈ dup +@test Enzyme.gradient(Reverse, senseloss4(ForwardDiffSensitivity()), u0p) ≈ dup +@test_throws SciMLSensitivity.ForwardSensitivityOutOfPlaceError Enzyme.gradient(Reverse, senseloss4(ForwardSensitivity()), + u0p)≈dup + using ReverseDiff, OrdinaryDiffEq, SciMLSensitivity, Test solvealg_test = Tsit5() @@ -145,6 +171,11 @@ res4 = ReverseDiff.gradient(loss2, p0) @test res1≈res3 atol=1e-14 @test res2≈res4 atol=1e-14 +res1 = loss(p0) +res2 = Enzyme.gradient(Reverse, loss, p0) +res3 = loss2(p0) +res4 = Enzyme.gradient(Reverse, loss2, p0) + # Test for recursion https://discourse.julialang.org/t/diffeqsensitivity-jl-issues-with-reversediffadjoint-sensealg/88774 function ode!(derivative, state, parameters, t) derivative .= parameters @@ -163,3 +194,4 @@ end const initial_state = ones(2) const solution_times = [1.0, 2.0] ReverseDiff.gradient(p -> sum(sum(solve_euler(initial_state, solution_times, p))), zeros(2)) +# Enzyme.gradient(Reverse, p -> sum(sum(solve_euler(initial_state, solution_times, p))), zeros(2))