Skip to content

Commit

Permalink
handle noindex_tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Pepijn de Vos committed Nov 14, 2023
1 parent 809309c commit 6590b87
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 11 deletions.
35 changes: 24 additions & 11 deletions src/dense/generic_dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -315,14 +315,14 @@ function evaluate_interpolant(f, Θ, dt, timeseries, i₋, i₊, cache, idxs,
end
end

function get_differential_vars(f, len)
differential_vars = trues(len)
function get_differential_vars(f, size)
differential_vars = trues(size)
if hasproperty(f, :mass_matrix)
mm = f.mass_matrix
if mm isa UniformScaling
# already correct
elseif isdiag(mm)
differential_vars = Diagonal(mm).diag .!= 0
differential_vars = reshape(Diagonal(mm).diag .!= 0, size)
else
# QR factorization
# @show typeof(mm)
Expand All @@ -342,7 +342,7 @@ function ode_interpolation(tvals, id::I, idxs, deriv::D, p,
@unpack ts, timeseries, ks, f, cache = id
@inbounds tdir = sign(ts[end] - ts[1])
idx = sortperm(tvals, rev = tdir < 0)
differential_vars = get_differential_vars(f, length(timeseries[begin]))
differential_vars = get_differential_vars(f, size(timeseries[begin]))
# start the search thinking it's ts[1]-ts[2]
i₋₊ref = Ref((1, 2))
vals = map(idx) do j
Expand Down Expand Up @@ -380,7 +380,7 @@ function ode_interpolation!(vals, tvals, id::I, idxs, deriv::D, p,
@unpack ts, timeseries, ks, f, cache = id
@inbounds tdir = sign(ts[end] - ts[1])
idx = sortperm(tvals, rev = tdir < 0)
differential_vars = get_differential_vars(f, length(timeseries[begin]))
differential_vars = get_differential_vars(f, size(timeseries[begin]))

# start the search thinking it's in ts[1]-ts[2]
i₋ = 1
Expand Down Expand Up @@ -466,7 +466,7 @@ function ode_interpolation(tval::Number, id::I, idxs, deriv::D, p,
@unpack ts, timeseries, ks, f, cache = id
@inbounds tdir = sign(ts[end] - ts[1])

differential_vars = get_differential_vars(f, length(timeseries[begin]))
differential_vars = get_differential_vars(f, size(timeseries[begin]))

if continuity === :left
# we have i₋ = i₊ = 1 if tval = ts[1], i₊ = i₋ + 1 = lastindex(ts) if tval > ts[end],
Expand Down Expand Up @@ -516,7 +516,7 @@ function ode_interpolation!(out, tval::Number, id::I, idxs, deriv::D, p,
@unpack ts, timeseries, ks, f, cache = id
@inbounds tdir = sign(ts[end] - ts[1])

differential_vars = get_differential_vars(f, length(timeseries[begin]))
differential_vars = get_differential_vars(f, size(timeseries[begin]))

if continuity === :left
# we have i₋ = i₊ = 1 if tval = ts[1], i₊ = i₋ + 1 = lastindex(ts) if tval > ts[end],
Expand Down Expand Up @@ -636,8 +636,14 @@ function partial_hermite_interpolant(Θ, dt, y₀::AbstractArray, y₁::Abstract
res
end

function partial_hermite_interpolant(Θ, dt, y₀::AbstractArray, y₁::AbstractArray, k, cache, idxs::Nothing, T::Type{Val{TI}}, differential_vars) where {TI}
partial_hermite_interpolant(Θ, dt, y₀, y₁, k, cache, eachindex(y₀), T, differential_vars)
function partial_hermite_interpolant(Θ, dt, y₀::AT, y₁::AT, k, cache, idxs::Nothing, T::Type{Val{TI}}, differential_vars) where {TI, AT <: AbstractArray}
if ArrayInterface.fast_scalar_indexing(AT)
partial_hermite_interpolant(Θ, dt, y₀, y₁, k, cache, eachindex(IndexLinear(), y₀), T, differential_vars)
else
h = hermite_interpolant(Θ, dt, y₀, y₁, k, Val{cache isa OrdinaryDiffEqMutableCache}, idxs, T)
l = linear_interpolant(Θ, dt, y₀, y₁, idxs, T)
@.. broadcast=false h*differential_vars + l*!differential_vars
end
end

function partial_hermite_interpolant(Θ, dt, y₀::AbstractArray, y₁::AbstractArray, k, cache, idxs::Int, T::Type{Val{TI}}, differential_vars) where {TI}
Expand Down Expand Up @@ -672,8 +678,15 @@ function partial_hermite_interpolant!(out::AbstractArray, Θ, dt, y₀::Abstract
out
end

function partial_hermite_interpolant!(out::AbstractArray, Θ, dt, y₀::AbstractArray, y₁::AbstractArray, k, cache, idxs::Nothing, T::Type{Val{TI}}, differential_vars) where {TI}
partial_hermite_interpolant!(out, Θ, dt, y₀, y₁, k, cache, eachindex(y₀), T, differential_vars)
function partial_hermite_interpolant!(out::AbstractArray, Θ, dt, y₀::AT, y₁::AT, k, cache, idxs::Nothing, T::Type{Val{TI}}, differential_vars) where {TI, AT <: AbstractArray}
if ArrayInterface.fast_scalar_indexing(AT)
partial_hermite_interpolant!(out, Θ, dt, y₀, y₁, k, cache, eachindex(IndexLinear(), y₀), T, differential_vars)
else
@show differential_vars
@show h = hermite_interpolant(Θ, dt, y₀, y₁, k, Val{cache isa OrdinaryDiffEqMutableCache}, idxs, T)
@show l = linear_interpolant(Θ, dt, y₀, y₁, idxs, T)
@.. broadcast=false out=h*differential_vars + l*!differential_vars
end
end

function partial_hermite_interpolant!(out::AbstractArray, Θ, dt, y₀::AbstractArray, y₁::AbstractArray, k, cache, idxs::Int, T::Type{Val{TI}}, differential_vars) where {TI}
Expand Down
1 change: 1 addition & 0 deletions test/interface/noindex_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ function Base.similar(bc::Base.Broadcast.Broadcasted{NoIndexStyle{N}},
end
Base.Broadcast._broadcast_getindex(x::NoIndexArray, i) = x.x[i]
Base.Broadcast.extrude(x::NoIndexArray) = x
ArrayInterface.fast_scalar_indexing(::Type{<:NoIndexArray}) = false

@inline function Base.copyto!(dest::NoIndexArray,
bc::Base.Broadcast.Broadcasted{<:NoIndexStyle})
Expand Down

0 comments on commit 6590b87

Please sign in to comment.