diff --git a/Project.toml b/Project.toml index c4706497a..f1ced38c1 100644 --- a/Project.toml +++ b/Project.toml @@ -24,6 +24,7 @@ GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae" Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a" +Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" PreallocationTools = "d236fae5-4411-538c-8e31-a6e3d9e00b46" QuadGK = "1fd47b50-473d-5c70-9696-f719f8f3bcdc" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" diff --git a/src/SciMLSensitivity.jl b/src/SciMLSensitivity.jl index d05c054f4..e50da5256 100644 --- a/src/SciMLSensitivity.jl +++ b/src/SciMLSensitivity.jl @@ -44,6 +44,7 @@ using ChainRulesCore: unthunk, @thunk, NoTangent, @not_implemented, Tangent, Zer using Enzyme: Enzyme using FiniteDiff: FiniteDiff using ForwardDiff: ForwardDiff +using Mooncake: Mooncake using Tracker: Tracker, TrackedArray using ReverseDiff: ReverseDiff using Zygote: Zygote @@ -78,6 +79,7 @@ include("concrete_solve.jl") include("second_order.jl") include("steadystate_adjoint.jl") include("sde_tools.jl") +include("tmp_mooncake_rules.jl") export extract_local_sensitivities diff --git a/src/tmp_mooncake_rules.jl b/src/tmp_mooncake_rules.jl new file mode 100644 index 000000000..84368ab49 --- /dev/null +++ b/src/tmp_mooncake_rules.jl @@ -0,0 +1,15 @@ +# This file will be removed, and put in an extension in DiffEqBase before merging. +Mooncake.@from_rrule( + Mooncake.MinimalCtx, + Tuple{ + typeof(DiffEqBase.solve_up), + DiffEqBase.AbstractDEProblem, + Union{Nothing, DiffEqBase.AbstractSensitivityAlgorithm}, + Any, + Any, + Any, + }, + true, +) + +Mooncake.@zero_adjoint Mooncake.MinimalCtx Tuple{typeof(DiffEqBase.numargs), Any} \ No newline at end of file diff --git a/test/alternative_ad_frontend.jl b/test/alternative_ad_frontend.jl index 79b45009c..f59585927 100644 --- a/test/alternative_ad_frontend.jl +++ b/test/alternative_ad_frontend.jl @@ -1,8 +1,10 @@ using OrdinaryDiffEq, SciMLSensitivity, ForwardDiff, Zygote, ReverseDiff, Tracker, Enzyme, - FiniteDiff + FiniteDiff, Mooncake using Test Enzyme.API.typeWarning!(false) +mooncake_gradient(f, x) = Mooncake.value_and_gradient!!(build_rrule(f, x), f, x)[2][2] + odef(du, u, p, t) = du .= u .* p const prob = ODEProblem(odef, [2.0], (0.0, 1.0), [3.0]) @@ -17,7 +19,9 @@ u0p = [2.0, 3.0] du0p = zeros(2) dup = Zygote.gradient(senseloss0(InterpolatingAdjoint()), u0p)[1] Enzyme.autodiff(Reverse, senseloss0(InterpolatingAdjoint()), Active, Duplicated(u0p, du0p)) +dup_mc = mooncake_gradient(senseloss0(InterpolatingAdjoint()), u0p) @test du0p ≈ dup +@test dup_mc ≈ dup struct senseloss{T} sense::T @@ -56,6 +60,12 @@ dup = Zygote.gradient(senseloss(InterpolatingAdjoint()), u0p)[1] @test only(Enzyme.gradient(Reverse, senseloss(ForwardDiffSensitivity()), u0p)) ≈ dup @test_broken only(Enzyme.gradient(Reverse, senseloss(ForwardSensitivity()), u0p)) ≈ dup # broken because ForwardSensitivity not compatible with perturbing u0 +@test mooncake_gradient(senseloss(InterpolatingAdjoint()), u0p) ≈ dup +@test_broken mooncake_gradient(senseloss(ReverseDiffAdjoint()), u0p) ≈ dup +@test_broken mooncake_gradient(senseloss(TrackerAdjoint()), u0p) ≈ dup +@test mooncake_gradient(senseloss(ForwardDiffSensitivity()), u0p) ≈ dup +@test_broken mooncake_gradient(senseloss(ForwardSensitivity()), u0p) ≈ dup # broken because ForwardSensitivity not compatible with perturbing u0 + struct senseloss2{T} sense::T end @@ -90,6 +100,12 @@ dup = Zygote.gradient(senseloss2(InterpolatingAdjoint()), u0p)[1] @test_broken only(Enzyme.gradient(Reverse, senseloss2(ForwardDiffSensitivity()), u0p)) ≈ dup @test_broken only(Enzyme.gradient(Reverse, senseloss2(ForwardSensitivity()), u0p)) ≈ dup # broken because ForwardSensitivity not compatible with perturbing u0 +@test mooncake_gradient(senseloss2(InterpolatingAdjoint()), u0p) ≈ dup +@test_broken mooncake_gradient(senseloss2(ReverseDiffAdjoint()), u0p) ≈ dup +@test_broken mooncake_gradient(senseloss2(TrackerAdjoint()), u0p) ≈ dup +@test mooncake_gradient(senseloss2(ForwardDiffSensitivity()), u0p) ≈ dup +@test_broken mooncake_gradient(senseloss2(ForwardSensitivity()), u0p) ≈ dup # broken because ForwardSensitivity not compatible with perturbing u0 + struct senseloss3{T} sense::T end @@ -122,6 +138,12 @@ dup = Zygote.gradient(senseloss3(InterpolatingAdjoint()), u0p)[1] @test_broken only(Enzyme.gradient(Reverse, senseloss3(ForwardDiffSensitivity()), u0p)) ≈ dup @test_broken only(Enzyme.gradient(Reverse, senseloss3(ForwardSensitivity()), u0p)) ≈ dup +@test mooncake_gradient(senseloss3(InterpolatingAdjoint()), u0p) ≈ dup +@test_broken mooncake_gradient(senseloss3(ReverseDiffAdjoint()), u0p) ≈ dup +@test_broken mooncake_gradient(senseloss3(TrackerAdjoint()), u0p) ≈ dup +@test mooncake_gradient(senseloss3(ForwardDiffSensitivity()), u0p) ≈ dup +@test_broken mooncake_gradient(senseloss3(ForwardSensitivity()), u0p) ≈ dup + struct senseloss4{T} sense::T end @@ -156,6 +178,12 @@ dup = Zygote.gradient(senseloss4(InterpolatingAdjoint()), u0p)[1] @test only(Enzyme.gradient(Reverse, senseloss4(ForwardDiffSensitivity()), u0p)) ≈ dup @test_broken only(Enzyme.gradient(Reverse, senseloss4(ForwardSensitivity()), u0p)) ≈ dup # broken because ForwardSensitivity not compatible with perturbing u0 +@test mooncake_gradient(senseloss4(InterpolatingAdjoint()), u0p) ≈ dup +@test_broken mooncake_gradient(senseloss4(ReverseDiffAdjoint()), u0p) ≈ dup +@test_broken mooncake_gradient(senseloss4(TrackerAdjoint()), u0p) ≈ dup +@test mooncake_gradient(senseloss4(ForwardDiffSensitivity()), u0p) ≈ dup +@test_broken mooncake_gradient(senseloss4(ForwardSensitivity()), u0p) ≈ dup + solvealg_test = Tsit5() sensealg_test = InterpolatingAdjoint() tspan = (0.0, 1.0) @@ -186,6 +214,10 @@ res4 = ReverseDiff.gradient(loss2, p0) @test_broken res2≈Enzyme.gradient(Reverse, loss, p0) atol=1e-14 @test_broken res4≈Enzyme.gradient(Reverse, loss2, p0) atol=1e-14 +# I think we're just not successfully hitting the rrule here. +@test_broken res2 ≈ mooncake_gradient(loss, p0) +res4 ≈ mooncake_gradient(loss2, p0) + # Test for recursion https://discourse.julialang.org/t/diffeqsensitivity-jl-issues-with-reversediffadjoint-sensealg/88774 function ode!(derivative, state, parameters, t) derivative .= parameters @@ -205,6 +237,7 @@ const initial_state = ones(2) const solution_times = [1.0, 2.0] ReverseDiff.gradient(p -> sum(sum(solve_euler(initial_state, solution_times, p))), zeros(2)) # Enzyme.gradient(Reverse, p -> sum(sum(solve_euler(initial_state, solution_times, p))), zeros(2)) +# mooncake_gradient(p -> sum(sum(solve_euler(initial_state, solution_times, p))), zeros(2)) # https://github.com/SciML/SciMLSensitivity.jl/issues/943 @@ -249,3 +282,4 @@ grad_rd = ReverseDiff.gradient(loss2, p) @test grad_fd≈grad_fi atol=1e-2 @test grad_fd ≈ grad_zg atol=1e-4 @test grad_fd ≈ grad_rd atol=1e-4 +@test_broken mooncake_gradient(loss2, p) ≈ grad_rd atol=1e-4 # appears to not be hitting the rule