Skip to content

Commit

Permalink
Merge pull request #1098 from SciML/fake_enzyme
Browse files Browse the repository at this point in the history
Fake solver output type to please Enzyme
  • Loading branch information
ChrisRackauckas authored Aug 28, 2024
2 parents a7a8fbf + af2b880 commit 7c6bf22
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 0 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ authors = ["Christopher Rackauckas <[email protected]>", "Yingbo Ma <
version = "7.65.0"

[deps]
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
Expand Down Expand Up @@ -46,6 +47,7 @@ TruncatedStacktraces = "781d530d-4396-4725-bb49-402e4bee1e77"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
Accessors = "0.1.36"
ADTypes = "0.1, 0.2, 1"
Adapt = "1.0, 2.0, 3.0, 4"
AlgebraicMultigrid = "0.6.0"
Expand Down
1 change: 1 addition & 0 deletions src/SciMLSensitivity.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import SciMLBase: AbstractNonlinearProblem
using Adapt
using LinearSolve
using Parameters: @unpack
import Accessors: @reset
using StochasticDiffEq
import DiffEqNoiseProcess
import RandomNumbers: Xorshifts
Expand Down
17 changes: 17 additions & 0 deletions src/concrete_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -700,6 +700,10 @@ function DiffEqBase._concrete_solve_adjoint(prob::SciMLBase.AbstractODEProblem,
end
out = DiffEqBase.sensitivity_solution(sol, u, ts)

if originator isa SciMLBase.EnzymeOriginator
@reset out.prob = prob
end

function forward_sensitivity_backpass(Δ)
adj = sum(eachindex(du)) do i
J = du[i]
Expand Down Expand Up @@ -737,6 +741,11 @@ function DiffEqBase._concrete_solve_forward(prob::SciMLBase.AbstractODEProblem,
_prob = ODEForwardSensitivityProblem(
prob.f, u0, prob.tspan, p, sensealg, callback = nothing)
sol = solve(_prob, args...; kwargs...)

if originator isa SciMLBase.EnzymeOriginator
@reset sol.prob = prob
end

u, du = extract_local_sensitivities(sol, Val(true))
_save_idxs = save_idxs === nothing ? (1:length(u0)) : save_idxs
ts = current_time(sol)
Expand Down Expand Up @@ -799,6 +808,10 @@ function DiffEqBase._concrete_solve_adjoint(
sol = solve(remake(prob, p = p, u0 = u0, callback = nothing),
alg, args...; saveat = _saveat, kwargs...)

if originator isa SciMLBase.EnzymeOriginator
@reset sol.prob = prob
end

# saveat values
# need all values here. Not only unique ones.
# if a callback is saving two times in primal solution, we also need to get it at least
Expand Down Expand Up @@ -1309,6 +1322,10 @@ function DiffEqBase._concrete_solve_adjoint(
kwargs_filtered...)
sol = SciMLBase.sensitivity_solution(sol, state_values(sol), current_time(sol))

if originator isa SciMLBase.EnzymeOriginator
@reset sol.prob = prob
end

if state_values(sol, 1) isa Array
return Array(sol)
else
Expand Down

0 comments on commit 7c6bf22

Please sign in to comment.