diff --git a/src/dense/generic_dense.jl b/src/dense/generic_dense.jl index 0dc062bc7d..e7c34df715 100644 --- a/src/dense/generic_dense.jl +++ b/src/dense/generic_dense.jl @@ -90,7 +90,7 @@ end DiffEqBase.addsteps!(integrator) if !(integrator.cache isa CompositeCache) val = ode_interpolant(Θ, integrator.dt, integrator.uprev, integrator.u, - integrator.k, integrator.cache, idxs, deriv) + integrator.k, integrator.cache, idxs, deriv, integrator.differential_vars) else val = composite_ode_interpolant(Θ, integrator, integrator.cache.caches, integrator.cache.current, idxs, deriv) @@ -122,11 +122,11 @@ end DiffEqBase.addsteps!(integrator) if !(integrator.cache isa CompositeCache) ode_interpolant!(val, Θ, integrator.dt, integrator.uprev, integrator.u, - integrator.k, integrator.cache, idxs, deriv) + integrator.k, integrator.cache, idxs, deriv, integrator.differential_vars) else ode_interpolant!(val, Θ, integrator.dt, integrator.uprev, integrator.u, integrator.k, integrator.cache.caches[integrator.cache.current], - idxs, deriv) + idxs, deriv, integrator.differential_vars) end end @@ -210,7 +210,7 @@ end DiffEqBase.addsteps!(integrator) if !(integrator.cache isa CompositeCache) ode_interpolant!(val, Θ, integrator.t - integrator.tprev, integrator.uprev2, - integrator.uprev, integrator.k, integrator.cache, idxs, deriv) + integrator.uprev, integrator.k, integrator.cache, idxs, deriv, integrator.differential_vars) else composite_ode_extrapolant!(val, Θ, integrator, integrator.cache.caches, integrator.cache.current, idxs, deriv) @@ -241,7 +241,7 @@ end DiffEqBase.addsteps!(integrator) if !(integrator.cache isa CompositeCache) ode_interpolant(Θ, integrator.t - integrator.tprev, integrator.uprev2, - integrator.uprev, integrator.k, integrator.cache, idxs, deriv) + integrator.uprev, integrator.k, integrator.cache, idxs, deriv, integrator.differential_vars) else composite_ode_extrapolant(Θ, integrator, integrator.cache.caches, integrator.cache.current, idxs, deriv) @@ -278,34 +278,34 @@ function _evaluate_interpolant(f, Θ, dt, timeseries, i₋, i₊, end function evaluate_composite_cache(f, Θ, dt, timeseries, i₋, i₊, caches::Tuple{C1, C2, Vararg}, idxs, - deriv, ks, ts, p, cacheid) where {C1, C2} + deriv, ks, ts, p, cacheid, differential_vars) where {C1, C2} if (cacheid -= 1) != 0 return evaluate_composite_cache(f, Θ, dt, timeseries, i₋, i₊, Base.tail(caches), idxs, - deriv, ks, ts, p, cacheid) + deriv, ks, ts, p, cacheid, differential_vars) end _evaluate_interpolant(f, Θ, dt, timeseries, i₋, i₊, first(caches), idxs, - deriv, ks, ts, p) + deriv, ks, ts, p, differential_vars) end function evaluate_composite_cache(f, Θ, dt, timeseries, i₋, i₊, caches::Tuple{C}, idxs, - deriv, ks, ts, p, _) where {C} + deriv, ks, ts, p, _, differential_vars) where {C} _evaluate_interpolant(f, Θ, dt, timeseries, i₋, i₊, only(caches), idxs, - deriv, ks, ts, p) + deriv, ks, ts, p, differential_vars) end function evaluate_interpolant(f, Θ, dt, timeseries, i₋, i₊, cache, idxs, deriv, ks, ts, id, p, differential_vars) if cache isa (FunctionMapCache) || cache isa FunctionMapConstantCache return ode_interpolant(Θ, dt, timeseries[i₋], timeseries[i₊], 0, cache, idxs, - deriv) + deriv, differential_vars) elseif !id.dense return linear_interpolant(Θ, dt, timeseries[i₋], timeseries[i₊], idxs, deriv) elseif cache isa CompositeCache return evaluate_composite_cache(f, Θ, dt, timeseries, i₋, i₊, cache.caches, idxs, - deriv, ks, ts, p, id.alg_choice[i₊]) + deriv, ks, ts, p, id.alg_choice[i₊], differential_vars) else return _evaluate_interpolant(f, Θ, dt, timeseries, i₋, i₊, cache, idxs, @@ -313,26 +313,6 @@ function evaluate_interpolant(f, Θ, dt, timeseries, i₋, i₊, cache, idxs, end end -struct DifferentialVarsUndefined end -function get_differential_vars(f, idxs, timeseries) - differential_vars = nothing - if hasproperty(f, :mass_matrix) - mm = f.mass_matrix - if mm isa UniformScaling - return nothing - elseif isdiag(mm) && all(x -> size(x) == size(timeseries[begin]), timeseries) - differential_vars = reshape(diag(mm) .!= 0, size(timeseries[begin])) - else - return DifferentialVarsUndefined() - end - end - if idxs === nothing - return differential_vars - else - return @view differential_vars[idxs] - end -end - """ ode_interpolation(tvals,ts,timeseries,ks) @@ -341,10 +321,9 @@ times ts (sorted), with values timeseries and derivatives ks """ function ode_interpolation(tvals, id::I, idxs, deriv::D, p, continuity::Symbol = :left) where {I, D} - @unpack ts, timeseries, ks, f, cache = id + @unpack ts, timeseries, ks, f, cache, differential_vars = id @inbounds tdir = sign(ts[end] - ts[1]) idx = sortperm(tvals, rev = tdir < 0) - differential_vars = get_differential_vars(f, idxs, timeseries) # start the search thinking it's ts[1]-ts[2] i₋₊ref = Ref((1, 2)) vals = map(idx) do j @@ -379,10 +358,9 @@ times ts (sorted), with values timeseries and derivatives ks """ function ode_interpolation!(vals, tvals, id::I, idxs, deriv::D, p, continuity::Symbol = :left) where {I, D} - @unpack ts, timeseries, ks, f, cache = id + @unpack ts, timeseries, ks, f, cache, differential_vars = id @inbounds tdir = sign(ts[end] - ts[1]) idx = sortperm(tvals, rev = tdir < 0) - differential_vars = get_differential_vars(f, idxs, timeseries) # start the search thinking it's in ts[1]-ts[2] i₋ = 1 @@ -465,11 +443,9 @@ times ts (sorted), with values timeseries and derivatives ks """ function ode_interpolation(tval::Number, id::I, idxs, deriv::D, p, continuity::Symbol = :left) where {I, D} - @unpack ts, timeseries, ks, f, cache = id + @unpack ts, timeseries, ks, f, cache, differential_vars = id @inbounds tdir = sign(ts[end] - ts[1]) - differential_vars = get_differential_vars(f, idxs, timeseries) - if continuity === :left # we have i₋ = i₊ = 1 if tval = ts[1], i₊ = i₋ + 1 = lastindex(ts) if tval > ts[end], # and otherwise i₋ and i₊ satisfy ts[i₋] < tval ≤ ts[i₊] @@ -488,7 +464,7 @@ function ode_interpolation(tval::Number, id::I, idxs, deriv::D, p, if cache isa (FunctionMapCache) || cache isa FunctionMapConstantCache val = ode_interpolant(Θ, dt, timeseries[i₋], timeseries[i₊], 0, cache, idxs, - deriv) + deriv, differential_vars) elseif !id.dense val = linear_interpolant(Θ, dt, timeseries[i₋], timeseries[i₊], idxs, deriv) elseif cache isa CompositeCache @@ -515,11 +491,9 @@ times ts (sorted), with values timeseries and derivatives ks """ function ode_interpolation!(out, tval::Number, id::I, idxs, deriv::D, p, continuity::Symbol = :left) where {I, D} - @unpack ts, timeseries, ks, f, cache = id + @unpack ts, timeseries, ks, f, cache, differential_vars = id @inbounds tdir = sign(ts[end] - ts[1]) - differential_vars = get_differential_vars(f, idxs, timeseries) - if continuity === :left # we have i₋ = i₊ = 1 if tval = ts[1], i₊ = i₋ + 1 = lastindex(ts) if tval > ts[end], # and otherwise i₋ and i₊ satisfy ts[i₋] < tval ≤ ts[i₊] @@ -538,14 +512,14 @@ function ode_interpolation!(out, tval::Number, id::I, idxs, deriv::D, p, if cache isa (FunctionMapCache) || cache isa FunctionMapConstantCache ode_interpolant!(out, Θ, dt, timeseries[i₋], timeseries[i₊], 0, cache, idxs, - deriv) + deriv, differential_vars) elseif !id.dense linear_interpolant!(out, Θ, dt, timeseries[i₋], timeseries[i₊], idxs, deriv) elseif cache isa CompositeCache _ode_addsteps!(ks[i₊], ts[i₋], timeseries[i₋], timeseries[i₊], dt, f, p, cache.caches[id.alg_choice[i₊]]) # update the kcurrent ode_interpolant!(out, Θ, dt, timeseries[i₋], timeseries[i₊], ks[i₊], - cache.caches[id.alg_choice[i₊]], idxs, deriv) + cache.caches[id.alg_choice[i₊]], idxs, deriv, differential_vars) else _ode_addsteps!(ks[i₊], ts[i₋], timeseries[i₋], timeseries[i₊], dt, f, p, cache) # update the kcurrent @@ -646,6 +620,8 @@ function _ode_interpolant(Θ, dt, y₀, y₁, k, cache, idxs, T::Type{Val{TI}}, else differential_vars = Trues(size(idxs)) end + elseif idxs !== nothing + @view differential_vars[idxs] end hermite_interpolant(Θ, dt, y₀, y₁, k, Val{cache isa OrdinaryDiffEqMutableCache}, idxs, T, differential_vars) @@ -662,6 +638,8 @@ function _ode_interpolant!(out, Θ, dt, y₀, y₁, k, cache, idxs, T::Type{Val{ else differential_vars = Trues(size(idxs)) end + elseif idxs !== nothing + @view differential_vars[idxs] end hermite_interpolant!(out, Θ, dt, y₀, y₁, k, idxs, T, differential_vars) diff --git a/src/integrators/type.jl b/src/integrators/type.jl index 6f96de274a..cdf3b077c1 100644 --- a/src/integrators/type.jl +++ b/src/integrators/type.jl @@ -85,7 +85,7 @@ For more info see the linked documentation page. mutable struct ODEIntegrator{algType <: Union{OrdinaryDiffEqAlgorithm, DAEAlgorithm}, IIP, uType, duType, tType, pType, eigenType, EEstT, QT, tdirType, ksEltype, SolType, F, CacheType, O, FSALType, EventErrorType, - CallbackCacheType, IA} <: + CallbackCacheType, IA, DV} <: DiffEqBase.AbstractODEIntegrator{algType, IIP, uType, tType} sol::SolType u::uType @@ -133,13 +133,14 @@ mutable struct ODEIntegrator{algType <: Union{OrdinaryDiffEqAlgorithm, DAEAlgori opts::O stats::DiffEqBase.Stats initializealg::IA + differential_vars::DV fsalfirst::FSALType fsallast::FSALType function ODEIntegrator{algType, IIP, uType, duType, tType, pType, eigenType, EEstT, tTypeNoUnits, tdirType, ksEltype, SolType, F, CacheType, O, FSALType, EventErrorType, CallbackCacheType, - InitializeAlgType}(sol, u, du, k, t, dt, f, p, uprev, uprev2, + InitializeAlgType, DV}(sol, u, du, k, t, dt, f, p, uprev, uprev2, duprev, tprev, alg, dtcache, dtchangeable, dtpropose, tdir, eigen_est, EEst, qold, q11, erracc, dtacc, @@ -154,7 +155,7 @@ mutable struct ODEIntegrator{algType <: Union{OrdinaryDiffEqAlgorithm, DAEAlgori accept_step, isout, reeval_fsal, u_modified, reinitialize, isdae, opts, stats, - initializealg) where {algType, IIP, uType, + initializealg, differential_vars) where {algType, IIP, uType, duType, tType, pType, eigenType, EEstT, tTypeNoUnits, tdirType, @@ -163,10 +164,10 @@ mutable struct ODEIntegrator{algType <: Union{OrdinaryDiffEqAlgorithm, DAEAlgori FSALType, EventErrorType, CallbackCacheType, - InitializeAlgType} + InitializeAlgType, DV} new{algType, IIP, uType, duType, tType, pType, eigenType, EEstT, tTypeNoUnits, tdirType, ksEltype, SolType, - F, CacheType, O, FSALType, EventErrorType, CallbackCacheType, InitializeAlgType, + F, CacheType, O, FSALType, EventErrorType, CallbackCacheType, InitializeAlgType, DV }(sol, u, du, k, t, dt, f, p, uprev, uprev2, duprev, tprev, alg, dtcache, dtchangeable, dtpropose, tdir, eigen_est, EEst, qold, q11, erracc, dtacc, success_iter, @@ -175,7 +176,7 @@ mutable struct ODEIntegrator{algType <: Union{OrdinaryDiffEqAlgorithm, DAEAlgori do_error_check, event_last_time, vector_event_last_time, last_event_error, accept_step, isout, reeval_fsal, u_modified, reinitialize, isdae, - opts, stats, initializealg) # Leave off fsalfirst and last + opts, stats, initializealg, differential_vars) # Leave off fsalfirst and last end end diff --git a/src/interp_func.jl b/src/interp_func.jl index 0f43051a87..d6be19dddb 100644 --- a/src/interp_func.jl +++ b/src/interp_func.jl @@ -1,7 +1,7 @@ abstract type OrdinaryDiffEqInterpolation{cacheType} <: DiffEqBase.AbstractDiffEqInterpolation end -struct InterpolationData{F, uType, tType, kType, cacheType} <: +struct InterpolationData{F, uType, tType, kType, cacheType, DV} <: OrdinaryDiffEqInterpolation{cacheType} f::F timeseries::uType @@ -9,9 +9,10 @@ struct InterpolationData{F, uType, tType, kType, cacheType} <: ks::kType dense::Bool cache::cacheType + differential_vars::DV end -struct CompositeInterpolationData{F, uType, tType, kType, cacheType} <: +struct CompositeInterpolationData{F, uType, tType, kType, cacheType, DV} <: OrdinaryDiffEqInterpolation{cacheType} f::F timeseries::uType @@ -20,6 +21,7 @@ struct CompositeInterpolationData{F, uType, tType, kType, cacheType} <: alg_choice::Vector{Int} dense::Bool cache::cacheType + differential_vars::DV end function DiffEqBase.interp_summary(interp::OrdinaryDiffEqInterpolation{ diff --git a/src/misc_utils.jl b/src/misc_utils.jl index 02d3d7c626..066d1eb9e9 100644 --- a/src/misc_utils.jl +++ b/src/misc_utils.jl @@ -171,3 +171,27 @@ macro fold(arg) esc(:(Base.@assume_effects :foldable $arg)) end end + +struct DifferentialVarsUndefined end + +""" + get_differential_vars(f, idxs, timeseries::uType) + +Returns an array of booleans for which values are the differential variables +vs algebraic variables. Returns `nothing` for the cases where all variables +are differential variables. Returns `DifferentialVarsUndefined` if it cannot +be determined (i.e. the mass matrix is not diagonal). +""" +function get_differential_vars(f, u) + differential_vars = nothing + if hasproperty(f, :mass_matrix) + mm = f.mass_matrix + if mm isa UniformScaling + return nothing + elseif isdiag(mm) + differential_vars = reshape(diag(mm) .!= 0, size(u)) + else + return DifferentialVarsUndefined() + end + end +end \ No newline at end of file diff --git a/src/solve.jl b/src/solve.jl index 5d73383239..6ff6bde34e 100644 --- a/src/solve.jl +++ b/src/solve.jl @@ -404,15 +404,16 @@ function DiffEqBase.__init(prob::Union{DiffEqBase.AbstractODEProblem, stop_at_next_tstop) stats = DiffEqBase.Stats(0) + differential_vars = prob isa DAEProblem ? prob.differential_vars : get_differential_vars(f, u) if _alg isa OrdinaryDiffEqCompositeAlgorithm - id = CompositeInterpolationData(f, timeseries, ts, ks, alg_choice, dense, cache) + id = CompositeInterpolationData(f, timeseries, ts, ks, alg_choice, dense, cache, differential_vars) sol = DiffEqBase.build_solution(prob, _alg, ts, timeseries, dense = dense, k = ks, interp = id, alg_choice = alg_choice, calculate_error = false, stats = stats) else - id = InterpolationData(f, timeseries, ts, ks, dense, cache) + id = InterpolationData(f, timeseries, ts, ks, dense, cache, differential_vars) sol = DiffEqBase.build_solution(prob, _alg, ts, timeseries, dense = dense, k = ks, interp = id, calculate_error = false, stats = stats) @@ -469,7 +470,7 @@ function DiffEqBase.__init(prob::Union{DiffEqBase.AbstractODEProblem, FType, cacheType, typeof(opts), fsal_typeof(_alg, rate_prototype), typeof(last_event_error), typeof(callback_cache), - typeof(initializealg)}(sol, u, du, k, t, tType(dt), f, p, + typeof(initializealg), typeof(differential_vars)}(sol, u, du, k, t, tType(dt), f, p, uprev, uprev2, duprev, tprev, _alg, dtcache, dtchangeable, dtpropose, tdir, eigen_est, EEst, @@ -485,7 +486,7 @@ function DiffEqBase.__init(prob::Union{DiffEqBase.AbstractODEProblem, last_event_error, accept_step, isout, reeval_fsal, u_modified, reinitiailize, isdae, - opts, stats, initializealg) + opts, stats, initializealg, differnetial_vars) if initialize_integrator if isdae