diff --git a/src/interpolation.jl b/src/interpolation.jl index d212ac82f..4b17baefb 100644 --- a/src/interpolation.jl +++ b/src/interpolation.jl @@ -1,3 +1,11 @@ +function enable_interpolation_sensitivitymode end + +enable_interpolation_sensitivitymode(interp::Nothing) = nothing + +# Pass through should be deprecated in the future, made for backwards compat +enable_interpolation_sensitivitymode(interp::AbstractDiffEqInterpolation) = interp +struct SensitivityInterpolation end + """ $(TYPEDEF) """ @@ -5,6 +13,15 @@ struct HermiteInterpolation{T1, T2, T3} <: AbstractDiffEqInterpolation t::T1 u::T2 du::T3 + sensitivitymode::Bool +end + +function HermiteInterpolation(t,u,du; sensitivitymode=false) + HermiteInterpolation(t,u,du,sensitivitymode) +end + +function enable_interpolation_sensitivitymode(interp::HermiteInterpolation) + HermiteInterpolation(interp.t,interp.u,interp.du,true) end """ @@ -13,6 +30,15 @@ $(TYPEDEF) struct LinearInterpolation{T1, T2} <: AbstractDiffEqInterpolation t::T1 u::T2 + sensitivitymode::Bool +end + +function LinearInterpolation(t,u; sensitivitymode=false) + LinearInterpolation(t,u,sensitivitymode) +end + +function enable_interpolation_sensitivitymode(interp::LinearInterpolation) + LinearInterpolation(interp.t,interp.u,true) end """ @@ -21,14 +47,15 @@ $(TYPEDEF) struct ConstantInterpolation{T1, T2} <: AbstractDiffEqInterpolation t::T1 u::T2 + sensitivitymode::Bool end -""" -$(TYPEDEF) -""" -struct SensitivityInterpolation{T1, T2} <: AbstractDiffEqInterpolation - t::T1 - u::T2 +function ConstantInterpolation(t,u; sensitivitymode=false) + ConstantInterpolation(t,u,sensitivitymode) +end + +function enable_interpolation_sensitivitymode(interp::ConstantInterpolation) + ConstantInterpolation(interp.t,interp.u,true) end interp_summary(::AbstractDiffEqInterpolation) = "Unknown" @@ -36,9 +63,6 @@ interp_summary(::HermiteInterpolation) = "3rd order Hermite" interp_summary(::LinearInterpolation) = "1st order linear" interp_summary(::ConstantInterpolation) = "Piecewise constant interpolation" interp_summary(::Nothing) = "No interpolation" -function interp_summary(::SensitivityInterpolation) - "Interpolation disabled due to sensitivity analysis" -end interp_summary(sol::AbstractSciMLSolution) = interp_summary(sol.interp) const SENSITIVITY_INTERP_MESSAGE = """ @@ -69,13 +93,6 @@ end function (id::ConstantInterpolation)(val, tvals, idxs, deriv, p, continuity::Symbol = :left) interpolation!(val, tvals, id, idxs, deriv, p, continuity) end -function (id::SensitivityInterpolation)(tvals, idxs, deriv, p, continuity::Symbol = :left) - interpolation(tvals, id, idxs, deriv, p, continuity) -end -function (id::SensitivityInterpolation)(val, tvals, idxs, deriv, p, - continuity::Symbol = :left) - interpolation!(val, tvals, id, idxs, deriv, p, continuity) -end @inline function interpolation(tvals, id::I, idxs, deriv::D, p, continuity::Symbol = :left) where {I, D} @@ -118,7 +135,7 @@ end vals[j] = u[k][idxs] end else - id isa SensitivityInterpolation && error(SENSITIVITY_INTERP_MESSAGE) + id.sensitivitymode && error(SENSITIVITY_INTERP_MESSAGE) dt = t[i] - t[i - 1] Θ = (tval - t[i - 1]) / dt idxs_internal = idxs @@ -173,7 +190,7 @@ times t (sorted), with values u and derivatives ks vals[j] = u[k][idxs] end else - id isa SensitivityInterpolation && error(SENSITIVITY_INTERP_MESSAGE) + id.sensitivitymode && error(SENSITIVITY_INTERP_MESSAGE) dt = t[i] - t[i - 1] Θ = (tval - t[i - 1]) / dt idxs_internal = idxs @@ -232,7 +249,7 @@ times t (sorted), with values u and derivatives ks val = u[i - 1][idxs] end else - id isa SensitivityInterpolation && error(SENSITIVITY_INTERP_MESSAGE) + id.sensitivitymode && error(SENSITIVITY_INTERP_MESSAGE) dt = t[i] - t[i - 1] Θ = (tval - t[i - 1]) / dt idxs_internal = idxs @@ -282,7 +299,7 @@ times t (sorted), with values u and derivatives ks copy!(out, u[i - 1][idxs]) end else - id isa SensitivityInterpolation && error(SENSITIVITY_INTERP_MESSAGE) + id.sensitivitymode && error(SENSITIVITY_INTERP_MESSAGE) dt = t[i] - t[i - 1] Θ = (tval - t[i - 1]) / dt idxs_internal = idxs diff --git a/src/solutions/ode_solutions.jl b/src/solutions/ode_solutions.jl index d87375400..c95d2a0cb 100644 --- a/src/solutions/ode_solutions.jl +++ b/src/solutions/ode_solutions.jl @@ -396,13 +396,7 @@ function sensitivity_solution(sol::ODESolution, u, t) N = length((size(u0)..., length(u))) end - interp = if sol.interp isa LinearInterpolation - LinearInterpolation(t, u) - elseif sol.interp isa ConstantInterpolation - ConstantInterpolation(t, u) - else - SensitivityInterpolation(t, u) - end + interp = enable_interpolation_sensitivitymode(sol.interp) ODESolution{T, N}(u, sol.u_analytic, sol.errors, t, nothing, sol.prob,