diff --git a/src/dense/generic_dense.jl b/src/dense/generic_dense.jl index 7c4d38c3e4..c83af3b492 100644 --- a/src/dense/generic_dense.jl +++ b/src/dense/generic_dense.jl @@ -624,11 +624,34 @@ end ##################### Hermite Interpolants # If no dispatch found, assume Hermite -function _ode_interpolant(Θ, dt, y₀, y₁, k, cache, idxs, T::Type{Val{TI}}, differential_vars=trues(size(y₀))) where {TI} +function _ode_interpolant(Θ, dt, y₀, y₁, k, cache, idxs, T::Type{Val{TI}}) where {TI} hermite_interpolant(Θ, dt, y₀, y₁, k, Val{cache isa OrdinaryDiffEqMutableCache}, idxs, T) end -function _ode_interpolant!(out, Θ, dt, y₀, y₁, k, cache, idxs, T::Type{Val{TI}}, differential_vars=trues(size(y₀))) where {TI} +function _ode_interpolant!(out, Θ, dt, y₀, y₁, k, cache, idxs, T::Type{Val{TI}}) where {TI} + hermite_interpolant!(out, Θ, dt, y₀, y₁, k, idxs, T) +end + +function _ode_interpolant(Θ, dt, y₀::AbstractArray, y₁::AbstractArray, k, cache, idxs::AbstractArray, T::Type{Val{TI}}, differential_vars) where {TI} + sel = differential_vars[idxs] + @show y₀ idxs differential_vars sel + diffidxs = idxs[sel] + linidxs = idxs[.!sel] + diff = hermite_interpolant(Θ, dt, y₀, y₁, k, Val{cache isa OrdinaryDiffEqMutableCache}, diffidxs, T) + @show lin = linear_interpolant(Θ, dt, y₀, y₁, linidxs, T) + @show res = similar(y₀, eltype(y₀), size(sel)) + res[diffidxs] = diff + res[linidxs] = lin + res +end + +function _ode_interpolant(Θ, dt, y₀, y₁, k, cache, idxs, T::Type{Val{TI}}, differential_vars) where {TI} + @assert all(differential_vars) + @show differential_vars + hermite_interpolant(Θ, dt, y₀, y₁, k, Val{cache isa OrdinaryDiffEqMutableCache}, idxs, T) +end + +function _ode_interpolant!(out, Θ, dt, y₀, y₁, k, cache, idxs, T::Type{Val{TI}}, differential_vars) where {TI} hermite_interpolant!(out, Θ, dt, y₀, y₁, k, idxs, T) end