Skip to content

Commit

Permalink
Copy over work from Mooncake PR
Browse files Browse the repository at this point in the history
  • Loading branch information
willtebbutt committed Nov 19, 2024
1 parent 99156d1 commit 444ebba
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 1 deletion.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
PreallocationTools = "d236fae5-4411-538c-8e31-a6e3d9e00b46"
QuadGK = "1fd47b50-473d-5c70-9696-f719f8f3bcdc"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Expand Down Expand Up @@ -71,6 +72,7 @@ LinearSolve = "2"
Lux = "1"
Markdown = "1.10"
ModelingToolkit = "9.42"
Mooncake = "0.4.44"
NLsolve = "4.5.1"
NonlinearSolve = "3.0.1"
Optimization = "4"
Expand Down
2 changes: 2 additions & 0 deletions src/SciMLSensitivity.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ using ChainRulesCore: unthunk, @thunk, NoTangent, @not_implemented, Tangent, Zer
using Enzyme: Enzyme
using FiniteDiff: FiniteDiff
using ForwardDiff: ForwardDiff
using Mooncake: Mooncake
using Tracker: Tracker, TrackedArray
using ReverseDiff: ReverseDiff
using Zygote: Zygote
Expand Down Expand Up @@ -78,6 +79,7 @@ include("concrete_solve.jl")
include("second_order.jl")
include("steadystate_adjoint.jl")
include("sde_tools.jl")
include("tmp_mooncake_rules.jl")

export extract_local_sensitivities

Expand Down
15 changes: 15 additions & 0 deletions src/tmp_mooncake_rules.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# This file will be removed, and put in an extension in DiffEqBase before merging.
Mooncake.@from_rrule(
Mooncake.MinimalCtx,
Tuple{
typeof(DiffEqBase.solve_up),
DiffEqBase.AbstractDEProblem,
Union{Nothing, DiffEqBase.AbstractSensitivityAlgorithm},
Any,
Any,
Any,
},
true,
)

Mooncake.@zero_adjoint Mooncake.MinimalCtx Tuple{typeof(DiffEqBase.numargs), Any}
36 changes: 35 additions & 1 deletion test/alternative_ad_frontend.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
using OrdinaryDiffEq, SciMLSensitivity, ForwardDiff, Zygote, ReverseDiff, Tracker, Enzyme,
FiniteDiff
FiniteDiff, Mooncake
using Test
Enzyme.API.typeWarning!(false)

mooncake_gradient(f, x) = Mooncake.value_and_gradient!!(build_rrule(f, x), f, x)[2][2]

odef(du, u, p, t) = du .= u .* p
const prob = ODEProblem(odef, [2.0], (0.0, 1.0), [3.0])

Expand All @@ -17,7 +19,9 @@ u0p = [2.0, 3.0]
du0p = zeros(2)
dup = Zygote.gradient(senseloss0(InterpolatingAdjoint()), u0p)[1]
Enzyme.autodiff(Reverse, senseloss0(InterpolatingAdjoint()), Active, Duplicated(u0p, du0p))
dup_mc = mooncake_gradient(senseloss0(InterpolatingAdjoint()), u0p)
@test du0p dup
@test dup_mc dup

struct senseloss{T}
sense::T
Expand Down Expand Up @@ -56,6 +60,12 @@ dup = Zygote.gradient(senseloss(InterpolatingAdjoint()), u0p)[1]
@test only(Enzyme.gradient(Reverse, senseloss(ForwardDiffSensitivity()), u0p)) dup
@test_broken only(Enzyme.gradient(Reverse, senseloss(ForwardSensitivity()), u0p)) dup # broken because ForwardSensitivity not compatible with perturbing u0

@test mooncake_gradient(senseloss(InterpolatingAdjoint()), u0p) dup
@test_broken mooncake_gradient(senseloss(ReverseDiffAdjoint()), u0p) dup
@test_broken mooncake_gradient(senseloss(TrackerAdjoint()), u0p) dup
@test mooncake_gradient(senseloss(ForwardDiffSensitivity()), u0p) dup
@test_broken mooncake_gradient(senseloss(ForwardSensitivity()), u0p) dup # broken because ForwardSensitivity not compatible with perturbing u0

struct senseloss2{T}
sense::T
end
Expand Down Expand Up @@ -90,6 +100,12 @@ dup = Zygote.gradient(senseloss2(InterpolatingAdjoint()), u0p)[1]
@test_broken only(Enzyme.gradient(Reverse, senseloss2(ForwardDiffSensitivity()), u0p)) dup
@test_broken only(Enzyme.gradient(Reverse, senseloss2(ForwardSensitivity()), u0p)) dup # broken because ForwardSensitivity not compatible with perturbing u0

@test mooncake_gradient(senseloss2(InterpolatingAdjoint()), u0p) dup
@test_broken mooncake_gradient(senseloss2(ReverseDiffAdjoint()), u0p) dup
@test_broken mooncake_gradient(senseloss2(TrackerAdjoint()), u0p) dup
@test mooncake_gradient(senseloss2(ForwardDiffSensitivity()), u0p) dup
@test_broken mooncake_gradient(senseloss2(ForwardSensitivity()), u0p) dup # broken because ForwardSensitivity not compatible with perturbing u0

struct senseloss3{T}
sense::T
end
Expand Down Expand Up @@ -122,6 +138,12 @@ dup = Zygote.gradient(senseloss3(InterpolatingAdjoint()), u0p)[1]
@test_broken only(Enzyme.gradient(Reverse, senseloss3(ForwardDiffSensitivity()), u0p)) dup
@test_broken only(Enzyme.gradient(Reverse, senseloss3(ForwardSensitivity()), u0p)) dup

@test mooncake_gradient(senseloss3(InterpolatingAdjoint()), u0p) dup
@test_broken mooncake_gradient(senseloss3(ReverseDiffAdjoint()), u0p) dup
@test_broken mooncake_gradient(senseloss3(TrackerAdjoint()), u0p) dup
@test mooncake_gradient(senseloss3(ForwardDiffSensitivity()), u0p) dup
@test_broken mooncake_gradient(senseloss3(ForwardSensitivity()), u0p) dup

struct senseloss4{T}
sense::T
end
Expand Down Expand Up @@ -156,6 +178,12 @@ dup = Zygote.gradient(senseloss4(InterpolatingAdjoint()), u0p)[1]
@test only(Enzyme.gradient(Reverse, senseloss4(ForwardDiffSensitivity()), u0p)) dup
@test_broken only(Enzyme.gradient(Reverse, senseloss4(ForwardSensitivity()), u0p)) dup # broken because ForwardSensitivity not compatible with perturbing u0

@test mooncake_gradient(senseloss4(InterpolatingAdjoint()), u0p) dup
@test_broken mooncake_gradient(senseloss4(ReverseDiffAdjoint()), u0p) dup
@test_broken mooncake_gradient(senseloss4(TrackerAdjoint()), u0p) dup
@test mooncake_gradient(senseloss4(ForwardDiffSensitivity()), u0p) dup
@test_broken mooncake_gradient(senseloss4(ForwardSensitivity()), u0p) dup

solvealg_test = Tsit5()
sensealg_test = InterpolatingAdjoint()
tspan = (0.0, 1.0)
Expand Down Expand Up @@ -186,6 +214,10 @@ res4 = ReverseDiff.gradient(loss2, p0)
@test_broken res2Enzyme.gradient(Reverse, loss, p0) atol=1e-14
@test_broken res4Enzyme.gradient(Reverse, loss2, p0) atol=1e-14

# I think we're just not successfully hitting the rrule here.
@test_broken res2 mooncake_gradient(loss, p0)
res4 mooncake_gradient(loss2, p0)

# Test for recursion https://discourse.julialang.org/t/diffeqsensitivity-jl-issues-with-reversediffadjoint-sensealg/88774
function ode!(derivative, state, parameters, t)
derivative .= parameters
Expand All @@ -205,6 +237,7 @@ const initial_state = ones(2)
const solution_times = [1.0, 2.0]
ReverseDiff.gradient(p -> sum(sum(solve_euler(initial_state, solution_times, p))), zeros(2))
# Enzyme.gradient(Reverse, p -> sum(sum(solve_euler(initial_state, solution_times, p))), zeros(2))
# mooncake_gradient(p -> sum(sum(solve_euler(initial_state, solution_times, p))), zeros(2))

# https://github.com/SciML/SciMLSensitivity.jl/issues/943

Expand Down Expand Up @@ -249,3 +282,4 @@ grad_rd = ReverseDiff.gradient(loss2, p)
@test grad_fdgrad_fi atol=1e-2
@test grad_fd grad_zg atol=1e-4
@test grad_fd grad_rd atol=1e-4
@test_broken mooncake_gradient(loss2, p) grad_rd atol=1e-4 # appears to not be hitting the rule

0 comments on commit 444ebba

Please sign in to comment.