Skip to content

Commit

Permalink
Merge pull request #593 from SciML/sensitivity_interpolation
Browse files Browse the repository at this point in the history
Make sensitivity interpolations type preserving
  • Loading branch information
ChrisRackauckas authored Jan 7, 2024
2 parents 37907fd + 924b1d7 commit 9e06a79
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 27 deletions.
57 changes: 37 additions & 20 deletions src/interpolation.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,27 @@
function enable_interpolation_sensitivitymode end

enable_interpolation_sensitivitymode(interp::Nothing) = nothing

# Pass through should be deprecated in the future, made for backwards compat
enable_interpolation_sensitivitymode(interp::AbstractDiffEqInterpolation) = interp
struct SensitivityInterpolation end

"""
$(TYPEDEF)
"""
struct HermiteInterpolation{T1, T2, T3} <: AbstractDiffEqInterpolation
t::T1
u::T2
du::T3
sensitivitymode::Bool
end

function HermiteInterpolation(t,u,du; sensitivitymode=false)
HermiteInterpolation(t,u,du,sensitivitymode)
end

function enable_interpolation_sensitivitymode(interp::HermiteInterpolation)
HermiteInterpolation(interp.t,interp.u,interp.du,true)
end

"""
Expand All @@ -13,6 +30,15 @@ $(TYPEDEF)
struct LinearInterpolation{T1, T2} <: AbstractDiffEqInterpolation
t::T1
u::T2
sensitivitymode::Bool
end

function LinearInterpolation(t,u; sensitivitymode=false)
LinearInterpolation(t,u,sensitivitymode)
end

function enable_interpolation_sensitivitymode(interp::LinearInterpolation)
LinearInterpolation(interp.t,interp.u,true)
end

"""
Expand All @@ -21,24 +47,22 @@ $(TYPEDEF)
struct ConstantInterpolation{T1, T2} <: AbstractDiffEqInterpolation
t::T1
u::T2
sensitivitymode::Bool
end

"""
$(TYPEDEF)
"""
struct SensitivityInterpolation{T1, T2} <: AbstractDiffEqInterpolation
t::T1
u::T2
function ConstantInterpolation(t,u; sensitivitymode=false)
ConstantInterpolation(t,u,sensitivitymode)
end

function enable_interpolation_sensitivitymode(interp::ConstantInterpolation)
ConstantInterpolation(interp.t,interp.u,true)
end

interp_summary(::AbstractDiffEqInterpolation) = "Unknown"
interp_summary(::HermiteInterpolation) = "3rd order Hermite"
interp_summary(::LinearInterpolation) = "1st order linear"
interp_summary(::ConstantInterpolation) = "Piecewise constant interpolation"
interp_summary(::Nothing) = "No interpolation"
function interp_summary(::SensitivityInterpolation)
"Interpolation disabled due to sensitivity analysis"
end
interp_summary(sol::AbstractSciMLSolution) = interp_summary(sol.interp)

const SENSITIVITY_INTERP_MESSAGE = """
Expand Down Expand Up @@ -69,13 +93,6 @@ end
function (id::ConstantInterpolation)(val, tvals, idxs, deriv, p, continuity::Symbol = :left)
interpolation!(val, tvals, id, idxs, deriv, p, continuity)
end
function (id::SensitivityInterpolation)(tvals, idxs, deriv, p, continuity::Symbol = :left)
interpolation(tvals, id, idxs, deriv, p, continuity)
end
function (id::SensitivityInterpolation)(val, tvals, idxs, deriv, p,
continuity::Symbol = :left)
interpolation!(val, tvals, id, idxs, deriv, p, continuity)
end

@inline function interpolation(tvals, id::I, idxs, deriv::D, p,
continuity::Symbol = :left) where {I, D}
Expand Down Expand Up @@ -118,7 +135,7 @@ end
vals[j] = u[k][idxs]
end
else
id isa SensitivityInterpolation && error(SENSITIVITY_INTERP_MESSAGE)
id.sensitivitymode && error(SENSITIVITY_INTERP_MESSAGE)
dt = t[i] - t[i - 1]
Θ = (tval - t[i - 1]) / dt
idxs_internal = idxs
Expand Down Expand Up @@ -173,7 +190,7 @@ times t (sorted), with values u and derivatives ks
vals[j] = u[k][idxs]
end
else
id isa SensitivityInterpolation && error(SENSITIVITY_INTERP_MESSAGE)
id.sensitivitymode && error(SENSITIVITY_INTERP_MESSAGE)
dt = t[i] - t[i - 1]
Θ = (tval - t[i - 1]) / dt
idxs_internal = idxs
Expand Down Expand Up @@ -232,7 +249,7 @@ times t (sorted), with values u and derivatives ks
val = u[i - 1][idxs]
end
else
id isa SensitivityInterpolation && error(SENSITIVITY_INTERP_MESSAGE)
id.sensitivitymode && error(SENSITIVITY_INTERP_MESSAGE)
dt = t[i] - t[i - 1]
Θ = (tval - t[i - 1]) / dt
idxs_internal = idxs
Expand Down Expand Up @@ -282,7 +299,7 @@ times t (sorted), with values u and derivatives ks
copy!(out, u[i - 1][idxs])
end
else
id isa SensitivityInterpolation && error(SENSITIVITY_INTERP_MESSAGE)
id.sensitivitymode && error(SENSITIVITY_INTERP_MESSAGE)
dt = t[i] - t[i - 1]
Θ = (tval - t[i - 1]) / dt
idxs_internal = idxs
Expand Down
8 changes: 1 addition & 7 deletions src/solutions/ode_solutions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -396,13 +396,7 @@ function sensitivity_solution(sol::ODESolution, u, t)
N = length((size(u0)..., length(u)))
end

interp = if sol.interp isa LinearInterpolation
LinearInterpolation(t, u)
elseif sol.interp isa ConstantInterpolation
ConstantInterpolation(t, u)
else
SensitivityInterpolation(t, u)
end
interp = enable_interpolation_sensitivitymode(sol.interp)

ODESolution{T, N}(u, sol.u_analytic, sol.errors, t,
nothing, sol.prob,
Expand Down

0 comments on commit 9e06a79

Please sign in to comment.