diff --git a/src/dense/generic_dense.jl b/src/dense/generic_dense.jl index 2590488430..cb173e8fee 100644 --- a/src/dense/generic_dense.jl +++ b/src/dense/generic_dense.jl @@ -313,8 +313,8 @@ function evaluate_interpolant(f, Θ, dt, timeseries, i₋, i₊, cache, idxs, end end -function get_differential_vars(f, size) - differential_vars = trues(size) +function get_differential_vars(f, idxs, size) + differential_vars = nothing if hasproperty(f, :mass_matrix) mm = f.mass_matrix if mm isa UniformScaling @@ -326,7 +326,11 @@ function get_differential_vars(f, size) # @show typeof(mm) end end - return differential_vars + if idxs === nothing || differential_vars === nothing + return differential_vars + else + return differential_vars[idxs] + end end """ @@ -340,7 +344,7 @@ function ode_interpolation(tvals, id::I, idxs, deriv::D, p, @unpack ts, timeseries, ks, f, cache = id @inbounds tdir = sign(ts[end] - ts[1]) idx = sortperm(tvals, rev = tdir < 0) - differential_vars = get_differential_vars(f, size(timeseries[begin])) + differential_vars = get_differential_vars(f, idxs, size(timeseries[begin])) # start the search thinking it's ts[1]-ts[2] i₋₊ref = Ref((1, 2)) vals = map(idx) do j @@ -378,7 +382,7 @@ function ode_interpolation!(vals, tvals, id::I, idxs, deriv::D, p, @unpack ts, timeseries, ks, f, cache = id @inbounds tdir = sign(ts[end] - ts[1]) idx = sortperm(tvals, rev = tdir < 0) - differential_vars = get_differential_vars(f, size(timeseries[begin])) + differential_vars = get_differential_vars(f, idxs, size(timeseries[begin])) # start the search thinking it's in ts[1]-ts[2] i₋ = 1 @@ -464,7 +468,7 @@ function ode_interpolation(tval::Number, id::I, idxs, deriv::D, p, @unpack ts, timeseries, ks, f, cache = id @inbounds tdir = sign(ts[end] - ts[1]) - differential_vars = get_differential_vars(f, size(timeseries[begin])) + differential_vars = get_differential_vars(f, idxs, size(timeseries[begin])) if continuity === :left # we have i₋ = i₊ = 1 if tval = ts[1], i₊ = i₋ + 1 = lastindex(ts) if tval > ts[end], @@ -514,7 +518,7 @@ function ode_interpolation!(out, tval::Number, id::I, idxs, deriv::D, p, @unpack ts, timeseries, ks, f, cache = id @inbounds tdir = sign(ts[end] - ts[1]) - differential_vars = get_differential_vars(f, size(timeseries[begin])) + differential_vars = get_differential_vars(f, idxs, size(timeseries[begin])) if continuity === :left # we have i₋ = i₊ = 1 if tval = ts[1], i₊ = i₋ + 1 = lastindex(ts) if tval > ts[end], @@ -576,12 +580,12 @@ end """ ode_interpolant and ode_interpolant! dispatch """ -function ode_interpolant(Θ, dt, y₀, y₁, k, cache, idxs, T::Type{Val{TI}}, differential_vars=trues(size(y₀))) where {TI} +function ode_interpolant(Θ, dt, y₀, y₁, k, cache, idxs, T::Type{Val{TI}}, differential_vars=nothing) where {TI} _ode_interpolant(Θ, dt, y₀, y₁, k, cache, idxs, T, differential_vars) end function ode_interpolant(Θ, dt, y₀, y₁, k, cache::OrdinaryDiffEqMutableCache, idxs, - T::Type{Val{TI}}, differential_vars=trues(size(y₀))) where {TI} + T::Type{Val{TI}}, differential_vars=nothing) where {TI} if idxs isa Number || y₀ isa Union{Number, SArray} # typeof(y₀) can be these if saveidxs gives a single value _ode_interpolant(Θ, dt, y₀, y₁, k, cache, idxs, T, differential_vars) @@ -606,169 +610,94 @@ function ode_interpolant(Θ, dt, y₀, y₁, k, cache::OrdinaryDiffEqMutableCach end end -function ode_interpolant!(out, Θ, dt, y₀, y₁, k, cache, idxs, T::Type{Val{TI}}, differential_vars=trues(size(y₀))) where {TI} +function ode_interpolant!(out, Θ, dt, y₀, y₁, k, cache, idxs, T::Type{Val{TI}}, differential_vars=nothing) where {TI} _ode_interpolant!(out, Θ, dt, y₀, y₁, k, cache, idxs, T, differential_vars) end ##################### Hermite Interpolants # If no dispatch found, assume Hermite -function _ode_interpolant(Θ, dt, y₀, y₁, k, cache, idxs, T::Type{Val{TI}}, differential_vars=trues(size(y₀))) where {TI} - partial_hermite_interpolant(Θ, dt, y₀, y₁, k, cache, idxs, T, differential_vars) -end - -function _ode_interpolant!(out, Θ, dt, y₀, y₁, k, cache, idxs, T::Type{Val{TI}}, differential_vars=trues(size(y₀))) where {TI} - partial_hermite_interpolant!(out, Θ, dt, y₀, y₁, k, cache, idxs, T, differential_vars) -end - -function partial_hermite_interpolant(Θ, dt, y₀::AbstractArray, y₁::AbstractArray, k, cache, idxs::AbstractArray{Int}, T::Type{Val{TI}}, differential_vars) where {TI} - sel = differential_vars[idxs] - res = similar(y₀, eltype(y₀), size(sel)) - res[sel] = hermite_interpolant(Θ, dt, y₀, y₁, k, Val{cache isa OrdinaryDiffEqMutableCache}, idxs[sel], T) - res[.!sel] = linear_interpolant(Θ, dt, y₀, y₁, idxs[.!sel], T) - res -end - -function partial_hermite_interpolant(Θ, dt, y₀::AT, y₁::AT, k, cache, idxs::Nothing, T::Type{Val{TI}}, differential_vars) where {TI, AT <: AbstractArray} - if ArrayInterface.fast_scalar_indexing(AT) - sel = vec(differential_vars) - res = similar(y₀) - res[differential_vars] = hermite_interpolant(Θ, dt, y₀, y₁, k, Val{cache isa OrdinaryDiffEqMutableCache}, eachindex(y₀)[sel], T) - res[.!differential_vars] = linear_interpolant(Θ, dt, y₀, y₁, eachindex(y₀)[.!sel], T) - res - else - h = hermite_interpolant(Θ, dt, y₀, y₁, k, Val{cache isa OrdinaryDiffEqMutableCache}, idxs, T) - l = linear_interpolant(Θ, dt, y₀, y₁, idxs, T) - @.. broadcast=false h*differential_vars + l*!differential_vars - end -end - -function partial_hermite_interpolant(Θ, dt, y₀::AbstractArray, y₁::AbstractArray, k, cache, idxs::Int, T::Type{Val{TI}}, differential_vars) where {TI} - if differential_vars[idxs] - hermite_interpolant(Θ, dt, y₀, y₁, k, Val{cache isa OrdinaryDiffEqMutableCache}, idxs, T) - else - linear_interpolant(Θ, dt, y₀, y₁, idxs, T) - end -end - -function partial_hermite_interpolant(Θ, dt, y₀::Number, y₁::Number, k, cache, idxs, T::Type{Val{TI}}, differential_vars) where {TI} - if only(differential_vars) +function _ode_interpolant(Θ, dt, y₀, y₁, k, cache, idxs, T::Type{Val{TI}}, differential_vars::DV=nothing) where {TI, DV} + if DV === Nothing hermite_interpolant(Θ, dt, y₀, y₁, k, Val{cache isa OrdinaryDiffEqMutableCache}, idxs, T) else - linear_interpolant(Θ, dt, y₀, y₁, idxs, T) + hermite_interpolant(Θ, dt, y₀, y₁, k, Val{cache isa OrdinaryDiffEqMutableCache}, idxs, T, differential_vars) end end -function partial_hermite_interpolant(Θ, dt, y₀, y₁, k, cache, idxs, T::Type{Val{TI}}, differential_vars) where {TI} - throw("how did we get here") - @assert all(differential_vars) - hermite_interpolant(Θ, dt, y₀, y₁, k, Val{cache isa OrdinaryDiffEqMutableCache}, idxs, T) -end - -function partial_hermite_interpolant!(out::AbstractArray, Θ, dt, y₀::AbstractArray, y₁::AbstractArray, k, cache, idxs::AbstractArray{Int}, T::Type{Val{TI}}, differential_vars) where {TI} - sel = differential_vars[idxs] - hermite_interpolant!(@view(out[sel]), Θ, dt, y₀, y₁, k, idxs[sel], T) - linear_interpolant!(@view(out[.!sel]), Θ, dt, y₀, y₁, idxs[.!sel], T) - out -end - -function partial_hermite_interpolant!(out::AT, Θ, dt, y₀::AT, y₁::AT, k, cache, idxs::Nothing, T::Type{Val{TI}}, differential_vars) where {TI, AT <: AbstractArray} - if ArrayInterface.fast_scalar_indexing(AT) - partial_hermite_interpolant!(out, Θ, dt, y₀, y₁, k, cache, eachindex(y₀), T, differential_vars) - else - h = hermite_interpolant(Θ, dt, y₀, y₁, k, Val{cache isa OrdinaryDiffEqMutableCache}, idxs, T) - l = linear_interpolant(Θ, dt, y₀, y₁, idxs, T) - @.. broadcast=false out=h*differential_vars + l*!differential_vars - end -end - -function partial_hermite_interpolant!(out::AbstractArray, Θ, dt, y₀::AbstractArray, y₁::AbstractArray, k, cache, idxs::Int, T::Type{Val{TI}}, differential_vars) where {TI} - if differential_vars[idxs] - hermite_interpolant!(out, Θ, dt, y₀, y₁, k, idxs, T) - else - linear_interpolant!(out, Θ, dt, y₀, y₁, idxs, T) - end -end - -function partial_hermite_interpolant!(out, Θ, dt, y₀::Number, y₁::Number, k, cache, idxs, T::Type{Val{TI}}, differential_vars) where {TI} - if only(differential_vars) +function _ode_interpolant!(out, Θ, dt, y₀, y₁, k, cache, idxs, T::Type{Val{TI}}, differential_vars::DV=nothing) where {TI, DV} + if DV === Nothing hermite_interpolant!(out, Θ, dt, y₀, y₁, k, idxs, T) else - linear_interpolant!(out, Θ, dt, y₀, y₁, idxs, T) + hermite_interpolant!(out, Θ, dt, y₀, y₁, k, idxs, T, differential_vars) end end -function partial_hermite_interpolant!(out, Θ, dt, y₀, y₁, k, cache, idxs, T::Type{Val{TI}}, differential_vars) where {TI} - throw("how did we get here") - @assert all(differential_vars) - hermite_interpolant!(out, Θ, dt, y₀, y₁, k, idxs, T) -end - """ Hairer Norsett Wanner Solving Ordinary Differential Euations I - Nonstiff Problems Page 190 Herimte Interpolation, chosen if no other dispatch for ode_interpolant """ @muladd function hermite_interpolant(Θ, dt, y₀, y₁, k, ::Type{Val{false}}, idxs::Nothing, - T::Type{Val{0}}) # Default interpolant is Hermite + T::Type{Val{0}}, dv=1) # Default interpolant is Hermite #@.. broadcast=false (1-Θ)*y₀+Θ*y₁+Θ*(Θ-1)*((1-2Θ)*(y₁-y₀)+(Θ-1)*dt*k[1] + Θ*dt*k[2]) @inbounds (1 - Θ) * y₀ + Θ * y₁ + - Θ * (Θ - 1) * ((1 - 2Θ) * (y₁ - y₀) + (Θ - 1) * dt * k[1] + Θ * dt * k[2]) + dv * Θ * (Θ - 1) * ((1 - 2Θ) * (y₁ - y₀) + (Θ - 1) * dt * k[1] + Θ * dt * k[2]) end @muladd function hermite_interpolant(Θ, dt, y₀, y₁, k, ::Type{Val{true}}, idxs::Nothing, - T::Type{Val{0}}) # Default interpolant is Hermite + T::Type{Val{0}}, dv=1) # Default interpolant is Hermite #@.. broadcast=false (1-Θ)*y₀+Θ*y₁+Θ*(Θ-1)*((1-2Θ)*(y₁-y₀)+(Θ-1)*dt*k[1] + Θ*dt*k[2]) @inbounds @.. broadcast=false (1 - Θ)*y₀+Θ*y₁+ - Θ*(Θ-1)*((1 - 2Θ)*(y₁ - y₀)+(Θ-1)*dt*k[1]+Θ*dt*k[2]) + dv*Θ*(Θ-1)*((1 - 2Θ)*(y₁ - y₀)+(Θ-1)*dt*k[1]+Θ*dt*k[2]) end @muladd function hermite_interpolant(Θ, dt, y₀::Array, y₁, k, ::Type{Val{true}}, - idxs::Nothing, T::Type{Val{0}}) # Default interpolant is Hermite + idxs::Nothing, T::Type{Val{0}}, dv=trues(size(y₀))) # Default interpolant is Hermite out = similar(y₀) @inbounds @simd ivdep for i in eachindex(y₀) out[i] = (1 - Θ) * y₀[i] + Θ * y₁[i] + - Θ * (Θ - 1) * + dv[i] * Θ * (Θ - 1) * ((1 - 2Θ) * (y₁[i] - y₀[i]) + (Θ - 1) * dt * k[1][i] + Θ * dt * k[2][i]) end end -@muladd function hermite_interpolant(Θ, dt, y₀, y₁, k, cache, idxs, T::Type{Val{0}}) # Default interpolant is Hermite +@muladd function hermite_interpolant(Θ, dt, y₀, y₁, k, cache, idxs, T::Type{Val{0}}, dv=1) # Default interpolant is Hermite # return @.. broadcast=false (1-Θ)*y₀[idxs]+Θ*y₁[idxs]+Θ*(Θ-1)*((1-2Θ)*(y₁[idxs]-y₀[idxs])+(Θ-1)*dt*k[1][idxs] + Θ*dt*k[2][idxs]) return (1 - Θ) * y₀[idxs] + Θ * y₁[idxs] + - Θ * (Θ - 1) * + dv * Θ * (Θ - 1) * ((1 - 2Θ) * (y₁[idxs] - y₀[idxs]) + (Θ - 1) * dt * k[1][idxs] + Θ * dt * k[2][idxs]) end -@muladd function hermite_interpolant!(out, Θ, dt, y₀, y₁, k, idxs::Nothing, T::Type{Val{0}}) # Default interpolant is Hermite +@muladd function hermite_interpolant!(out, Θ, dt, y₀, y₁, k, idxs::Nothing, T::Type{Val{0}}, dv=1) # Default interpolant is Hermite @inbounds @.. broadcast=false out=(1 - Θ) * y₀ + Θ * y₁ + - Θ * (Θ - 1) * + dv * Θ * (Θ - 1) * ((1 - 2Θ) * (y₁ - y₀) + (Θ - 1) * dt * k[1] + Θ * dt * k[2]) end @muladd function hermite_interpolant!(out::Array, Θ, dt, y₀, y₁, k, idxs::Nothing, - T::Type{Val{0}}) # Default interpolant is Hermite + T::Type{Val{0}}, dv=trues(size(y₀))) # Default interpolant is Hermite @inbounds @simd ivdep for i in eachindex(out) out[i] = (1 - Θ) * y₀[i] + Θ * y₁[i] + - Θ * (Θ - 1) * + dv[i] * Θ * (Θ - 1) * ((1 - 2Θ) * (y₁[i] - y₀[i]) + (Θ - 1) * dt * k[1][i] + Θ * dt * k[2][i]) end out end -@muladd function hermite_interpolant!(out, Θ, dt, y₀, y₁, k, idxs, T::Type{Val{0}}) # Default interpolant is Hermite +@muladd function hermite_interpolant!(out, Θ, dt, y₀, y₁, k, idxs, T::Type{Val{0}}, dv=1) # Default interpolant is Hermite @views @.. broadcast=false out=(1 - Θ) * y₀[idxs] + Θ * y₁[idxs] + - Θ * (Θ - 1) * + dv * Θ * (Θ - 1) * ((1 - 2Θ) * (y₁[idxs] - y₀[idxs]) + (Θ - 1) * dt * k[1][idxs] + Θ * dt * k[2][idxs]) end -@muladd function hermite_interpolant!(out::Array, Θ, dt, y₀, y₁, k, idxs, T::Type{Val{0}}) # Default interpolant is Hermite +@muladd function hermite_interpolant!(out::Array, Θ, dt, y₀, y₁, k, idxs, T::Type{Val{0}}, dv=trues(size(out))) # Default interpolant is Hermite @inbounds for (j, i) in enumerate(idxs) out[j] = (1 - Θ) * y₀[i] + Θ * y₁[i] + - Θ * (Θ - 1) * + dv[j] * Θ * (Θ - 1) * ((1 - 2Θ) * (y₁[i] - y₀[i]) + (Θ - 1) * dt * k[1][i] + Θ * dt * k[2][i]) end out @@ -778,7 +707,7 @@ end Herimte Interpolation, chosen if no other dispatch for ode_interpolant """ @muladd function hermite_interpolant(Θ, dt, y₀, y₁, k, ::Type{Val{false}}, idxs::Nothing, - T::Type{Val{1}}) # Default interpolant is Hermite + T::Type{Val{1}}, dv=I) # Default interpolant is Hermite #@.. broadcast=false k[1] + Θ*(-4*dt*k[1] - 2*dt*k[2] - 6*y₀ + Θ*(3*dt*k[1] + 3*dt*k[2] + 6*y₀ - 6*y₁) + 6*y₁)/dt @inbounds k[1] + Θ * (-4 * dt * k[1] - 2 * dt * k[2] - 6 * y₀ + @@ -786,14 +715,14 @@ Herimte Interpolation, chosen if no other dispatch for ode_interpolant end @muladd function hermite_interpolant(Θ, dt, y₀, y₁, k, ::Type{Val{true}}, idxs::Nothing, - T::Type{Val{1}}) # Default interpolant is Hermite + T::Type{Val{1}}, dv=I) # Default interpolant is Hermite @inbounds @.. broadcast=false k[1]+Θ * (-4 * dt * k[1] - 2 * dt * k[2] - 6 * y₀ + Θ * (3 * dt * k[1] + 3 * dt * k[2] + 6 * y₀ - 6 * y₁) + 6 * y₁) / dt end -@muladd function hermite_interpolant(Θ, dt, y₀, y₁, k, cache, idxs, T::Type{Val{1}}) # Default interpolant is Hermite +@muladd function hermite_interpolant(Θ, dt, y₀, y₁, k, cache, idxs, T::Type{Val{1}}, dv=I) # Default interpolant is Hermite # return @.. broadcast=false k[1][idxs] + Θ*(-4*dt*k[1][idxs] - 2*dt*k[2][idxs] - 6*y₀[idxs] + Θ*(3*dt*k[1][idxs] + 3*dt*k[2][idxs] + 6*y₀[idxs] - 6*y₁[idxs]) + 6*y₁[idxs])/dt return k[1][idxs] + Θ * (-4 * dt * k[1][idxs] - 2 * dt * k[2][idxs] - 6 * y₀[idxs] + @@ -801,7 +730,7 @@ end 6 * y₁[idxs]) / dt end -@muladd function hermite_interpolant!(out, Θ, dt, y₀, y₁, k, idxs::Nothing, T::Type{Val{1}}) # Default interpolant is Hermite +@muladd function hermite_interpolant!(out, Θ, dt, y₀, y₁, k, idxs::Nothing, T::Type{Val{1}}, dv=I) # Default interpolant is Hermite @inbounds @.. broadcast=false out=k[1] + Θ * (-4 * dt * k[1] - 2 * dt * k[2] - 6 * y₀ + Θ * @@ -810,7 +739,7 @@ end end @muladd function hermite_interpolant!(out::Array, Θ, dt, y₀, y₁, k, idxs::Nothing, - T::Type{Val{1}}) # Default interpolant is Hermite + T::Type{Val{1}}, dv=I) # Default interpolant is Hermite @inbounds @simd ivdep for i in eachindex(out) out[i] = k[1][i] + Θ * (-4 * dt * k[1][i] - 2 * dt * k[2][i] - 6 * y₀[i] + @@ -820,7 +749,7 @@ end out end -@muladd function hermite_interpolant!(out, Θ, dt, y₀, y₁, k, idxs, T::Type{Val{1}}) # Default interpolant is Hermite +@muladd function hermite_interpolant!(out, Θ, dt, y₀, y₁, k, idxs, T::Type{Val{1}}, dv=I) # Default interpolant is Hermite @views @.. broadcast=false out=k[1][idxs] + Θ * (-4 * dt * k[1][idxs] - 2 * dt * k[2][idxs] - 6 * y₀[idxs] + @@ -828,7 +757,7 @@ end 6 * y₀[idxs] - 6 * y₁[idxs]) + 6 * y₁[idxs]) / dt end -@muladd function hermite_interpolant!(out::Array, Θ, dt, y₀, y₁, k, idxs, T::Type{Val{1}}) # Default interpolant is Hermite +@muladd function hermite_interpolant!(out::Array, Θ, dt, y₀, y₁, k, idxs, T::Type{Val{1}}, dv=I) # Default interpolant is Hermite @inbounds for (j, i) in enumerate(idxs) out[j] = k[1][i] + Θ * (-4 * dt * k[1][i] - 2 * dt * k[2][i] - 6 * y₀[i] + @@ -842,21 +771,21 @@ end Herimte Interpolation, chosen if no other dispatch for ode_interpolant """ @muladd function hermite_interpolant(Θ, dt, y₀, y₁, k, ::Type{Val{false}}, idxs::Nothing, - T::Type{Val{2}}) # Default interpolant is Hermite + T::Type{Val{2}}, dv=I) # Default interpolant is Hermite #@.. broadcast=false (-4*dt*k[1] - 2*dt*k[2] - 6*y₀ + Θ*(6*dt*k[1] + 6*dt*k[2] + 12*y₀ - 12*y₁) + 6*y₁)/(dt*dt) @inbounds (-4 * dt * k[1] - 2 * dt * k[2] - 6 * y₀ + Θ * (6 * dt * k[1] + 6 * dt * k[2] + 12 * y₀ - 12 * y₁) + 6 * y₁) / (dt * dt) end @muladd function hermite_interpolant(Θ, dt, y₀, y₁, k, ::Type{Val{true}}, idxs::Nothing, - T::Type{Val{2}}) # Default interpolant is Hermite + T::Type{Val{2}}, dv=I) # Default interpolant is Hermite #@.. broadcast=false (-4*dt*k[1] - 2*dt*k[2] - 6*y₀ + Θ*(6*dt*k[1] + 6*dt*k[2] + 12*y₀ - 12*y₁) + 6*y₁)/(dt*dt) @inbounds @.. broadcast=false (-4 * dt * k[1] - 2 * dt * k[2] - 6 * y₀ + Θ * (6 * dt * k[1] + 6 * dt * k[2] + 12 * y₀ - 12 * y₁) + 6 * y₁)/(dt * dt) end -@muladd function hermite_interpolant(Θ, dt, y₀, y₁, k, cache, idxs, T::Type{Val{2}}) # Default interpolant is Hermite +@muladd function hermite_interpolant(Θ, dt, y₀, y₁, k, cache, idxs, T::Type{Val{2}}, dv=I) # Default interpolant is Hermite #out = similar(y₀,axes(idxs)) #@views @.. broadcast=false out = (-4*dt*k[1][idxs] - 2*dt*k[2][idxs] - 6*y₀[idxs] + Θ*(6*dt*k[1][idxs] + 6*dt*k[2][idxs] + 12*y₀[idxs] - 12*y₁[idxs]) + 6*y₁[idxs])/(dt*dt) @views out = (-4 * dt * k[1][idxs] - 2 * dt * k[2][idxs] - 6 * y₀[idxs] + @@ -865,7 +794,7 @@ end out end -@muladd function hermite_interpolant!(out, Θ, dt, y₀, y₁, k, idxs::Nothing, T::Type{Val{2}}) # Default interpolant is Hermite +@muladd function hermite_interpolant!(out, Θ, dt, y₀, y₁, k, idxs::Nothing, T::Type{Val{2}}, dv=I) # Default interpolant is Hermite @inbounds @.. broadcast=false out=(-4 * dt * k[1] - 2 * dt * k[2] - 6 * y₀ + Θ * (6 * dt * k[1] + 6 * dt * k[2] + 12 * y₀ - 12 * y₁) + @@ -873,7 +802,7 @@ end end @muladd function hermite_interpolant!(out::Array, Θ, dt, y₀, y₁, k, idxs::Nothing, - T::Type{Val{2}}) # Default interpolant is Hermite + T::Type{Val{2}}, dv=I) # Default interpolant is Hermite @inbounds @simd ivdep for i in eachindex(out) out[i] = (-4 * dt * k[1][i] - 2 * dt * k[2][i] - 6 * y₀[i] + Θ * (6 * dt * k[1][i] + 6 * dt * k[2][i] + 12 * y₀[i] - 12 * y₁[i]) + @@ -882,7 +811,7 @@ end out end -@muladd function hermite_interpolant!(out, Θ, dt, y₀, y₁, k, idxs, T::Type{Val{2}}) # Default interpolant is Hermite +@muladd function hermite_interpolant!(out, Θ, dt, y₀, y₁, k, idxs, T::Type{Val{2}}, dv=I) # Default interpolant is Hermite @views @.. broadcast=false out=(-4 * dt * k[1][idxs] - 2 * dt * k[2][idxs] - 6 * y₀[idxs] + Θ * (6 * dt * k[1][idxs] + 6 * dt * k[2][idxs] + @@ -890,7 +819,7 @@ end (dt * dt) end -@muladd function hermite_interpolant!(out::Array, Θ, dt, y₀, y₁, k, idxs, T::Type{Val{2}}) # Default interpolant is Hermite +@muladd function hermite_interpolant!(out::Array, Θ, dt, y₀, y₁, k, idxs, T::Type{Val{2}}, dv=I) # Default interpolant is Hermite @views @.. broadcast=false out=(-4 * dt * k[1][idxs] - 2 * dt * k[2][idxs] - 6 * y₀[idxs] + Θ * (6 * dt * k[1][idxs] + 6 * dt * k[2][idxs] + @@ -908,13 +837,13 @@ end Herimte Interpolation, chosen if no other dispatch for ode_interpolant """ @muladd function hermite_interpolant(Θ, dt, y₀, y₁, k, ::Type{Val{false}}, idxs::Nothing, - T::Type{Val{3}}) # Default interpolant is Hermite + T::Type{Val{3}}, dv=I) # Default interpolant is Hermite #@.. broadcast=false (6*dt*k[1] + 6*dt*k[2] + 12*y₀ - 12*y₁)/(dt*dt*dt) @inbounds (6 * dt * k[1] + 6 * dt * k[2] + 12 * y₀ - 12 * y₁) / (dt * dt * dt) end @muladd function hermite_interpolant(Θ, dt, y₀, y₁, k, ::Type{Val{true}}, idxs::Nothing, - T::Type{Val{3}}) # Default interpolant is Hermite + T::Type{Val{3}}, dv=I) # Default interpolant is Hermite #@.. broadcast=false (6*dt*k[1] + 6*dt*k[2] + 12*y₀ - 12*y₁)/(dt*dt*dt) @inbounds @.. broadcast=false (6 * dt * k[1] + 6 * dt * k[2] + 12 * y₀ - 12 * y₁)/(dt * @@ -922,7 +851,7 @@ end dt) end -@muladd function hermite_interpolant(Θ, dt, y₀, y₁, k, cache, idxs, T::Type{Val{3}}) # Default interpolant is Hermite +@muladd function hermite_interpolant(Θ, dt, y₀, y₁, k, cache, idxs, T::Type{Val{3}}, dv=I) # Default interpolant is Hermite #out = similar(y₀,axes(idxs)) #@views @.. broadcast=false out = (6*dt*k[1][idxs] + 6*dt*k[2][idxs] + 12*y₀[idxs] - 12*y₁[idxs])/(dt*dt*dt) @views out = (6 * dt * k[1][idxs] + 6 * dt * k[2][idxs] + 12 * y₀[idxs] - @@ -931,7 +860,7 @@ end out end -@muladd function hermite_interpolant!(out, Θ, dt, y₀, y₁, k, idxs::Nothing, T::Type{Val{3}}) # Default interpolant is Hermite +@muladd function hermite_interpolant!(out, Θ, dt, y₀, y₁, k, idxs::Nothing, T::Type{Val{3}}, dv=I) # Default interpolant is Hermite @inbounds @.. broadcast=false out=(6 * dt * k[1] + 6 * dt * k[2] + 12 * y₀ - 12 * y₁) / (dt * dt * dt) #for i in eachindex(out) @@ -941,7 +870,7 @@ end end @muladd function hermite_interpolant!(out::Array, Θ, dt, y₀, y₁, k, idxs::Nothing, - T::Type{Val{3}}) # Default interpolant is Hermite + T::Type{Val{3}}, dv=I) # Default interpolant is Hermite @inbounds @simd ivdep for i in eachindex(out) out[i] = (6 * dt * k[1][i] + 6 * dt * k[2][i] + 12 * y₀[i] - 12 * y₁[i]) / (dt * dt * dt) @@ -949,12 +878,12 @@ end out end -@muladd function hermite_interpolant!(out, Θ, dt, y₀, y₁, k, idxs, T::Type{Val{3}}) # Default interpolant is Hermite +@muladd function hermite_interpolant!(out, Θ, dt, y₀, y₁, k, idxs, T::Type{Val{3}}, dv=I) # Default interpolant is Hermite @views @.. broadcast=false out=(6 * dt * k[1][idxs] + 6 * dt * k[2][idxs] + 12 * y₀[idxs] - 12 * y₁[idxs]) / (dt * dt * dt) end -@muladd function hermite_interpolant!(out::Array, Θ, dt, y₀, y₁, k, idxs, T::Type{Val{3}}) # Default interpolant is Hermite +@muladd function hermite_interpolant!(out::Array, Θ, dt, y₀, y₁, k, idxs, T::Type{Val{3}}, dv=I) # Default interpolant is Hermite @inbounds for (j, i) in enumerate(idxs) out[j] = (6 * dt * k[1][i] + 6 * dt * k[2][i] + 12 * y₀[i] - 12 * y₁[i]) / (dt * dt * dt) diff --git a/test/interface/inplace_interpolation.jl b/test/interface/inplace_interpolation.jl index b5ea6459ca..474ef15708 100644 --- a/test/interface/inplace_interpolation.jl +++ b/test/interface/inplace_interpolation.jl @@ -25,7 +25,7 @@ out_VMF = vecarrzero(ntt, size(prob_ode_2Dlinear.u0)) # Vector{Matrix{Float64} @test_throws MethodError sol_ODE(out_VF, tt; idxs = 1:1) @test sol_ODE(out_VF, tt) isa Vector{Float64} @test sol_ODE(out_VVF_1, tt) isa Vector{Vector{Float64}} - @test sol_ODE_interp.u == out_VF + @test sol_ODE_interp.u ≈ out_VF end @testset "2D" begin @@ -35,6 +35,6 @@ out_VMF = vecarrzero(ntt, size(prob_ode_2Dlinear.u0)) # Vector{Matrix{Float64} @test sol_ODE_2D(out_VVF_1, tt; idxs = 3:3) isa Vector{Vector{Float64}} @test sol_ODE_2D(out_VVF_2, tt; idxs = 2:3) isa Vector{Vector{Float64}} @test sol_ODE_2D(out_VMF, tt) isa Vector{Matrix{Float64}} - @test sol_ODE_2D_interp.u == out_VMF + @test sol_ODE_2D_interp.u ≈ out_VMF end end