Skip to content

Commit

Permalink
Setup tests for driving with Enzyme
Browse files Browse the repository at this point in the history
  • Loading branch information
ChrisRackauckas committed Oct 7, 2023
1 parent 29f85a0 commit e2e82a8
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 9 deletions.
1 change: 1 addition & 0 deletions docs/src/manual/differential_equation_sensitivities.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ requiring the user to do any of the setup.
Current AD libraries whose calls are captured by the sensitivity
system are:

- [Enzyme.jl](https://github.com/EnzymeAD/Enzyme.jl)
- [Zygote.jl](https://fluxml.ai/Zygote.jl/stable/)
- [Diffractor.jl](https://github.com/JuliaDiff/Diffractor.jl)
- [ReverseDiff.jl](https://github.com/JuliaDiff/ReverseDiff.jl)
Expand Down
50 changes: 41 additions & 9 deletions test/alternative_ad_frontend.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
using OrdinaryDiffEq, SciMLSensitivity, ForwardDiff, Zygote, ReverseDiff, Tracker
using OrdinaryDiffEq, SciMLSensitivity, ForwardDiff, Zygote, ReverseDiff, Tracker, Enzyme
using Test

prob = ODEProblem((u, p, t) -> u .* p, [2.0], (0.0, 1.0), [3.0])

struct senseloss
sense::Any
struct senseloss{T}
sense::T
end
function (f::senseloss)(u0p)
sum(solve(prob, Tsit5(), u0 = u0p[1:1], p = u0p[2:2], abstol = 1e-12,
Expand Down Expand Up @@ -34,8 +34,15 @@ dup = Zygote.gradient(senseloss(InterpolatingAdjoint()), u0p)[1]

@test ForwardDiff.gradient(senseloss(InterpolatingAdjoint()), u0p) dup

struct senseloss2
sense::Any
@test Enzyme.gradient(Reverse, senseloss(InterpolatingAdjoint()), u0p) dup
@test_broken Enzyme.gradient(Reverse, senseloss(ReverseDiffAdjoint()), u0p) dup
@test Enzyme.gradient(Reverse, senseloss(TrackerAdjoint()), u0p) dup
@test Enzyme.gradient(Reverse, senseloss(ForwardDiffSensitivity()), u0p) dup
@test_throws SciMLSensitivity.ForwardSensitivityOutOfPlaceError ReverseDiff.gradient(Reverse, senseloss(ForwardSensitivity()),
u0p)dup

struct senseloss2{T}
sense::T
end
prob2 = ODEProblem((du, u, p, t) -> du .= u .* p, [2.0], (0.0, 1.0), [3.0])

Expand All @@ -62,8 +69,14 @@ dup = Zygote.gradient(senseloss2(InterpolatingAdjoint()), u0p)[1]

@test ForwardDiff.gradient(senseloss2(InterpolatingAdjoint()), u0p) dup

struct senseloss3
sense::Any
@test Enzyme.gradient(Reverse, senseloss2(InterpolatingAdjoint()), u0p) dup
@test Enzyme.gradient(Reverse, senseloss2(ReverseDiffAdjoint()), u0p) dup
@test Enzyme.gradient(Reverse, senseloss2(TrackerAdjoint()), u0p) dup
@test Enzyme.gradient(Reverse, senseloss2(ForwardDiffSensitivity()), u0p) dup
@test_broken Enzyme.gradient(Reverse, senseloss2(ForwardSensitivity()), u0p) dup

struct senseloss3{T}
sense::T
end
function (f::senseloss3)(u0p)
sum(solve(prob2, Tsit5(), p = u0p, abstol = 1e-12,
Expand All @@ -88,8 +101,14 @@ dup = Zygote.gradient(senseloss3(InterpolatingAdjoint()), u0p)[1]

@test ForwardDiff.gradient(senseloss3(InterpolatingAdjoint()), u0p) dup

struct senseloss4
sense::Any
@test Enzyme.gradient(Reverse, senseloss3(InterpolatingAdjoint()), u0p) dup
@test Enzyme.gradient(Reverse, senseloss3(ReverseDiffAdjoint()), u0p) dup
@test Enzyme.gradient(Reverse, senseloss3(TrackerAdjoint()), u0p) dup
@test Enzyme.gradient(Reverse, senseloss3(ForwardDiffSensitivity()), u0p) dup
@test Enzyme.gradient(Reverse, senseloss3(ForwardSensitivity()), u0p) dup

struct senseloss4{T}
sense::T
end
function (f::senseloss4)(u0p)
sum(solve(prob, Tsit5(), p = u0p, abstol = 1e-12,
Expand All @@ -116,6 +135,13 @@ dup = Zygote.gradient(senseloss4(InterpolatingAdjoint()), u0p)[1]

@test ForwardDiff.gradient(senseloss4(InterpolatingAdjoint()), u0p) dup

@test Enzyme.gradient(Reverse, senseloss4(InterpolatingAdjoint()), u0p) dup
@test_broken Enzyme.gradient(Reverse, senseloss4(ReverseDiffAdjoint()), u0p) dup
@test Enzyme.gradient(Reverse, senseloss4(TrackerAdjoint()), u0p) dup
@test Enzyme.gradient(Reverse, senseloss4(ForwardDiffSensitivity()), u0p) dup
@test_throws SciMLSensitivity.ForwardSensitivityOutOfPlaceError Enzyme.gradient(Reverse, senseloss4(ForwardSensitivity()),
u0p)dup

using ReverseDiff, OrdinaryDiffEq, SciMLSensitivity, Test

solvealg_test = Tsit5()
Expand Down Expand Up @@ -145,6 +171,11 @@ res4 = ReverseDiff.gradient(loss2, p0)
@test res1res3 atol=1e-14
@test res2res4 atol=1e-14

res1 = loss(p0)
res2 = Enzyme.gradient(Reverse, loss, p0)
res3 = loss2(p0)
res4 = Enzyme.gradient(Reverse, loss2, p0)

# Test for recursion https://discourse.julialang.org/t/diffeqsensitivity-jl-issues-with-reversediffadjoint-sensealg/88774
function ode!(derivative, state, parameters, t)
derivative .= parameters
Expand All @@ -163,3 +194,4 @@ end
const initial_state = ones(2)
const solution_times = [1.0, 2.0]
ReverseDiff.gradient(p -> sum(sum(solve_euler(initial_state, solution_times, p))), zeros(2))
# Enzyme.gradient(Reverse, p -> sum(sum(solve_euler(initial_state, solution_times, p))), zeros(2))

0 comments on commit e2e82a8

Please sign in to comment.