diff --git a/Project.toml b/Project.toml index c21296c70..f2bc69504 100644 --- a/Project.toml +++ b/Project.toml @@ -4,6 +4,7 @@ authors = ["Chris Rackauckas "] version = "6.150.0" [deps] +Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471" DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" @@ -60,6 +61,7 @@ DiffEqBaseTrackerExt = "Tracker" DiffEqBaseUnitfulExt = "Unitful" [compat] +Adapt = "4" ArrayInterface = "7.8" ChainRulesCore = "1" ConcreteStructs = "0.2.3" diff --git a/ext/DiffEqBaseTrackerExt.jl b/ext/DiffEqBaseTrackerExt.jl index b7a67c815..54e6e9f85 100644 --- a/ext/DiffEqBaseTrackerExt.jl +++ b/ext/DiffEqBaseTrackerExt.jl @@ -4,10 +4,12 @@ if isdefined(Base, :get_extension) using DiffEqBase import DiffEqBase: value import Tracker + import Adapt else using ..DiffEqBase import ..DiffEqBase: value import ..Tracker + import ..Adapt end DiffEqBase.value(x::Type{Tracker.TrackedReal{T}}) where {T} = T @@ -97,9 +99,10 @@ Tracker.@grad function DiffEqBase.solve_up(prob, }, u0, p, args...; kwargs...) - out = DiffEqBase._solve_adjoint(prob, sensealg, Tracker.data(u0), Tracker.data(p), + sol, pb_f = DiffEqBase._solve_adjoint(prob, sensealg, Tracker.data(u0), Tracker.data(p), SciMLBase.TrackerOriginator(), args...; kwargs...) - Array(out[1]), out[2] + sol isa AbstractArray && return sol.u, pb_f # AbstractNoTimeSolution isa AbstractArray + return convert(AbstractArray, sol), pb_f end end diff --git a/src/solve.jl b/src/solve.jl index 989870c10..c5cc0a979 100644 --- a/src/solve.jl +++ b/src/solve.jl @@ -1130,23 +1130,23 @@ function get_concrete_problem(prob::AbstractJumpProblem, isadapt; kwargs...) end function get_concrete_problem(prob::SteadyStateProblem, isadapt; kwargs...) - u0 = get_concrete_u0(prob, isadapt, Inf, kwargs) - u0 = promote_u0(u0, prob.p, nothing) p = get_concrete_p(prob, kwargs) + u0 = get_concrete_u0(prob, isadapt, Inf, kwargs) + u0 = promote_u0(u0, p, nothing) remake(prob; u0 = u0, p = p) end function get_concrete_problem(prob::NonlinearProblem, isadapt; kwargs...) - u0 = get_concrete_u0(prob, isadapt, nothing, kwargs) - u0 = promote_u0(u0, prob.p, nothing) p = get_concrete_p(prob, kwargs) + u0 = get_concrete_u0(prob, isadapt, nothing, kwargs) + u0 = promote_u0(u0, p, nothing) remake(prob; u0 = u0, p = p) end function get_concrete_problem(prob::NonlinearLeastSquaresProblem, isadapt; kwargs...) - u0 = get_concrete_u0(prob, isadapt, nothing, kwargs) - u0 = promote_u0(u0, prob.p, nothing) p = get_concrete_p(prob, kwargs) + u0 = get_concrete_u0(prob, isadapt, nothing, kwargs) + u0 = promote_u0(u0, p, nothing) remake(prob; u0 = u0, p = p) end