diff --git a/src/dense/generic_dense.jl b/src/dense/generic_dense.jl index c83af3b492..92e225f1f9 100644 --- a/src/dense/generic_dense.jl +++ b/src/dense/generic_dense.jl @@ -315,6 +315,22 @@ function evaluate_interpolant(f, Θ, dt, timeseries, i₋, i₊, cache, idxs, end end +function get_differential_vars(f, len) + differential_vars = trues(len) + if hasproperty(f, :mass_matrix) + mm = f.mass_matrix + if mm isa UniformScaling + # already correct + elseif isdiag(mm) + differential_vars = Diagonal(mm).diag .!= 0 + else + @show typeof(mm) + error("QR factorizations is annoying") + end + end + @show differential_vars +end + """ ode_interpolation(tvals,ts,timeseries,ks) @@ -326,6 +342,7 @@ function ode_interpolation(tvals, id::I, idxs, deriv::D, p, @unpack ts, timeseries, ks, f, cache = id @inbounds tdir = sign(ts[end] - ts[1]) idx = sortperm(tvals, rev = tdir < 0) + differential_vars = get_differential_vars(f, length(timeseries[begin])) # start the search thinking it's ts[1]-ts[2] i₋₊ref = Ref((1, 2)) vals = map(idx) do j @@ -363,6 +380,7 @@ function ode_interpolation!(vals, tvals, id::I, idxs, deriv::D, p, @unpack ts, timeseries, ks, f, cache = id @inbounds tdir = sign(ts[end] - ts[1]) idx = sortperm(tvals, rev = tdir < 0) + differential_vars = get_differential_vars(f, length(timeseries[begin])) # start the search thinking it's in ts[1]-ts[2] i₋ = 1 @@ -448,20 +466,7 @@ function ode_interpolation(tval::Number, id::I, idxs, deriv::D, p, @unpack ts, timeseries, ks, f, cache = id @inbounds tdir = sign(ts[end] - ts[1]) - differential_vars = trues(length(timeseries[begin])) - if hasproperty(f, :mass_matrix) - mm = f.mass_matrix - if mm isa UniformScaling - # already correct - elseif isdiag(mm) - differential_vars = Diagonal(mm).diag .!= 0 - else - @show typeof(mm) - error("QR factorizations is annoying") - end - end - @show differential_vars - + differential_vars = get_differential_vars(f, length(timeseries[begin])) if continuity === :left # we have i₋ = i₊ = 1 if tval = ts[1], i₊ = i₋ + 1 = lastindex(ts) if tval > ts[end], @@ -511,20 +516,7 @@ function ode_interpolation!(out, tval::Number, id::I, idxs, deriv::D, p, @unpack ts, timeseries, ks, f, cache = id @inbounds tdir = sign(ts[end] - ts[1]) - differential_vars = trues(length(timeseries[begin])) - if hasproperty(f, :mass_matrix) - mm = f.mass_matrix - if mm isa UniformScaling - # already correct - elseif isdiag(mm) - differential_vars = Diagonal(mm).diag .!= 0 - else - # @show typeof(mm) - error("QR factorizations is annoying") - end - end - @show "!" differential_vars - + differential_vars = get_differential_vars(f, length(timeseries[begin])) if continuity === :left # we have i₋ = i₊ = 1 if tval = ts[1], i₊ = i₋ + 1 = lastindex(ts) if tval > ts[end], @@ -634,24 +626,42 @@ end function _ode_interpolant(Θ, dt, y₀::AbstractArray, y₁::AbstractArray, k, cache, idxs::AbstractArray, T::Type{Val{TI}}, differential_vars) where {TI} sel = differential_vars[idxs] - @show y₀ idxs differential_vars sel diffidxs = idxs[sel] linidxs = idxs[.!sel] diff = hermite_interpolant(Θ, dt, y₀, y₁, k, Val{cache isa OrdinaryDiffEqMutableCache}, diffidxs, T) - @show lin = linear_interpolant(Θ, dt, y₀, y₁, linidxs, T) - @show res = similar(y₀, eltype(y₀), size(sel)) - res[diffidxs] = diff - res[linidxs] = lin + lin = linear_interpolant(Θ, dt, y₀, y₁, linidxs, T) + res = similar(y₀, eltype(y₀), size(sel)) + res[sel] = diff + res[.!sel] = lin res end +function _ode_interpolant(Θ, dt, y₀, y₁, k, cache, idxs::Nothing, T::Type{Val{TI}}, differential_vars) where {TI} + _ode_interpolant(Θ, dt, y₀, y₁, k, cache, eachindex(y₀), T, differential_vars) +end + function _ode_interpolant(Θ, dt, y₀, y₁, k, cache, idxs, T::Type{Val{TI}}, differential_vars) where {TI} @assert all(differential_vars) @show differential_vars hermite_interpolant(Θ, dt, y₀, y₁, k, Val{cache isa OrdinaryDiffEqMutableCache}, idxs, T) end +function _ode_interpolant!(out::AbstractArray, Θ, dt, y₀, y₁, k, cache, idxs::AbstractArray, T::Type{Val{TI}}, differential_vars) where {TI} + sel = differential_vars[idxs] + diffidxs = idxs[sel] + linidxs = idxs[.!sel] + hermite_interpolant!(@view(out[sel]), Θ, dt, y₀, y₁, k, diffidxs, T) + linear_interpolant!(@view(out[.!sel]), Θ, dt, y₀, y₁, linidxs, T) + out +end + +function _ode_interpolant!(out, Θ, dt, y₀, y₁, k, cache, idxs::Nothing, T::Type{Val{TI}}, differential_vars) where {TI} + _ode_interpolant!(out, Θ, dt, y₀, y₁, k, cache, eachindex(y₀), T, differential_vars) +end + function _ode_interpolant!(out, Θ, dt, y₀, y₁, k, cache, idxs, T::Type{Val{TI}}, differential_vars) where {TI} + @assert all(differential_vars) + @show differential_vars hermite_interpolant!(out, Θ, dt, y₀, y₁, k, idxs, T) end