Skip to content

Commit

Permalink
fix linear method tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Pepijn de Vos committed Nov 14, 2023
1 parent 824e33b commit 191447a
Showing 1 changed file with 13 additions and 16 deletions.
29 changes: 13 additions & 16 deletions src/dense/generic_dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -621,21 +621,21 @@ function _ode_interpolant!(out, Θ, dt, y₀, y₁, k, cache, idxs, T::Type{Val{
partial_hermite_interpolant!(out, Θ, dt, y₀, y₁, k, cache, idxs, T, differential_vars)
end

function partial_hermite_interpolant(Θ, dt, y₀::AbstractArray, y₁::AbstractArray, k, cache, idxs::AbstractVector{Int}, T::Type{Val{TI}}, differential_vars) where {TI}
function partial_hermite_interpolant(Θ, dt, y₀::AbstractArray, y₁::AbstractArray, k, cache, idxs::AbstractArray{Int}, T::Type{Val{TI}}, differential_vars) where {TI}
sel = differential_vars[idxs]
diffidxs = idxs[sel]
linidxs = idxs[.!sel]
diff = hermite_interpolant(Θ, dt, y₀, y₁, k, Val{cache isa OrdinaryDiffEqMutableCache}, diffidxs, T)
lin = linear_interpolant(Θ, dt, y₀, y₁, linidxs, T)
res = similar(y₀, eltype(y₀), size(sel))
res[sel] = diff
res[.!sel] = lin
res[sel] = hermite_interpolant(Θ, dt, y₀, y₁, k, Val{cache isa OrdinaryDiffEqMutableCache}, idxs[sel], T)
res[.!sel] = linear_interpolant(Θ, dt, y₀, y₁, idxs[.!sel], T)
res
end

function partial_hermite_interpolant(Θ, dt, y₀::AT, y₁::AT, k, cache, idxs::Nothing, T::Type{Val{TI}}, differential_vars) where {TI, AT <: AbstractArray}
if ArrayInterface.fast_scalar_indexing(AT)
partial_hermite_interpolant(Θ, dt, y₀, y₁, k, cache, eachindex(IndexLinear(), y₀), T, differential_vars)
sel = vec(differential_vars)
res = similar(y₀)
res[differential_vars] = hermite_interpolant(Θ, dt, y₀, y₁, k, Val{cache isa OrdinaryDiffEqMutableCache}, eachindex(y₀)[sel], T)
res[.!differential_vars] = linear_interpolant(Θ, dt, y₀, y₁, eachindex(y₀)[.!sel], T)
res
else
h = hermite_interpolant(Θ, dt, y₀, y₁, k, Val{cache isa OrdinaryDiffEqMutableCache}, idxs, T)
l = linear_interpolant(Θ, dt, y₀, y₁, idxs, T)
Expand Down Expand Up @@ -665,18 +665,16 @@ function partial_hermite_interpolant(Θ, dt, y₀, y₁, k, cache, idxs, T::Type
hermite_interpolant(Θ, dt, y₀, y₁, k, Val{cache isa OrdinaryDiffEqMutableCache}, idxs, T)
end

function partial_hermite_interpolant!(out::AbstractArray, Θ, dt, y₀::AbstractArray, y₁::AbstractArray, k, cache, idxs::AbstractVector{Int}, T::Type{Val{TI}}, differential_vars) where {TI}
function partial_hermite_interpolant!(out::AbstractArray, Θ, dt, y₀::AbstractArray, y₁::AbstractArray, k, cache, idxs::AbstractArray{Int}, T::Type{Val{TI}}, differential_vars) where {TI}
sel = differential_vars[idxs]
diffidxs = idxs[sel]
linidxs = idxs[.!sel]
hermite_interpolant!(@view(out[sel]), Θ, dt, y₀, y₁, k, diffidxs, T)
linear_interpolant!(@view(out[.!sel]), Θ, dt, y₀, y₁, linidxs, T)
hermite_interpolant!(@view(out[sel]), Θ, dt, y₀, y₁, k, idxs[sel], T)
linear_interpolant!(@view(out[.!sel]), Θ, dt, y₀, y₁, idxs[.!sel], T)
out
end

function partial_hermite_interpolant!(out::AbstractArray, Θ, dt, y₀::AT, y₁::AT, k, cache, idxs::Nothing, T::Type{Val{TI}}, differential_vars) where {TI, AT <: AbstractArray}
function partial_hermite_interpolant!(out::AT, Θ, dt, y₀::AT, y₁::AT, k, cache, idxs::Nothing, T::Type{Val{TI}}, differential_vars) where {TI, AT <: AbstractArray}
if ArrayInterface.fast_scalar_indexing(AT)
partial_hermite_interpolant!(out, Θ, dt, y₀, y₁, k, cache, eachindex(IndexLinear(), y₀), T, differential_vars)
partial_hermite_interpolant!(out, Θ, dt, y₀, y₁, k, cache, eachindex(y₀), T, differential_vars)
else
h = hermite_interpolant(Θ, dt, y₀, y₁, k, Val{cache isa OrdinaryDiffEqMutableCache}, idxs, T)
l = linear_interpolant(Θ, dt, y₀, y₁, idxs, T)
Expand All @@ -703,7 +701,6 @@ end
function partial_hermite_interpolant!(out, Θ, dt, y₀, y₁, k, cache, idxs, T::Type{Val{TI}}, differential_vars) where {TI}
throw("how did we get here")
@assert all(differential_vars)
differential_vars
hermite_interpolant!(out, Θ, dt, y₀, y₁, k, idxs, T)
end

Expand Down

0 comments on commit 191447a

Please sign in to comment.