Skip to content

Commit

Permalink
add differential_vars arguments
Browse files Browse the repository at this point in the history
  • Loading branch information
Pepijn de Vos committed Nov 9, 2023
1 parent bcd225b commit e7dbe93
Show file tree
Hide file tree
Showing 2 changed files with 198 additions and 201 deletions.
23 changes: 10 additions & 13 deletions src/dense/generic_dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -511,19 +511,19 @@ function ode_interpolation!(out, tval::Number, id::I, idxs, deriv::D, p,
@unpack ts, timeseries, ks, f, cache = id
@inbounds tdir = sign(ts[end] - ts[1])

diferential_vars = trues(length(timeseries[begin]))
differential_vars = trues(length(timeseries[begin]))
if hasproperty(f, :mass_matrix)
mm = f.mass_matrix
if mm isa UniformScaling
# already correct
elseif isdiag(mm)
diferential_vars = Diagonal(mm).diag .!= 0
differential_vars = Diagonal(mm).diag .!= 0
else
# @show typeof(mm)
error("QR factorizations is annoying")
end
end
@show "!" diferential_vars
@show "!" differential_vars


if continuity === :left
Expand Down Expand Up @@ -586,15 +586,12 @@ end
"""
ode_interpolant and ode_interpolant! dispatch
"""
function ode_interpolant(Θ, dt, y₀, y₁, k, cache, idxs, T::Type{Val{TI}}) where {TI}
_ode_interpolant(Θ, dt, y₀, y₁, k, cache, idxs, T)
function ode_interpolant(Θ, dt, y₀, y₁, k, cache, idxs, T::Type{Val{TI}}, differential_vars=trues(size(y₀))) where {TI}
_ode_interpolant(Θ, dt, y₀, y₁, k, cache, idxs, T, differential_vars)
end

struct AllDifferential end
Base.getindex(::AllDifferential, ::Any) = true

function ode_interpolant(Θ, dt, y₀, y₁, k, cache::OrdinaryDiffEqMutableCache, idxs,
T::Type{Val{TI}}, differential_vars=AllDifferential()) where {TI}
T::Type{Val{TI}}, differential_vars=trues(size(y₀))) where {TI}
@show differential_vars
if idxs isa Number || y₀ isa Union{Number, SArray}
# typeof(y₀) can be these if saveidxs gives a single value
Expand All @@ -620,18 +617,18 @@ function ode_interpolant(Θ, dt, y₀, y₁, k, cache::OrdinaryDiffEqMutableCach
end
end

function ode_interpolant!(out, Θ, dt, y₀, y₁, k, cache, idxs, T::Type{Val{TI}}) where {TI}
_ode_interpolant!(out, Θ, dt, y₀, y₁, k, cache, idxs, T)
function ode_interpolant!(out, Θ, dt, y₀, y₁, k, cache, idxs, T::Type{Val{TI}}, differential_vars=trues(size(y₀))) where {TI}
_ode_interpolant!(out, Θ, dt, y₀, y₁, k, cache, idxs, T, differential_vars)
end

##################### Hermite Interpolants

# If no dispatch found, assume Hermite
function _ode_interpolant(Θ, dt, y₀, y₁, k, cache, idxs, T::Type{Val{TI}}, differential_vars=AllDifferential()) where {TI}
function _ode_interpolant(Θ, dt, y₀, y₁, k, cache, idxs, T::Type{Val{TI}}, differential_vars=trues(size(y₀))) where {TI}
hermite_interpolant(Θ, dt, y₀, y₁, k, Val{cache isa OrdinaryDiffEqMutableCache}, idxs, T)
end

function _ode_interpolant!(out, Θ, dt, y₀, y₁, k, cache, idxs, T::Type{Val{TI}}, differential_vars=AllDifferential()) where {TI}
function _ode_interpolant!(out, Θ, dt, y₀, y₁, k, cache, idxs, T::Type{Val{TI}}, differential_vars=trues(size(y₀))) where {TI}
hermite_interpolant!(out, Θ, dt, y₀, y₁, k, idxs, T)
end

Expand Down
Loading

0 comments on commit e7dbe93

Please sign in to comment.