diff --git a/src/dense/generic_dense.jl b/src/dense/generic_dense.jl index cb9ffb4593..f65d1c5e8a 100644 --- a/src/dense/generic_dense.jl +++ b/src/dense/generic_dense.jl @@ -315,14 +315,14 @@ function evaluate_interpolant(f, Θ, dt, timeseries, i₋, i₊, cache, idxs, end end -function get_differential_vars(f, len) - differential_vars = trues(len) +function get_differential_vars(f, size) + differential_vars = trues(size) if hasproperty(f, :mass_matrix) mm = f.mass_matrix if mm isa UniformScaling # already correct elseif isdiag(mm) - differential_vars = Diagonal(mm).diag .!= 0 + differential_vars = reshape(Diagonal(mm).diag .!= 0, size) else # QR factorization # @show typeof(mm) @@ -342,7 +342,7 @@ function ode_interpolation(tvals, id::I, idxs, deriv::D, p, @unpack ts, timeseries, ks, f, cache = id @inbounds tdir = sign(ts[end] - ts[1]) idx = sortperm(tvals, rev = tdir < 0) - differential_vars = get_differential_vars(f, length(timeseries[begin])) + differential_vars = get_differential_vars(f, size(timeseries[begin])) # start the search thinking it's ts[1]-ts[2] i₋₊ref = Ref((1, 2)) vals = map(idx) do j @@ -380,7 +380,7 @@ function ode_interpolation!(vals, tvals, id::I, idxs, deriv::D, p, @unpack ts, timeseries, ks, f, cache = id @inbounds tdir = sign(ts[end] - ts[1]) idx = sortperm(tvals, rev = tdir < 0) - differential_vars = get_differential_vars(f, length(timeseries[begin])) + differential_vars = get_differential_vars(f, size(timeseries[begin])) # start the search thinking it's in ts[1]-ts[2] i₋ = 1 @@ -466,7 +466,7 @@ function ode_interpolation(tval::Number, id::I, idxs, deriv::D, p, @unpack ts, timeseries, ks, f, cache = id @inbounds tdir = sign(ts[end] - ts[1]) - differential_vars = get_differential_vars(f, length(timeseries[begin])) + differential_vars = get_differential_vars(f, size(timeseries[begin])) if continuity === :left # we have i₋ = i₊ = 1 if tval = ts[1], i₊ = i₋ + 1 = lastindex(ts) if tval > ts[end], @@ -516,7 +516,7 @@ function ode_interpolation!(out, tval::Number, id::I, idxs, deriv::D, p, @unpack ts, timeseries, ks, f, cache = id @inbounds tdir = sign(ts[end] - ts[1]) - differential_vars = get_differential_vars(f, length(timeseries[begin])) + differential_vars = get_differential_vars(f, size(timeseries[begin])) if continuity === :left # we have i₋ = i₊ = 1 if tval = ts[1], i₊ = i₋ + 1 = lastindex(ts) if tval > ts[end], @@ -636,8 +636,14 @@ function partial_hermite_interpolant(Θ, dt, y₀::AbstractArray, y₁::Abstract res end -function partial_hermite_interpolant(Θ, dt, y₀::AbstractArray, y₁::AbstractArray, k, cache, idxs::Nothing, T::Type{Val{TI}}, differential_vars) where {TI} - partial_hermite_interpolant(Θ, dt, y₀, y₁, k, cache, eachindex(y₀), T, differential_vars) +function partial_hermite_interpolant(Θ, dt, y₀::AT, y₁::AT, k, cache, idxs::Nothing, T::Type{Val{TI}}, differential_vars) where {TI, AT <: AbstractArray} + if ArrayInterface.fast_scalar_indexing(AT) + partial_hermite_interpolant(Θ, dt, y₀, y₁, k, cache, eachindex(IndexLinear(), y₀), T, differential_vars) + else + h = hermite_interpolant(Θ, dt, y₀, y₁, k, Val{cache isa OrdinaryDiffEqMutableCache}, idxs, T) + l = linear_interpolant(Θ, dt, y₀, y₁, idxs, T) + @.. broadcast=false h*differential_vars + l*!differential_vars + end end function partial_hermite_interpolant(Θ, dt, y₀::AbstractArray, y₁::AbstractArray, k, cache, idxs::Int, T::Type{Val{TI}}, differential_vars) where {TI} @@ -672,8 +678,15 @@ function partial_hermite_interpolant!(out::AbstractArray, Θ, dt, y₀::Abstract out end -function partial_hermite_interpolant!(out::AbstractArray, Θ, dt, y₀::AbstractArray, y₁::AbstractArray, k, cache, idxs::Nothing, T::Type{Val{TI}}, differential_vars) where {TI} - partial_hermite_interpolant!(out, Θ, dt, y₀, y₁, k, cache, eachindex(y₀), T, differential_vars) +function partial_hermite_interpolant!(out::AbstractArray, Θ, dt, y₀::AT, y₁::AT, k, cache, idxs::Nothing, T::Type{Val{TI}}, differential_vars) where {TI, AT <: AbstractArray} + if ArrayInterface.fast_scalar_indexing(AT) + partial_hermite_interpolant!(out, Θ, dt, y₀, y₁, k, cache, eachindex(IndexLinear(), y₀), T, differential_vars) + else + @show differential_vars + @show h = hermite_interpolant(Θ, dt, y₀, y₁, k, Val{cache isa OrdinaryDiffEqMutableCache}, idxs, T) + @show l = linear_interpolant(Θ, dt, y₀, y₁, idxs, T) + @.. broadcast=false out=h*differential_vars + l*!differential_vars + end end function partial_hermite_interpolant!(out::AbstractArray, Θ, dt, y₀::AbstractArray, y₁::AbstractArray, k, cache, idxs::Int, T::Type{Val{TI}}, differential_vars) where {TI} diff --git a/test/interface/noindex_tests.jl b/test/interface/noindex_tests.jl index d9c4107cd0..907229e4a9 100644 --- a/test/interface/noindex_tests.jl +++ b/test/interface/noindex_tests.jl @@ -25,6 +25,7 @@ function Base.similar(bc::Base.Broadcast.Broadcasted{NoIndexStyle{N}}, end Base.Broadcast._broadcast_getindex(x::NoIndexArray, i) = x.x[i] Base.Broadcast.extrude(x::NoIndexArray) = x +ArrayInterface.fast_scalar_indexing(::Type{<:NoIndexArray}) = false @inline function Base.copyto!(dest::NoIndexArray, bc::Base.Broadcast.Broadcasted{<:NoIndexStyle})