From b0bff4153b9a416b58a61ce7e574a312ecf7c7b4 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Fri, 3 Nov 2023 05:53:45 -0400 Subject: [PATCH 1/2] Remove AD piracy functions by moving to SciMLBase This also bumps to v1.9 because of SciMLBase, and thus drops the extra Requires.jl stuff --- Project.toml | 25 ++-- ext/DiffEqBaseChainRulesCoreExt.jl | 11 ++ ext/DiffEqBaseZygoteExt.jl | 60 --------- src/DiffEqBase.jl | 16 +-- src/chainrules.jl | 188 ----------------------------- src/init.jl | 55 --------- src/solve.jl | 11 ++ src/utils.jl | 4 + 8 files changed, 36 insertions(+), 334 deletions(-) create mode 100644 ext/DiffEqBaseChainRulesCoreExt.jl delete mode 100644 ext/DiffEqBaseZygoteExt.jl delete mode 100644 src/chainrules.jl delete mode 100644 src/init.jl diff --git a/Project.toml b/Project.toml index e3ef7cf7d..bef8bad65 100644 --- a/Project.toml +++ b/Project.toml @@ -5,9 +5,7 @@ version = "6.136.0" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" -ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" -Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" EnumX = "4e289a0a-7415-4d19-859d-a7e5c4648b56" EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" @@ -25,7 +23,6 @@ PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" -Requires = "ae029012-a4dd-5104-9daa-d747884805df" SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" SciMLOperators = "c0aeaf25-5076-4817-a8d5-81caf7dfa961" Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" @@ -35,9 +32,9 @@ StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Tricks = "410a4b4d-49e4-4fbc-ab6d-cb71b17b3775" TruncatedStacktraces = "781d530d-4396-4725-bb49-402e4bee1e77" -ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" [weakdeps] +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" GeneralizedGenerated = "6b9d7cbe-bcb9-11e9-073f-15a7a543e2eb" @@ -47,7 +44,6 @@ MonteCarloMeasurements = "0987c9cc-fe09-11e8-30f0-b96dd679fdca" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d" -Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [extensions] DiffEqBaseDistributionsExt = "Distributions" @@ -59,7 +55,6 @@ DiffEqBaseMonteCarloMeasurementsExt = "MonteCarloMeasurements" DiffEqBaseReverseDiffExt = "ReverseDiff" DiffEqBaseTrackerExt = "Tracker" DiffEqBaseUnitfulExt = "Unitful" -DiffEqBaseZygoteExt = "Zygote" [compat] ArrayInterface = "7" @@ -73,31 +68,30 @@ FastBroadcast = "0.2" ForwardDiff = "0.10" FunctionWrappers = "1.0" FunctionWrappersWrappers = "0.1" -LinearAlgebra = "1.6" -Logging = "1.6" -Markdown = "1.6" +LinearAlgebra = "1.9" +Logging = "1.9" +Markdown = "1.9" MuladdMacro = "0.2.1" Parameters = "0.12.0" PreallocationTools = "0.4" PrecompileTools = "1" -Printf = "1.6" +Printf = "1.9" RecursiveArrayTools = "2" Reexport = "1.0" -Requires = "1.0" -SciMLBase = "2.4.0" +SciMLBase = "2.7.0" SciMLOperators = "0.2, 0.3" Setfield = "0.8, 1" -SparseArrays = "1.6" +SparseArrays = "1.9" Static = "0.7, 0.8" StaticArraysCore = "1.4" Statistics = "1" Tricks = "0.1.6" TruncatedStacktraces = "1" -ZygoteRules = "0.2" -julia = "1.6" +julia = "1.9" [extras] Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" @@ -116,7 +110,6 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d" -Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] test = ["Distributed", "GeneralizedGenerated", "Measurements", "MonteCarloMeasurements", "Unitful", "LabelledArrays", "ForwardDiff", "InteractiveUtils", "Plots", "Pkg", "Random", "StaticArrays", "SafeTestsets", "Statistics", "Test", "Distributions", "Aqua"] diff --git a/ext/DiffEqBaseChainRulesCoreExt.jl b/ext/DiffEqBaseChainRulesCoreExt.jl new file mode 100644 index 000000000..06479ebfa --- /dev/null +++ b/ext/DiffEqBaseChainRulesCoreExt.jl @@ -0,0 +1,11 @@ +module DiffEqBaseChainRulesCoreExt + +using DiffEqBase +import DiffEqBase: numargs + +import ChainRulesCore +import ChainRulesCore: NoTangent + +ChainRulesCore.rrule(::typeof(numargs), f) = (numargs(f), df -> (NoTangent(), NoTangent())) + +end \ No newline at end of file diff --git a/ext/DiffEqBaseZygoteExt.jl b/ext/DiffEqBaseZygoteExt.jl deleted file mode 100644 index 503265568..000000000 --- a/ext/DiffEqBaseZygoteExt.jl +++ /dev/null @@ -1,60 +0,0 @@ -module DiffEqBaseZygoteExt - -if isdefined(Base, :get_extension) - using DiffEqBase - import DiffEqBase: value - import Zygote -else - using ..DiffEqBase - import ..DiffEqBase: value - import ..Zygote -end - -function ∇tmap(cx, f, args...) - ys_and_backs = SciMLBase.tmap((args...) -> Zygote._pullback(cx, f, args...), args...) - if isempty(ys_and_backs) - ys_and_backs, _ -> (NoTangent(), NoTangent()) - else - ys, backs = Zygote.unzip(ys_and_backs) - function ∇tmap_internal(Δ) - Δf_and_args_zipped = SciMLBase.tmap((f, δ) -> f(δ), backs, Δ) - Δf_and_args = Zygote.unzip(Δf_and_args_zipped) - Δf = reduce(Zygote.accum, Δf_and_args[1]) - (Δf, Δf_and_args[2:end]...) - end - ys, ∇tmap_internal - end -end - -function ∇responsible_map(cx, f, args...) - ys_and_backs = SciMLBase.responsible_map((args...) -> Zygote._pullback(cx, f, args...), - args...) - if isempty(ys_and_backs) - ys_and_backs, _ -> (NoTangent(), NoTangent()) - else - ys, backs = Zygote.unzip(ys_and_backs) - ys, - function ∇responsible_map_internal(Δ) - # Apply pullbacks in reverse order. Needed for correctness if `f` is stateful. - Δf_and_args_zipped = SciMLBase.responsible_map((f, δ) -> f(δ), - Zygote._tryreverse(SciMLBase.responsible_map, - backs, Δ)...) - Δf_and_args = Zygote.unzip(Zygote._tryreverse(SciMLBase.responsible_map, - Δf_and_args_zipped)) - Δf = reduce(Zygote.accum, Δf_and_args[1]) - (Δf, Δf_and_args[2:end]...) - end - end -end - -Zygote.@adjoint function SciMLBase.tmap(f, args::Union{AbstractArray, Tuple}...) - ∇tmap(__context__, f, args...) -end - -Zygote.@adjoint function SciMLBase.responsible_map(f, - args::Union{AbstractArray, Tuple - }...) - ∇responsible_map(__context__, f, args...) -end - -end diff --git a/src/DiffEqBase.jl b/src/DiffEqBase.jl index c73daadc0..c5c348d6f 100644 --- a/src/DiffEqBase.jl +++ b/src/DiffEqBase.jl @@ -3,9 +3,6 @@ if isdefined(Base, :Experimental) && isdefined(Base.Experimental, Symbol("@max_methods")) @eval Base.Experimental.@max_methods 1 end -if !isdefined(Base, :get_extension) - using Requires -end import PrecompileTools @@ -28,14 +25,10 @@ PrecompileTools.@recompile_invalidations begin using Static: reduce_tup - import ChainRulesCore import RecursiveArrayTools import SparseArrays import TruncatedStacktraces - import ChainRulesCore: NoTangent, @non_differentiable - import ZygoteRules - using Setfield using ForwardDiff @@ -140,13 +133,10 @@ include("callbacks.jl") include("common_defaults.jl") include("solve.jl") include("internal_euler.jl") -include("init.jl") include("forwarddiff.jl") -include("chainrules.jl") - include("termination_conditions.jl") - include("norecompile.jl") + # This is only used for oop stiff solvers default_factorize(A) = lu(A; check = false) @@ -181,8 +171,4 @@ export NLSolveTerminationMode, export KeywordArgError, KeywordArgWarn, KeywordArgSilent -if !isdefined(Base, :get_extension) - include("../ext/DiffEqBaseDistributionsExt.jl") -end - end # module diff --git a/src/chainrules.jl b/src/chainrules.jl deleted file mode 100644 index 5e376f2bf..000000000 --- a/src/chainrules.jl +++ /dev/null @@ -1,188 +0,0 @@ -function ChainRulesCore.rrule(::Type{ODEProblem}, args...; kwargs...) - function ODEProblemAdjoint(ȳ) - (NoTangent(), ȳ.f, ȳ.u0, ȳ.tspan, ȳ.p, ȳ.kwargs, ȳ.problem_type) - end - - ODEProblem(args...; kwargs...), ODEProblemAdjoint -end - -function ChainRulesCore.rrule(::Type{SDEProblem}, args...; kwargs...) - function SDEProblemAdjoint(ȳ) - (NoTangent(), ȳ.f, ȳ.g, ȳ.u0, ȳ.tspan, ȳ.p, ȳ.kwargs, ȳ.problem_type) - end - - SDEProblem(args...; kwargs...), SDEProblemAdjoint -end - -function ChainRulesCore.rrule(::Type{ - <:ODESolution{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, - T11, T12, - }}, u, - args...) where {T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, - T12} - function ODESolutionAdjoint(ȳ) - (NoTangent(), ȳ, ntuple(_ -> NoTangent(), length(args))...) - end - - ODESolution{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12}(u, args...), - ODESolutionAdjoint -end - -ZygoteRules.@adjoint function ODESolution{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12 -}(u, - args...) where {T1, T2, T3, T4, T5, T6, T7, T8, - T9, T10, T11, T12} - function ODESolutionAdjoint(ȳ) - (ȳ, ntuple(_ -> nothing, length(args))...) - end - - ODESolution{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12}(u, args...), - ODESolutionAdjoint -end - -function ChainRulesCore.rrule(::Type{ - <:ODESolution{uType, tType, isinplace, P, NP, F, G, K, - ND, - }}, u, - args...) where {uType, tType, isinplace, P, NP, F, G, K, ND} - function SDESolutionAdjoint(ȳ) - (NoTangent(), ȳ, ntuple(_ -> NoTangent(), length(args))...) - end - - SDESolution{uType, tType, isinplace, P, NP, F, G, K, ND}(u, args...), SDESolutionAdjoint -end - -ZygoteRules.@adjoint function SDEProblem{uType, tType, isinplace, P, NP, F, G, K, ND}(u, - args...) where - {uType, tType, isinplace, P, NP, F, G, K, ND} - function SDESolutionAdjoint(ȳ) - (ȳ, ntuple(_ -> nothing, length(args))...) - end - - SDESolution{uType, tType, isinplace, P, NP, F, G, K, ND}(u, args...), SDESolutionAdjoint -end - -ZygoteRules.@adjoint function NonlinearSolution{T, N, uType, R, P, A, O, uType2}(u, - args...) where { - T, - N, - uType, - R, - P, - A, - O, - uType2, -} - function NonlinearSolutionAdjoint(ȳ) - (ȳ, ntuple(_ -> nothing, length(args))...) - end - NonlinearSolution{T, N, uType, R, P, A, O, uType2}(u, args...), NonlinearSolutionAdjoint -end - -ZygoteRules.@adjoint function ZygoteRules.literal_getproperty(sol::AbstractTimeseriesSolution, - ::Val{:u}) - function solu_adjoint(Δ) - zerou = zero(sol.prob.u0) - _Δ = @. ifelse(Δ == nothing, (zerou,), Δ) - (DiffEqBase.build_solution(sol.prob, sol.alg, sol.t, _Δ),) - end - sol.u, solu_adjoint -end - -ZygoteRules.@adjoint function ZygoteRules.literal_getproperty(sol::AbstractNoTimeSolution, - ::Val{:u}) - function solu_adjoint(Δ) - zerou = zero(sol.prob.u0) - _Δ = @. ifelse(Δ == nothing, zerou, Δ) - (DiffEqBase.build_solution(sol.prob, sol.alg, _Δ, sol.resid),) - end - sol.u, solu_adjoint -end - -ZygoteRules.@adjoint function ZygoteRules.literal_getproperty(sol::SciMLBase.OptimizationSolution, - ::Val{:u}) - function solu_adjoint(Δ) - zerou = zero(sol.u) - _Δ = @. ifelse(Δ == nothing, zerou, Δ) - (DiffEqBase.build_solution(sol.cache, sol.alg, _Δ, sol.objective),) - end - sol.u, solu_adjoint -end - -function ChainRulesCore.rrule(::DiffEqBase.EnsembleSolution, sim, time, converged) - out = EnsembleSolution(sim, time, converged) - function EnsembleSolution_adjoint(p̄::AbstractArray{T, N}) where {T, N} - arrarr = [[p̄[ntuple(x -> Colon(), Val(N - 2))..., j, i] - for j in 1:size(p̄)[end - 1]] for i in 1:size(p̄)[end]] - (NoTangent(), EnsembleSolution(arrarr, 0.0, true), NoTangent(), NoTangent()) - end - function EnsembleSolution_adjoint(p̄::EnsembleSolution) - (NoTangent(), p̄, NoTangent(), NoTangent()) - end - out, EnsembleSolution_adjoint -end - -ZygoteRules.@adjoint function DiffEqBase.EnsembleSolution(sim, time, converged) - out = EnsembleSolution(sim, time, converged) - function EnsembleSolution_adjoint(p̄::AbstractArray{T, N}) where {T, N} - arrarr = [[p̄[ntuple(x -> Colon(), Val(N - 2))..., j, i] - for j in 1:size(p̄)[end - 1]] for i in 1:size(p̄)[end]] - (EnsembleSolution(arrarr, 0.0, true), nothing, nothing) - end - function EnsembleSolution_adjoint(p̄::AbstractArray{<:AbstractArray, 1}) - (EnsembleSolution(p̄, 0.0, true), nothing, nothing) - end - function EnsembleSolution_adjoint(p̄::EnsembleSolution) - (p̄, nothing, nothing) - end - out, EnsembleSolution_adjoint -end - -ZygoteRules.@adjoint function ZygoteRules.literal_getproperty(sim::EnsembleSolution, - ::Val{:u}) - sim.u, p̄ -> (EnsembleSolution(p̄, 0.0, true),) -end - -#= -ChainRulesCore.frule(f::ODEFunction,u,p,t) - if f.jvp === nothing - ChainRulesCore.frule(f.f,u,p,t) - else - function ode_jvp(f,du,dp,dt) - f.jvp_u(du,u,p,t) + f.jvp_p(dp,u,p,t) + f.tgrad(u,p,t)*dt - end - f.f(u,p,t),ode_jvp -end -=# - -#= -function ChainRulesCore.rrule(f::ODEFunction,u,p,t) - if f.vjp === nothing - ChainRulesCore.rrule(f.f,u,p,t) - else - f.vjp(u,p,t) - end -end -=# - -#= -ZygoteRules.@adjoint function (f::ODEFunction)(u,p,t) - if f.vjp === nothing - ZygoteRules._pullback(f.f,u,p,t) - else - f.vjp(u,p,t) - end -end -=# - -#= -ZygoteRules.@adjoint! function (f::ODEFunction)(du,u,p,t) - if f.vjp === nothing - ZygoteRules._pullback(f.f,du,u,p,t) - else - f.vjp(du,u,p,t) - end -end -=# - -ChainRulesCore.rrule(::typeof(numargs), f) = (numargs(f), df -> (NoTangent(), NoTangent())) diff --git a/src/init.jl b/src/init.jl deleted file mode 100644 index ba245c09d..000000000 --- a/src/init.jl +++ /dev/null @@ -1,55 +0,0 @@ -value(x) = x -promote_tspan(u0, p, tspan, prob, kwargs) = _promote_tspan(tspan, kwargs) -function _promote_tspan(tspan, kwargs) - if (dt = get(kwargs, :dt, nothing)) !== nothing - tspan1, tspan2, _ = promote(tspan..., dt) - return (tspan1, tspan2) - else - return tspan - end -end -isdistribution(u0) = false - -function SciMLBase.tmap(args...) - error("Zygote must be added to differentiate Zygote? If you see this error, report it.") -end - -@static if !isdefined(Base, :get_extension) - function __init__() - @require Measurements="eff96d63-e80a-5855-80a2-b1b0885c5ab7" begin - include("../ext/DiffEqBaseMeasurementsExt.jl") - end - - @require MonteCarloMeasurements="0987c9cc-fe09-11e8-30f0-b96dd679fdca" begin - include("../ext/DiffEqBaseMonteCarloMeasurementsExt.jl") - end - - @require Unitful="1986cc42-f94f-5a68-af5c-568840ba703d" begin - include("../ext/DiffEqBaseUnitfulExt.jl") - end - - @require GeneralizedGenerated="6b9d7cbe-bcb9-11e9-073f-15a7a543e2eb" begin - include("../ext/DiffEqBaseGeneralizedGeneratedExt.jl") - end - - @require Tracker="9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" begin - include("../ext/DiffEqBaseTrackerExt.jl") - end - - @require ReverseDiff="37e2e3b7-166d-5795-8a7a-e32c996b4267" begin - include("../ext/DiffEqBaseReverseDiffExt.jl") - end - - @require Zygote="e88e6eb3-aa80-5325-afca-941959d7151f" begin - include("../ext/DiffEqBaseZygoteExt.jl") - end - - @require MPI="da04e1cc-30fd-572f-bb4f-1f8673147195" begin - include("../ext/DiffEqBaseMPIExt.jl") - end - - @require Enzyme="7da242da-08ed-463a-9acd-ee780be4f1d9" begin - include("../ext/DiffEqBaseEnzymeExt.jl") - end - end -end diff --git a/src/solve.jl b/src/solve.jl index d1ce09568..d44f1a524 100644 --- a/src/solve.jl +++ b/src/solve.jl @@ -1146,6 +1146,17 @@ function get_concrete_problem(prob::DDEProblem, isadapt; kwargs...) remake(prob; u0 = u0, tspan = tspan, p = p, constant_lags = constant_lags) end +# Most are extensions +promote_tspan(u0, p, tspan, prob, kwargs) = _promote_tspan(tspan, kwargs) +function _promote_tspan(tspan, kwargs) + if (dt = get(kwargs, :dt, nothing)) !== nothing + tspan1, tspan2, _ = promote(tspan..., dt) + return (tspan1, tspan2) + else + return tspan + end +end + function promote_f(f::F, ::Val{specialize}, u0, p, t) where {F, specialize} # Ensure our jacobian will be of the same type as u0 uElType = u0 === nothing ? Float64 : eltype(u0) diff --git a/src/utils.jl b/src/utils.jl index 9274008fd..f6ccbcc5a 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1,3 +1,7 @@ +# Handled in Extensions +value(x) = x +isdistribution(u0) = false + _vec(v) = vec(v) _vec(v::Number) = v _vec(v::AbstractSciMLScalarOperator) = v From 042f08efde496974324787a0c1c380f4e74b06ae Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Fri, 3 Nov 2023 06:40:39 -0400 Subject: [PATCH 2/2] Remove extra chainrulescore --- ext/DiffEqBaseChainRulesCoreExt.jl | 17 +++++++++++++++++ src/solve.jl | 18 ------------------ 2 files changed, 17 insertions(+), 18 deletions(-) diff --git a/ext/DiffEqBaseChainRulesCoreExt.jl b/ext/DiffEqBaseChainRulesCoreExt.jl index 06479ebfa..3fb586a5a 100644 --- a/ext/DiffEqBaseChainRulesCoreExt.jl +++ b/ext/DiffEqBaseChainRulesCoreExt.jl @@ -7,5 +7,22 @@ import ChainRulesCore import ChainRulesCore: NoTangent ChainRulesCore.rrule(::typeof(numargs), f) = (numargs(f), df -> (NoTangent(), NoTangent())) +ChainRulesCore.@non_differentiable checkkwargs(kwargshandle) + +function ChainRulesCore.frule(::typeof(solve_up), prob, + sensealg::Union{Nothing, AbstractSensitivityAlgorithm}, + u0, p, args...; + kwargs...) + _solve_forward(prob, sensealg, u0, p, SciMLBase.ChainRulesOriginator(), args...; + kwargs...) +end + +function ChainRulesCore.rrule(::typeof(solve_up), prob::AbstractDEProblem, + sensealg::Union{Nothing, AbstractSensitivityAlgorithm}, + u0, p, args...; + kwargs...) + _solve_adjoint(prob, sensealg, u0, p, SciMLBase.ChainRulesOriginator(), args...; + kwargs...) +end end \ No newline at end of file diff --git a/src/solve.jl b/src/solve.jl index d44f1a524..d337f55cc 100644 --- a/src/solve.jl +++ b/src/solve.jl @@ -1051,8 +1051,6 @@ function checkkwargs(kwargshandle; kwargs...) end end -@non_differentiable checkkwargs(kwargshandle) - function get_concrete_problem(prob::AbstractJumpProblem, isadapt; kwargs...) prob end @@ -1396,22 +1394,6 @@ discrete sensitivity algorithms. """ struct SensitivityADPassThrough <: AbstractDEAlgorithm end -function ChainRulesCore.frule(::typeof(solve_up), prob, - sensealg::Union{Nothing, AbstractSensitivityAlgorithm}, - u0, p, args...; - kwargs...) - _solve_forward(prob, sensealg, u0, p, SciMLBase.ChainRulesOriginator(), args...; - kwargs...) -end - -function ChainRulesCore.rrule(::typeof(solve_up), prob::AbstractDEProblem, - sensealg::Union{Nothing, AbstractSensitivityAlgorithm}, - u0, p, args...; - kwargs...) - _solve_adjoint(prob, sensealg, u0, p, SciMLBase.ChainRulesOriginator(), args...; - kwargs...) -end - ### ### Legacy Dispatches to be Non-Breaking ###