Skip to content

Commit

Permalink
Move differential_vars calculation to the integrator
Browse files Browse the repository at this point in the history
  • Loading branch information
ChrisRackauckas committed Dec 10, 2023
1 parent 9ab430e commit 174a6a2
Show file tree
Hide file tree
Showing 5 changed files with 63 additions and 57 deletions.
68 changes: 23 additions & 45 deletions src/dense/generic_dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ end
DiffEqBase.addsteps!(integrator)
if !(integrator.cache isa CompositeCache)
val = ode_interpolant(Θ, integrator.dt, integrator.uprev, integrator.u,
integrator.k, integrator.cache, idxs, deriv)
integrator.k, integrator.cache, idxs, deriv, integrator.differential_vars)
else
val = composite_ode_interpolant(Θ, integrator, integrator.cache.caches,
integrator.cache.current, idxs, deriv)
Expand Down Expand Up @@ -122,11 +122,11 @@ end
DiffEqBase.addsteps!(integrator)
if !(integrator.cache isa CompositeCache)
ode_interpolant!(val, Θ, integrator.dt, integrator.uprev, integrator.u,
integrator.k, integrator.cache, idxs, deriv)
integrator.k, integrator.cache, idxs, deriv, integrator.differential_vars)
else
ode_interpolant!(val, Θ, integrator.dt, integrator.uprev, integrator.u,
integrator.k, integrator.cache.caches[integrator.cache.current],
idxs, deriv)
idxs, deriv, integrator.differential_vars)
end
end

Expand Down Expand Up @@ -210,7 +210,7 @@ end
DiffEqBase.addsteps!(integrator)
if !(integrator.cache isa CompositeCache)
ode_interpolant!(val, Θ, integrator.t - integrator.tprev, integrator.uprev2,
integrator.uprev, integrator.k, integrator.cache, idxs, deriv)
integrator.uprev, integrator.k, integrator.cache, idxs, deriv, integrator.differential_vars)
else
composite_ode_extrapolant!(val, Θ, integrator, integrator.cache.caches,
integrator.cache.current, idxs, deriv)
Expand Down Expand Up @@ -241,7 +241,7 @@ end
DiffEqBase.addsteps!(integrator)
if !(integrator.cache isa CompositeCache)
ode_interpolant(Θ, integrator.t - integrator.tprev, integrator.uprev2,
integrator.uprev, integrator.k, integrator.cache, idxs, deriv)
integrator.uprev, integrator.k, integrator.cache, idxs, deriv, integrator.differential_vars)
else
composite_ode_extrapolant(Θ, integrator, integrator.cache.caches,
integrator.cache.current, idxs, deriv)
Expand Down Expand Up @@ -278,61 +278,41 @@ function _evaluate_interpolant(f, Θ, dt, timeseries, i₋, i₊,
end
function evaluate_composite_cache(f, Θ, dt, timeseries, i₋, i₊,
caches::Tuple{C1, C2, Vararg}, idxs,
deriv, ks, ts, p, cacheid) where {C1, C2}
deriv, ks, ts, p, cacheid, differential_vars) where {C1, C2}
if (cacheid -= 1) != 0
return evaluate_composite_cache(f, Θ, dt, timeseries, i₋, i₊, Base.tail(caches),
idxs,
deriv, ks, ts, p, cacheid)
deriv, ks, ts, p, cacheid, differential_vars)
end
_evaluate_interpolant(f, Θ, dt, timeseries, i₋, i₊,
first(caches), idxs,
deriv, ks, ts, p)
deriv, ks, ts, p, differential_vars)
end
function evaluate_composite_cache(f, Θ, dt, timeseries, i₋, i₊,
caches::Tuple{C}, idxs,
deriv, ks, ts, p, _) where {C}
deriv, ks, ts, p, _, differential_vars) where {C}
_evaluate_interpolant(f, Θ, dt, timeseries, i₋, i₊,
only(caches), idxs,
deriv, ks, ts, p)
deriv, ks, ts, p, differential_vars)
end

function evaluate_interpolant(f, Θ, dt, timeseries, i₋, i₊, cache, idxs,
deriv, ks, ts, id, p, differential_vars)
if cache isa (FunctionMapCache) || cache isa FunctionMapConstantCache
return ode_interpolant(Θ, dt, timeseries[i₋], timeseries[i₊], 0, cache, idxs,
deriv)
deriv, differential_vars)
elseif !id.dense
return linear_interpolant(Θ, dt, timeseries[i₋], timeseries[i₊], idxs, deriv)
elseif cache isa CompositeCache
return evaluate_composite_cache(f, Θ, dt, timeseries, i₋, i₊, cache.caches, idxs,
deriv, ks, ts, p, id.alg_choice[i₊])
deriv, ks, ts, p, id.alg_choice[i₊], differential_vars)
else
return _evaluate_interpolant(f, Θ, dt, timeseries, i₋, i₊,
cache, idxs,
deriv, ks, ts, p, differential_vars)
end
end

struct DifferentialVarsUndefined end
function get_differential_vars(f, idxs, timeseries)
differential_vars = nothing
if hasproperty(f, :mass_matrix)
mm = f.mass_matrix
if mm isa UniformScaling
return nothing
elseif isdiag(mm) && all(x -> size(x) == size(timeseries[begin]), timeseries)
differential_vars = reshape(diag(mm) .!= 0, size(timeseries[begin]))
else
return DifferentialVarsUndefined()
end
end
if idxs === nothing
return differential_vars
else
return @view differential_vars[idxs]
end
end

"""
ode_interpolation(tvals,ts,timeseries,ks)
Expand All @@ -341,10 +321,9 @@ times ts (sorted), with values timeseries and derivatives ks
"""
function ode_interpolation(tvals, id::I, idxs, deriv::D, p,
continuity::Symbol = :left) where {I, D}
@unpack ts, timeseries, ks, f, cache = id
@unpack ts, timeseries, ks, f, cache, differential_vars = id
@inbounds tdir = sign(ts[end] - ts[1])
idx = sortperm(tvals, rev = tdir < 0)
differential_vars = get_differential_vars(f, idxs, timeseries)
# start the search thinking it's ts[1]-ts[2]
i₋₊ref = Ref((1, 2))
vals = map(idx) do j
Expand Down Expand Up @@ -379,10 +358,9 @@ times ts (sorted), with values timeseries and derivatives ks
"""
function ode_interpolation!(vals, tvals, id::I, idxs, deriv::D, p,
continuity::Symbol = :left) where {I, D}
@unpack ts, timeseries, ks, f, cache = id
@unpack ts, timeseries, ks, f, cache, differential_vars = id
@inbounds tdir = sign(ts[end] - ts[1])
idx = sortperm(tvals, rev = tdir < 0)
differential_vars = get_differential_vars(f, idxs, timeseries)

# start the search thinking it's in ts[1]-ts[2]
i₋ = 1
Expand Down Expand Up @@ -465,11 +443,9 @@ times ts (sorted), with values timeseries and derivatives ks
"""
function ode_interpolation(tval::Number, id::I, idxs, deriv::D, p,
continuity::Symbol = :left) where {I, D}
@unpack ts, timeseries, ks, f, cache = id
@unpack ts, timeseries, ks, f, cache, differential_vars = id
@inbounds tdir = sign(ts[end] - ts[1])

differential_vars = get_differential_vars(f, idxs, timeseries)

if continuity === :left
# we have i₋ = i₊ = 1 if tval = ts[1], i₊ = i₋ + 1 = lastindex(ts) if tval > ts[end],
# and otherwise i₋ and i₊ satisfy ts[i₋] < tval ≤ ts[i₊]
Expand All @@ -488,7 +464,7 @@ function ode_interpolation(tval::Number, id::I, idxs, deriv::D, p,

if cache isa (FunctionMapCache) || cache isa FunctionMapConstantCache
val = ode_interpolant(Θ, dt, timeseries[i₋], timeseries[i₊], 0, cache, idxs,
deriv)
deriv, differential_vars)
elseif !id.dense
val = linear_interpolant(Θ, dt, timeseries[i₋], timeseries[i₊], idxs, deriv)
elseif cache isa CompositeCache
Expand All @@ -515,11 +491,9 @@ times ts (sorted), with values timeseries and derivatives ks
"""
function ode_interpolation!(out, tval::Number, id::I, idxs, deriv::D, p,
continuity::Symbol = :left) where {I, D}
@unpack ts, timeseries, ks, f, cache = id
@unpack ts, timeseries, ks, f, cache, differential_vars = id
@inbounds tdir = sign(ts[end] - ts[1])

differential_vars = get_differential_vars(f, idxs, timeseries)

if continuity === :left
# we have i₋ = i₊ = 1 if tval = ts[1], i₊ = i₋ + 1 = lastindex(ts) if tval > ts[end],
# and otherwise i₋ and i₊ satisfy ts[i₋] < tval ≤ ts[i₊]
Expand All @@ -538,14 +512,14 @@ function ode_interpolation!(out, tval::Number, id::I, idxs, deriv::D, p,

if cache isa (FunctionMapCache) || cache isa FunctionMapConstantCache
ode_interpolant!(out, Θ, dt, timeseries[i₋], timeseries[i₊], 0, cache, idxs,
deriv)
deriv, differential_vars)
elseif !id.dense
linear_interpolant!(out, Θ, dt, timeseries[i₋], timeseries[i₊], idxs, deriv)
elseif cache isa CompositeCache
_ode_addsteps!(ks[i₊], ts[i₋], timeseries[i₋], timeseries[i₊], dt, f, p,
cache.caches[id.alg_choice[i₊]]) # update the kcurrent
ode_interpolant!(out, Θ, dt, timeseries[i₋], timeseries[i₊], ks[i₊],
cache.caches[id.alg_choice[i₊]], idxs, deriv)
cache.caches[id.alg_choice[i₊]], idxs, deriv, differential_vars)
else
_ode_addsteps!(ks[i₊], ts[i₋], timeseries[i₋], timeseries[i₊], dt, f, p,
cache) # update the kcurrent
Expand Down Expand Up @@ -646,6 +620,8 @@ function _ode_interpolant(Θ, dt, y₀, y₁, k, cache, idxs, T::Type{Val{TI}},
else
differential_vars = Trues(size(idxs))
end
elseif idxs !== nothing
@view differential_vars[idxs]
end

hermite_interpolant(Θ, dt, y₀, y₁, k, Val{cache isa OrdinaryDiffEqMutableCache}, idxs, T, differential_vars)
Expand All @@ -662,6 +638,8 @@ function _ode_interpolant!(out, Θ, dt, y₀, y₁, k, cache, idxs, T::Type{Val{
else
differential_vars = Trues(size(idxs))
end
elseif idxs !== nothing
@view differential_vars[idxs]
end

hermite_interpolant!(out, Θ, dt, y₀, y₁, k, idxs, T, differential_vars)
Expand Down
13 changes: 7 additions & 6 deletions src/integrators/type.jl
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ For more info see the linked documentation page.
mutable struct ODEIntegrator{algType <: Union{OrdinaryDiffEqAlgorithm, DAEAlgorithm}, IIP,
uType, duType, tType, pType, eigenType, EEstT, QT, tdirType,
ksEltype, SolType, F, CacheType, O, FSALType, EventErrorType,
CallbackCacheType, IA} <:
CallbackCacheType, IA, DV} <:
DiffEqBase.AbstractODEIntegrator{algType, IIP, uType, tType}
sol::SolType
u::uType
Expand Down Expand Up @@ -133,13 +133,14 @@ mutable struct ODEIntegrator{algType <: Union{OrdinaryDiffEqAlgorithm, DAEAlgori
opts::O
stats::DiffEqBase.Stats
initializealg::IA
differential_vars::DV
fsalfirst::FSALType
fsallast::FSALType

function ODEIntegrator{algType, IIP, uType, duType, tType, pType, eigenType, EEstT,
tTypeNoUnits, tdirType, ksEltype, SolType,
F, CacheType, O, FSALType, EventErrorType, CallbackCacheType,
InitializeAlgType}(sol, u, du, k, t, dt, f, p, uprev, uprev2,
InitializeAlgType, DV}(sol, u, du, k, t, dt, f, p, uprev, uprev2,
duprev, tprev,
alg, dtcache, dtchangeable, dtpropose, tdir,
eigen_est, EEst, qold, q11, erracc, dtacc,
Expand All @@ -154,7 +155,7 @@ mutable struct ODEIntegrator{algType <: Union{OrdinaryDiffEqAlgorithm, DAEAlgori
accept_step, isout, reeval_fsal, u_modified,
reinitialize, isdae,
opts, stats,
initializealg) where {algType, IIP, uType,
initializealg, differential_vars) where {algType, IIP, uType,
duType, tType, pType,
eigenType, EEstT,
tTypeNoUnits, tdirType,
Expand All @@ -163,10 +164,10 @@ mutable struct ODEIntegrator{algType <: Union{OrdinaryDiffEqAlgorithm, DAEAlgori
FSALType,
EventErrorType,
CallbackCacheType,
InitializeAlgType}
InitializeAlgType, DV}
new{algType, IIP, uType, duType, tType, pType, eigenType, EEstT, tTypeNoUnits,
tdirType, ksEltype, SolType,
F, CacheType, O, FSALType, EventErrorType, CallbackCacheType, InitializeAlgType,
F, CacheType, O, FSALType, EventErrorType, CallbackCacheType, InitializeAlgType, DV
}(sol, u, du, k, t, dt, f, p, uprev, uprev2, duprev, tprev,
alg, dtcache, dtchangeable, dtpropose, tdir,
eigen_est, EEst, qold, q11, erracc, dtacc, success_iter,
Expand All @@ -175,7 +176,7 @@ mutable struct ODEIntegrator{algType <: Union{OrdinaryDiffEqAlgorithm, DAEAlgori
do_error_check,
event_last_time, vector_event_last_time, last_event_error,
accept_step, isout, reeval_fsal, u_modified, reinitialize, isdae,
opts, stats, initializealg) # Leave off fsalfirst and last
opts, stats, initializealg, differential_vars) # Leave off fsalfirst and last
end
end

Expand Down
6 changes: 4 additions & 2 deletions src/interp_func.jl
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
abstract type OrdinaryDiffEqInterpolation{cacheType} <:
DiffEqBase.AbstractDiffEqInterpolation end

struct InterpolationData{F, uType, tType, kType, cacheType} <:
struct InterpolationData{F, uType, tType, kType, cacheType, DV} <:
OrdinaryDiffEqInterpolation{cacheType}
f::F
timeseries::uType
ts::tType
ks::kType
dense::Bool
cache::cacheType
differential_vars::DV
end

struct CompositeInterpolationData{F, uType, tType, kType, cacheType} <:
struct CompositeInterpolationData{F, uType, tType, kType, cacheType, DV} <:
OrdinaryDiffEqInterpolation{cacheType}
f::F
timeseries::uType
Expand All @@ -20,6 +21,7 @@ struct CompositeInterpolationData{F, uType, tType, kType, cacheType} <:
alg_choice::Vector{Int}
dense::Bool
cache::cacheType
differential_vars::DV
end

function DiffEqBase.interp_summary(interp::OrdinaryDiffEqInterpolation{
Expand Down
24 changes: 24 additions & 0 deletions src/misc_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -171,3 +171,27 @@ macro fold(arg)
esc(:(Base.@assume_effects :foldable $arg))
end
end

struct DifferentialVarsUndefined end

"""
get_differential_vars(f, idxs, timeseries::uType)
Returns an array of booleans for which values are the differential variables
vs algebraic variables. Returns `nothing` for the cases where all variables
are differential variables. Returns `DifferentialVarsUndefined` if it cannot
be determined (i.e. the mass matrix is not diagonal).
"""
function get_differential_vars(f, u)
differential_vars = nothing
if hasproperty(f, :mass_matrix)
mm = f.mass_matrix
if mm isa UniformScaling
return nothing
elseif isdiag(mm)
differential_vars = reshape(diag(mm) .!= 0, size(u))
else
return DifferentialVarsUndefined()
end
end
end
9 changes: 5 additions & 4 deletions src/solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -404,15 +404,16 @@ function DiffEqBase.__init(prob::Union{DiffEqBase.AbstractODEProblem,
stop_at_next_tstop)

stats = DiffEqBase.Stats(0)
differential_vars = prob isa DAEProblem ? prob.differential_vars : get_differential_vars(f, u)

if _alg isa OrdinaryDiffEqCompositeAlgorithm
id = CompositeInterpolationData(f, timeseries, ts, ks, alg_choice, dense, cache)
id = CompositeInterpolationData(f, timeseries, ts, ks, alg_choice, dense, cache, differential_vars)
sol = DiffEqBase.build_solution(prob, _alg, ts, timeseries,
dense = dense, k = ks, interp = id,
alg_choice = alg_choice,
calculate_error = false, stats = stats)
else
id = InterpolationData(f, timeseries, ts, ks, dense, cache)
id = InterpolationData(f, timeseries, ts, ks, dense, cache, differential_vars)
sol = DiffEqBase.build_solution(prob, _alg, ts, timeseries,
dense = dense, k = ks, interp = id,
calculate_error = false, stats = stats)
Expand Down Expand Up @@ -469,7 +470,7 @@ function DiffEqBase.__init(prob::Union{DiffEqBase.AbstractODEProblem,
FType, cacheType,
typeof(opts), fsal_typeof(_alg, rate_prototype),
typeof(last_event_error), typeof(callback_cache),
typeof(initializealg)}(sol, u, du, k, t, tType(dt), f, p,
typeof(initializealg), typeof(differential_vars)}(sol, u, du, k, t, tType(dt), f, p,
uprev, uprev2, duprev, tprev,
_alg, dtcache, dtchangeable,
dtpropose, tdir, eigen_est, EEst,
Expand All @@ -485,7 +486,7 @@ function DiffEqBase.__init(prob::Union{DiffEqBase.AbstractODEProblem,
last_event_error, accept_step,
isout, reeval_fsal,
u_modified, reinitiailize, isdae,
opts, stats, initializealg)
opts, stats, initializealg, differnetial_vars)

if initialize_integrator
if isdae
Expand Down

0 comments on commit 174a6a2

Please sign in to comment.