Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Mooncake to Alternative AD Frontends #1151

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
PreallocationTools = "d236fae5-4411-538c-8e31-a6e3d9e00b46"
QuadGK = "1fd47b50-473d-5c70-9696-f719f8f3bcdc"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Expand Down
2 changes: 2 additions & 0 deletions src/SciMLSensitivity.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ using ChainRulesCore: unthunk, @thunk, NoTangent, @not_implemented, Tangent, Zer
using Enzyme: Enzyme
using FiniteDiff: FiniteDiff
using ForwardDiff: ForwardDiff
using Mooncake: Mooncake
using Tracker: Tracker, TrackedArray
using ReverseDiff: ReverseDiff
using Zygote: Zygote
Expand Down Expand Up @@ -78,6 +79,7 @@ include("concrete_solve.jl")
include("second_order.jl")
include("steadystate_adjoint.jl")
include("sde_tools.jl")
include("tmp_mooncake_rules.jl")

export extract_local_sensitivities

Expand Down
15 changes: 15 additions & 0 deletions src/tmp_mooncake_rules.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# This file will be removed, and put in an extension in DiffEqBase before merging.
Mooncake.@from_rrule(
Mooncake.MinimalCtx,
Tuple{
typeof(DiffEqBase.solve_up),
DiffEqBase.AbstractDEProblem,
Union{Nothing, DiffEqBase.AbstractSensitivityAlgorithm},
Any,
Any,
Any,
},
true,
)

Mooncake.@zero_adjoint Mooncake.MinimalCtx Tuple{typeof(DiffEqBase.numargs), Any}
36 changes: 35 additions & 1 deletion test/alternative_ad_frontend.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
using OrdinaryDiffEq, SciMLSensitivity, ForwardDiff, Zygote, ReverseDiff, Tracker, Enzyme,
FiniteDiff
FiniteDiff, Mooncake
using Test
Enzyme.API.typeWarning!(false)

mooncake_gradient(f, x) = Mooncake.value_and_gradient!!(build_rrule(f, x), f, x)[2][2]

odef(du, u, p, t) = du .= u .* p
const prob = ODEProblem(odef, [2.0], (0.0, 1.0), [3.0])

Expand All @@ -17,7 +19,9 @@ u0p = [2.0, 3.0]
du0p = zeros(2)
dup = Zygote.gradient(senseloss0(InterpolatingAdjoint()), u0p)[1]
Enzyme.autodiff(Reverse, senseloss0(InterpolatingAdjoint()), Active, Duplicated(u0p, du0p))
dup_mc = mooncake_gradient(senseloss0(InterpolatingAdjoint()), u0p)
@test du0p ≈ dup
@test dup_mc ≈ dup

struct senseloss{T}
sense::T
Expand Down Expand Up @@ -56,6 +60,12 @@ dup = Zygote.gradient(senseloss(InterpolatingAdjoint()), u0p)[1]
@test only(Enzyme.gradient(Reverse, senseloss(ForwardDiffSensitivity()), u0p)) ≈ dup
@test_broken only(Enzyme.gradient(Reverse, senseloss(ForwardSensitivity()), u0p)) ≈ dup # broken because ForwardSensitivity not compatible with perturbing u0

@test mooncake_gradient(senseloss(InterpolatingAdjoint()), u0p) ≈ dup
@test_broken mooncake_gradient(senseloss(ReverseDiffAdjoint()), u0p) ≈ dup
@test_broken mooncake_gradient(senseloss(TrackerAdjoint()), u0p) ≈ dup
@test mooncake_gradient(senseloss(ForwardDiffSensitivity()), u0p) ≈ dup
@test_broken mooncake_gradient(senseloss(ForwardSensitivity()), u0p) ≈ dup # broken because ForwardSensitivity not compatible with perturbing u0

struct senseloss2{T}
sense::T
end
Expand Down Expand Up @@ -90,6 +100,12 @@ dup = Zygote.gradient(senseloss2(InterpolatingAdjoint()), u0p)[1]
@test_broken only(Enzyme.gradient(Reverse, senseloss2(ForwardDiffSensitivity()), u0p)) ≈ dup
@test_broken only(Enzyme.gradient(Reverse, senseloss2(ForwardSensitivity()), u0p)) ≈ dup # broken because ForwardSensitivity not compatible with perturbing u0

@test mooncake_gradient(senseloss2(InterpolatingAdjoint()), u0p) ≈ dup
@test_broken mooncake_gradient(senseloss2(ReverseDiffAdjoint()), u0p) ≈ dup
@test_broken mooncake_gradient(senseloss2(TrackerAdjoint()), u0p) ≈ dup
@test mooncake_gradient(senseloss2(ForwardDiffSensitivity()), u0p) ≈ dup
@test_broken mooncake_gradient(senseloss2(ForwardSensitivity()), u0p) ≈ dup # broken because ForwardSensitivity not compatible with perturbing u0

struct senseloss3{T}
sense::T
end
Expand Down Expand Up @@ -122,6 +138,12 @@ dup = Zygote.gradient(senseloss3(InterpolatingAdjoint()), u0p)[1]
@test_broken only(Enzyme.gradient(Reverse, senseloss3(ForwardDiffSensitivity()), u0p)) ≈ dup
@test_broken only(Enzyme.gradient(Reverse, senseloss3(ForwardSensitivity()), u0p)) ≈ dup

@test mooncake_gradient(senseloss3(InterpolatingAdjoint()), u0p) ≈ dup
@test_broken mooncake_gradient(senseloss3(ReverseDiffAdjoint()), u0p) ≈ dup
@test_broken mooncake_gradient(senseloss3(TrackerAdjoint()), u0p) ≈ dup
@test mooncake_gradient(senseloss3(ForwardDiffSensitivity()), u0p) ≈ dup
@test_broken mooncake_gradient(senseloss3(ForwardSensitivity()), u0p) ≈ dup

struct senseloss4{T}
sense::T
end
Expand Down Expand Up @@ -156,6 +178,12 @@ dup = Zygote.gradient(senseloss4(InterpolatingAdjoint()), u0p)[1]
@test only(Enzyme.gradient(Reverse, senseloss4(ForwardDiffSensitivity()), u0p)) ≈ dup
@test_broken only(Enzyme.gradient(Reverse, senseloss4(ForwardSensitivity()), u0p)) ≈ dup # broken because ForwardSensitivity not compatible with perturbing u0

@test mooncake_gradient(senseloss4(InterpolatingAdjoint()), u0p) ≈ dup
@test_broken mooncake_gradient(senseloss4(ReverseDiffAdjoint()), u0p) ≈ dup
@test_broken mooncake_gradient(senseloss4(TrackerAdjoint()), u0p) ≈ dup
@test mooncake_gradient(senseloss4(ForwardDiffSensitivity()), u0p) ≈ dup
@test_broken mooncake_gradient(senseloss4(ForwardSensitivity()), u0p) ≈ dup

solvealg_test = Tsit5()
sensealg_test = InterpolatingAdjoint()
tspan = (0.0, 1.0)
Expand Down Expand Up @@ -186,6 +214,10 @@ res4 = ReverseDiff.gradient(loss2, p0)
@test_broken res2≈Enzyme.gradient(Reverse, loss, p0) atol=1e-14
@test_broken res4≈Enzyme.gradient(Reverse, loss2, p0) atol=1e-14

# I think we're just not successfully hitting the rrule here.
@test_broken res2 ≈ mooncake_gradient(loss, p0)
res4 ≈ mooncake_gradient(loss2, p0)

# Test for recursion https://discourse.julialang.org/t/diffeqsensitivity-jl-issues-with-reversediffadjoint-sensealg/88774
function ode!(derivative, state, parameters, t)
derivative .= parameters
Expand All @@ -205,6 +237,7 @@ const initial_state = ones(2)
const solution_times = [1.0, 2.0]
ReverseDiff.gradient(p -> sum(sum(solve_euler(initial_state, solution_times, p))), zeros(2))
# Enzyme.gradient(Reverse, p -> sum(sum(solve_euler(initial_state, solution_times, p))), zeros(2))
# mooncake_gradient(p -> sum(sum(solve_euler(initial_state, solution_times, p))), zeros(2))

# https://github.com/SciML/SciMLSensitivity.jl/issues/943

Expand Down Expand Up @@ -249,3 +282,4 @@ grad_rd = ReverseDiff.gradient(loss2, p)
@test grad_fd≈grad_fi atol=1e-2
@test grad_fd ≈ grad_zg atol=1e-4
@test grad_fd ≈ grad_rd atol=1e-4
@test_broken mooncake_gradient(loss2, p) ≈ grad_rd atol=1e-4 # appears to not be hitting the rule
Loading