Skip to content

Commit

Permalink
try to handle more cases
Browse files Browse the repository at this point in the history
  • Loading branch information
Pepijn de Vos committed Nov 13, 2023
1 parent b3e6ae3 commit cb00e09
Showing 1 changed file with 46 additions and 12 deletions.
58 changes: 46 additions & 12 deletions src/dense/generic_dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -616,15 +616,15 @@ end
##################### Hermite Interpolants

# If no dispatch found, assume Hermite
function _ode_interpolant(Θ, dt, y₀, y₁, k, cache, idxs, T::Type{Val{TI}}) where {TI}
hermite_interpolant(Θ, dt, y₀, y₁, k, Val{cache isa OrdinaryDiffEqMutableCache}, idxs, T)
function _ode_interpolant(Θ, dt, y₀, y₁, k, cache, idxs, T::Type{Val{TI}}, differential_vars=trues(size(y₀))) where {TI}
partial_hermite_interpolant(Θ, dt, y₀, y₁, k, cache, idxs, T, differential_vars)
end

function _ode_interpolant!(out, Θ, dt, y₀, y₁, k, cache, idxs, T::Type{Val{TI}}) where {TI}
hermite_interpolant!(out, Θ, dt, y₀, y₁, k, idxs, T)
function _ode_interpolant!(out, Θ, dt, y₀, y₁, k, cache, idxs, T::Type{Val{TI}}, differential_vars=trues(size(y₀))) where {TI}
partial_hermite_interpolant!(out, Θ, dt, y₀, y₁, k, cache, idxs, T, differential_vars)
end

function _ode_interpolant(Θ, dt, y₀::AbstractArray, y₁::AbstractArray, k, cache, idxs::AbstractArray, T::Type{Val{TI}}, differential_vars) where {TI}
function partial_hermite_interpolant(Θ, dt, y₀::AbstractArray, y₁::AbstractArray, k, cache, idxs::AbstractVector{Int}, T::Type{Val{TI}}, differential_vars) where {TI}
sel = differential_vars[idxs]
diffidxs = idxs[sel]
linidxs = idxs[.!sel]
Expand All @@ -636,17 +636,34 @@ function _ode_interpolant(Θ, dt, y₀::AbstractArray, y₁::AbstractArray, k, c
res
end

function _ode_interpolant(Θ, dt, y₀, y₁, k, cache, idxs::Nothing, T::Type{Val{TI}}, differential_vars) where {TI}
_ode_interpolant(Θ, dt, y₀, y₁, k, cache, eachindex(y₀), T, differential_vars)
function partial_hermite_interpolant(Θ, dt, y₀::AbstractArray, y₁::AbstractArray, k, cache, idxs::Nothing, T::Type{Val{TI}}, differential_vars) where {TI}
partial_hermite_interpolant(Θ, dt, y₀, y₁, k, cache, eachindex(y₀), T, differential_vars)
end

function partial_hermite_interpolant(Θ, dt, y₀::AbstractArray, y₁::AbstractArray, k, cache, idxs::Int, T::Type{Val{TI}}, differential_vars) where {TI}
if differential_vars[idxs]
hermite_interpolant(Θ, dt, y₀, y₁, k, Val{cache isa OrdinaryDiffEqMutableCache}, idxs, T)
else
linear_interpolant(Θ, dt, y₀, y₁, idxs, T)
end
end

function partial_hermite_interpolant(Θ, dt, y₀::Number, y₁::Number, k, cache, idxs, T::Type{Val{TI}}, differential_vars) where {TI}
if only(differential_vars)
hermite_interpolant(Θ, dt, y₀, y₁, k, Val{cache isa OrdinaryDiffEqMutableCache}, idxs, T)
else
linear_interpolant(Θ, dt, y₀, y₁, idxs, T)
end
end

function _ode_interpolant(Θ, dt, y₀, y₁, k, cache, idxs, T::Type{Val{TI}}, differential_vars) where {TI}
function partial_hermite_interpolant(Θ, dt, y₀, y₁, k, cache, idxs, T::Type{Val{TI}}, differential_vars) where {TI}
throw("how did we get here")
@assert all(differential_vars)
@show differential_vars
hermite_interpolant(Θ, dt, y₀, y₁, k, Val{cache isa OrdinaryDiffEqMutableCache}, idxs, T)
end

function _ode_interpolant!(out::AbstractArray, Θ, dt, y₀, y₁, k, cache, idxs::AbstractArray, T::Type{Val{TI}}, differential_vars) where {TI}
function partial_hermite_interpolant!(out::AbstractArray, Θ, dt, y₀::AbstractArray, y₁::AbstractArray, k, cache, idxs::AbstractVector{Int}, T::Type{Val{TI}}, differential_vars) where {TI}
sel = differential_vars[idxs]
diffidxs = idxs[sel]
linidxs = idxs[.!sel]
Expand All @@ -655,11 +672,28 @@ function _ode_interpolant!(out::AbstractArray, Θ, dt, y₀, y₁, k, cache, idx
out
end

function _ode_interpolant!(out, Θ, dt, y₀, y₁, k, cache, idxs::Nothing, T::Type{Val{TI}}, differential_vars) where {TI}
_ode_interpolant!(out, Θ, dt, y₀, y₁, k, cache, eachindex(y₀), T, differential_vars)
function partial_hermite_interpolant!(out::AbstractArray, Θ, dt, y₀::AbstractArray, y₁::AbstractArray, k, cache, idxs::Nothing, T::Type{Val{TI}}, differential_vars) where {TI}
partial_hermite_interpolant!(out, Θ, dt, y₀, y₁, k, cache, eachindex(y₀), T, differential_vars)
end

function partial_hermite_interpolant!(out::AbstractArray, Θ, dt, y₀::AbstractArray, y₁::AbstractArray, k, cache, idxs::Int, T::Type{Val{TI}}, differential_vars) where {TI}
if differential_vars[idxs]
hermite_interpolant!(out, Θ, dt, y₀, y₁, k, idxs, T)
else
linear_interpolant!(out, Θ, dt, y₀, y₁, idxs, T)
end
end

function partial_hermite_interpolant!(out, Θ, dt, y₀::Number, y₁::Number, k, cache, idxs, T::Type{Val{TI}}, differential_vars) where {TI}
if only(differential_vars)
hermite_interpolant!(out, Θ, dt, y₀, y₁, k, idxs, T)
else
linear_interpolant!(out, Θ, dt, y₀, y₁, idxs, T)
end
end

function _ode_interpolant!(out, Θ, dt, y₀, y₁, k, cache, idxs, T::Type{Val{TI}}, differential_vars) where {TI}
function partial_hermite_interpolant!(out, Θ, dt, y₀, y₁, k, cache, idxs, T::Type{Val{TI}}, differential_vars) where {TI}
throw("how did we get here")
@assert all(differential_vars)
@show differential_vars
hermite_interpolant!(out, Θ, dt, y₀, y₁, k, idxs, T)
Expand Down

0 comments on commit cb00e09

Please sign in to comment.