diff --git a/Project.toml b/Project.toml index 41990bcb..59d27ef9 100644 --- a/Project.toml +++ b/Project.toml @@ -27,7 +27,7 @@ DataInterpolationsSymbolicsExt = "Symbolics" Aqua = "0.8" BenchmarkTools = "1" ChainRulesCore = "1.24" -FindFirstFunctions = "1.1" +FindFirstFunctions = "1.3" FiniteDifferences = "0.12.31" ForwardDiff = "0.10.36" LinearAlgebra = "1.10" diff --git a/ext/DataInterpolationsChainRulesCoreExt.jl b/ext/DataInterpolationsChainRulesCoreExt.jl index 312e67e3..988e519b 100644 --- a/ext/DataInterpolationsChainRulesCoreExt.jl +++ b/ext/DataInterpolationsChainRulesCoreExt.jl @@ -52,7 +52,7 @@ end function u_tangent(A::LinearInterpolation, t, Δ) out = zero(A.u) - idx = get_idx(A, t, A.idx_prev[]) + idx = get_idx(A, t, A.iguesser) t_factor = (t - A.t[idx]) / (A.t[idx + 1] - A.t[idx]) out[idx] = Δ * (one(eltype(out)) - t_factor) out[idx + 1] = Δ * t_factor @@ -61,7 +61,7 @@ end function u_tangent(A::QuadraticInterpolation, t, Δ) out = zero(A.u) - i₀, i₁, i₂ = _quad_interp_indices(A, t, A.idx_prev[]) + i₀, i₁, i₂ = _quad_interp_indices(A, t, A.iguesser) t₀ = A.t[i₀] t₁ = A.t[i₁] t₂ = A.t[i₂] diff --git a/src/DataInterpolations.jl b/src/DataInterpolations.jl index 7f44c878..5f035e62 100644 --- a/src/DataInterpolations.jl +++ b/src/DataInterpolations.jl @@ -8,7 +8,7 @@ using LinearAlgebra, RecipesBase using PrettyTables using ForwardDiff import FindFirstFunctions: searchsortedfirstcorrelated, searchsortedlastcorrelated, - bracketstrictlymontonic + Guesser include("parameter_caches.jl") include("interpolation_caches.jl") @@ -22,10 +22,6 @@ include("online.jl") include("show.jl") (interp::AbstractInterpolation)(t::Number) = _interpolate(interp, t) -function (interp::AbstractInterpolation)(t::Number, i::Integer) - interp.idx_prev[] = i - _interpolate(interp, t) -end function (interp::AbstractInterpolation)(t::AbstractVector) u = get_u(interp.u, t) diff --git a/src/derivatives.jl b/src/derivatives.jl index e7cb97fd..ff5b9850 100644 --- a/src/derivatives.jl +++ b/src/derivatives.jl @@ -1,16 +1,12 @@ function derivative(A, t, order = 1) ((t < A.t[1] || t > A.t[end]) && !A.extrapolate) && throw(ExtrapolationError()) - iguess = A.idx_prev[] + iguess = A.iguesser return if order == 1 - val, idx = _derivative(A, t, iguess) - A.idx_prev[] = idx - val + _derivative(A, t, iguess) elseif order == 2 ForwardDiff.derivative(t -> begin - val, idx = _derivative(A, t, iguess) - A.idx_prev[] = idx - val + _derivative(A, t, iguess) end, t) else throw(DerivativeNotFoundError()) @@ -20,7 +16,7 @@ end function _derivative(A::LinearInterpolation, t::Number, iguess) idx = get_idx(A, t, iguess; idx_shift = -1, ub_shift = -1, side = :first) slope = get_parameters(A, idx) - slope, idx + slope end function _derivative(A::QuadraticInterpolation, t::Number, iguess) @@ -29,7 +25,7 @@ function _derivative(A::QuadraticInterpolation, t::Number, iguess) du₀ = l₀ * (2t - A.t[i₁] - A.t[i₂]) du₁ = l₁ * (2t - A.t[i₀] - A.t[i₂]) du₂ = l₂ * (2t - A.t[i₀] - A.t[i₁]) - return @views @. du₀ + du₁ + du₂, i₀ + return @views @. du₀ + du₁ + du₂ end function _derivative(A::LagrangeInterpolation{<:AbstractVector}, t::Number) @@ -101,21 +97,21 @@ function _derivative(A::LagrangeInterpolation{<:AbstractMatrix}, t::Number) end function _derivative(A::LagrangeInterpolation{<:AbstractVector}, t::Number, idx) - _derivative(A, t), idx + _derivative(A, t) end function _derivative(A::LagrangeInterpolation{<:AbstractMatrix}, t::Number, idx) - _derivative(A, t), idx + _derivative(A, t) end function _derivative(A::AkimaInterpolation{<:AbstractVector}, t::Number, iguess) idx = get_idx(A, t, iguess; idx_shift = -1, side = :first) j = min(idx, length(A.c)) # for smooth derivative at A.t[end] wj = t - A.t[idx] - (@evalpoly wj A.b[idx] 2A.c[j] 3A.d[j]), idx + @evalpoly wj A.b[idx] 2A.c[j] 3A.d[j] end function _derivative(A::ConstantInterpolation, t::Number, iguess) - return zero(first(A.u)), iguess + return zero(first(A.u)) end function _derivative(A::ConstantInterpolation{<:AbstractVector}, t::Number) @@ -132,7 +128,7 @@ end function _derivative(A::QuadraticSpline{<:AbstractVector}, t::Number, iguess) idx = get_idx(A, t, iguess; lb = 2, ub_shift = 0, side = :first) σ = get_parameters(A, idx - 1) - A.z[idx - 1] + 2σ * (t - A.t[idx - 1]), idx + A.z[idx - 1] + 2σ * (t - A.t[idx - 1]) end # CubicSpline Interpolation @@ -144,13 +140,13 @@ function _derivative(A::CubicSpline{<:AbstractVector}, t::Number, iguess) c₁, c₂ = get_parameters(A, idx) dC = c₁ dD = -c₂ - dI + dC + dD, idx + dI + dC + dD end function _derivative(A::BSplineInterpolation{<:AbstractVector{<:Number}}, t::Number, iguess) # change t into param [0 1] - t < A.t[1] && return zero(A.u[1]), 1 - t > A.t[end] && return zero(A.u[end]), lastindex(t) + t < A.t[1] && return zero(A.u[1]) + t > A.t[end] && return zero(A.u[end]) idx = get_idx(A, t, iguess) n = length(A.t) scale = (A.p[idx + 1] - A.p[idx]) / (A.t[idx + 1] - A.t[idx]) @@ -165,14 +161,14 @@ function _derivative(A::BSplineInterpolation{<:AbstractVector{<:Number}}, t::Num ducum += N[i + 1] * (A.c[i + 1] - A.c[i]) / (A.k[i + A.d + 1] - A.k[i + 1]) end end - ducum * A.d * scale, idx + ducum * A.d * scale end # BSpline Curve Approx function _derivative(A::BSplineApprox{<:AbstractVector{<:Number}}, t::Number, iguess) # change t into param [0 1] - t < A.t[1] && return zero(A.u[1]), 1 - t > A.t[end] && return zero(A.u[end]), lastindex(t) + t < A.t[1] && return zero(A.u[1]) + t > A.t[end] && return zero(A.u[end]) idx = get_idx(A, t, iguess) scale = (A.p[idx + 1] - A.p[idx]) / (A.t[idx + 1] - A.t[idx]) t_ = A.p[idx] + (t - A.t[idx]) * scale @@ -186,7 +182,7 @@ function _derivative(A::BSplineApprox{<:AbstractVector{<:Number}}, t::Number, ig ducum += N[i + 1] * (A.c[i + 1] - A.c[i]) / (A.k[i + A.d + 1] - A.k[i + 1]) end end - ducum * A.d * scale, idx + ducum * A.d * scale end # Cubic Hermite Spline @@ -198,7 +194,7 @@ function _derivative( out = A.du[idx] c₁, c₂ = get_parameters(A, idx) out += Δt₀ * (Δt₀ * c₂ + 2(c₁ + Δt₁ * c₂)) - out, idx + out end # Quintic Hermite Spline @@ -211,5 +207,5 @@ function _derivative( c₁, c₂, c₃ = get_parameters(A, idx) out += Δt₀^2 * (3c₁ + (3Δt₁ + Δt₀) * c₂ + (3Δt₁^2 + Δt₀ * 2Δt₁) * c₃) - out, idx + out end diff --git a/src/integral_inverses.jl b/src/integral_inverses.jl index aebee9d5..4cddf7d3 100644 --- a/src/integral_inverses.jl +++ b/src/integral_inverses.jl @@ -18,7 +18,7 @@ invert_integral(A::AbstractInterpolation) = throw(IntegralInverseNotFoundError() _integral(A::AbstractIntegralInverseInterpolation, idx, t) = throw(IntegralNotFoundError()) function _derivative(A::AbstractIntegralInverseInterpolation, t::Number, iguess) - inv(A.itp(A(t))), A.idx_prev[] + inv(A.itp(A(t))) end """ @@ -38,11 +38,11 @@ struct LinearInterpolationIntInv{uType, tType, itpType, T} <: u::uType t::tType extrapolate::Bool - idx_prev::Base.RefValue{Int} + iguesser::Guesser{tType} itp::itpType function LinearInterpolationIntInv(u, t, A) new{typeof(u), typeof(t), typeof(A), eltype(u)}( - u, t, A.extrapolate, Ref(1), A) + u, t, A.extrapolate, Guesser(t), A) end end @@ -64,7 +64,7 @@ function _interpolate( x = A.itp.u[idx] slope = get_parameters(A.itp, idx) u = A.u[idx] + 2Δt / (x + sqrt(x^2 + slope * 2Δt)) - u, idx + u end """ @@ -84,11 +84,11 @@ struct ConstantInterpolationIntInv{uType, tType, itpType, T} <: u::uType t::tType extrapolate::Bool - idx_prev::Base.RefValue{Int} + iguesser::Guesser{tType} itp::itpType function ConstantInterpolationIntInv(u, t, A) new{typeof(u), typeof(t), typeof(A), eltype(u)}( - u, t, A.extrapolate, Ref(1), A + u, t, A.extrapolate, Guesser(t), A ) end end @@ -112,5 +112,5 @@ function _interpolate( # :right means that value to the right is used for interpolation idx_ = get_idx(A, t, idx; side = :first, lb = 1, ub_shift = 0) end - A.u[idx] + (t - A.t[idx]) / A.itp.u[idx_], idx + A.u[idx] + (t - A.t[idx]) / A.itp.u[idx_] end diff --git a/src/interpolation_caches.jl b/src/interpolation_caches.jl index 71b2b042..d7dce336 100644 --- a/src/interpolation_caches.jl +++ b/src/interpolation_caches.jl @@ -25,13 +25,13 @@ struct LinearInterpolation{uType, tType, IType, pType, T} <: AbstractInterpolati I::IType p::LinearParameterCache{pType} extrapolate::Bool - idx_prev::Base.RefValue{Int} + iguesser::Guesser{tType} cache_parameters::Bool linear_lookup::Bool function LinearInterpolation(u, t, I, p, extrapolate, cache_parameters, assume_linear_t) linear_lookup = seems_linear(assume_linear_t, t) new{typeof(u), typeof(t), typeof(I), typeof(p.slope), eltype(u)}( - u, t, I, p, extrapolate, Ref(1), cache_parameters, linear_lookup) + u, t, I, p, extrapolate, Guesser(t), cache_parameters, linear_lookup) end end @@ -73,7 +73,7 @@ struct QuadraticInterpolation{uType, tType, IType, pType, T} <: AbstractInterpol p::QuadraticParameterCache{pType} mode::Symbol extrapolate::Bool - idx_prev::Base.RefValue{Int} + iguesser::Guesser{tType} cache_parameters::Bool linear_lookup::Bool function QuadraticInterpolation( @@ -82,7 +82,7 @@ struct QuadraticInterpolation{uType, tType, IType, pType, T} <: AbstractInterpol error("mode should be :Forward or :Backward for QuadraticInterpolation") linear_lookup = seems_linear(assume_linear_t, t) new{typeof(u), typeof(t), typeof(I), typeof(p.l₀), eltype(u)}( - u, t, I, p, mode, extrapolate, Ref(1), cache_parameters, linear_lookup) + u, t, I, p, mode, extrapolate, Guesser(t), cache_parameters, linear_lookup) end end @@ -124,7 +124,7 @@ struct LagrangeInterpolation{uType, tType, T, bcacheType} <: bcache::bcacheType idxs::Vector{Int} extrapolate::Bool - idx_prev::Base.RefValue{Int} + iguesser::Guesser{tType} function LagrangeInterpolation(u, t, n, extrapolate) bcache = zeros(eltype(u[1]), n + 1) idxs = zeros(Int, n + 1) @@ -135,7 +135,7 @@ struct LagrangeInterpolation{uType, tType, T, bcacheType} <: bcache, idxs, extrapolate, - Ref(1) + Guesser(t) ) end end @@ -178,7 +178,7 @@ struct AkimaInterpolation{uType, tType, IType, bType, cType, dType, T} <: c::cType d::dType extrapolate::Bool - idx_prev::Base.RefValue{Int} + iguesser::Guesser{tType} cache_parameters::Bool linear_lookup::Bool function AkimaInterpolation( @@ -192,7 +192,7 @@ struct AkimaInterpolation{uType, tType, IType, bType, cType, dType, T} <: c, d, extrapolate, - Ref(1), + Guesser(t), cache_parameters, linear_lookup ) @@ -258,14 +258,14 @@ struct ConstantInterpolation{uType, tType, IType, T} <: AbstractInterpolation{T} p::Nothing dir::Symbol # indicates if value to the $dir should be used for the interpolation extrapolate::Bool - idx_prev::Base.RefValue{Int} + iguesser::Guesser{tType} cache_parameters::Bool linear_lookup::Bool function ConstantInterpolation( u, t, I, dir, extrapolate, cache_parameters, assume_linear_t) linear_lookup = seems_linear(assume_linear_t, t) new{typeof(u), typeof(t), typeof(I), eltype(u)}( - u, t, I, nothing, dir, extrapolate, Ref(1), cache_parameters, linear_lookup) + u, t, I, nothing, dir, extrapolate, Guesser(t), cache_parameters, linear_lookup) end end @@ -309,7 +309,7 @@ struct QuadraticSpline{uType, tType, IType, pType, tAType, dType, zType, T} <: d::dType z::zType extrapolate::Bool - idx_prev::Base.RefValue{Int} + iguesser::Guesser{tType} cache_parameters::Bool linear_lookup::Bool function QuadraticSpline( @@ -324,7 +324,7 @@ struct QuadraticSpline{uType, tType, IType, pType, tAType, dType, zType, T} <: d, z, extrapolate, - Ref(1), + Guesser(t), cache_parameters, linear_lookup ) @@ -410,7 +410,7 @@ struct CubicSpline{uType, tType, IType, pType, hType, zType, T} <: AbstractInter h::hType z::zType extrapolate::Bool - idx_prev::Base.RefValue{Int} + iguesser::Guesser{tType} cache_parameters::Bool linear_lookup::Bool function CubicSpline(u, t, I, p, h, z, extrapolate, cache_parameters, assume_linear_t) @@ -423,7 +423,7 @@ struct CubicSpline{uType, tType, IType, pType, hType, zType, T} <: AbstractInter h, z, extrapolate, - Ref(1), + Guesser(t), cache_parameters, linear_lookup ) @@ -520,7 +520,7 @@ struct BSplineInterpolation{uType, tType, pType, kType, cType, NType, T} <: pVecType::Symbol knotVecType::Symbol extrapolate::Bool - idx_prev::Base.RefValue{Int} + iguesser::Guesser{tType} linear_lookup::Bool function BSplineInterpolation(u, t, @@ -544,7 +544,7 @@ struct BSplineInterpolation{uType, tType, pType, kType, cType, NType, T} <: pVecType, knotVecType, extrapolate, - Ref(1), + Guesser(t), linear_lookup ) end @@ -656,7 +656,7 @@ struct BSplineApprox{uType, tType, pType, kType, cType, NType, T} <: pVecType::Symbol knotVecType::Symbol extrapolate::Bool - idx_prev::Base.RefValue{Int} + iguesser::Guesser{tType} linear_lookup::Bool function BSplineApprox(u, t, @@ -683,7 +683,7 @@ struct BSplineApprox{uType, tType, pType, kType, cType, NType, T} <: pVecType, knotVecType, extrapolate, - Ref(1), + Guesser(t), linear_lookup ) end @@ -806,14 +806,14 @@ struct CubicHermiteSpline{uType, tType, IType, duType, pType, T} <: AbstractInte I::IType p::CubicHermiteParameterCache{pType} extrapolate::Bool - idx_prev::Base.RefValue{Int} + iguesser::Guesser{tType} cache_parameters::Bool linear_lookup::Bool function CubicHermiteSpline( du, u, t, I, p, extrapolate, cache_parameters, assume_linear_t) linear_lookup = seems_linear(assume_linear_t, t) new{typeof(u), typeof(t), typeof(I), typeof(du), typeof(p.c₁), eltype(u)}( - du, u, t, I, p, extrapolate, Ref(1), cache_parameters, linear_lookup) + du, u, t, I, p, extrapolate, Guesser(t), cache_parameters, linear_lookup) end end @@ -887,7 +887,7 @@ struct QuinticHermiteSpline{uType, tType, IType, duType, dduType, pType, T} <: I::IType p::QuinticHermiteParameterCache{pType} extrapolate::Bool - idx_prev::Base.RefValue{Int} + iguesser::Guesser{tType} cache_parameters::Bool linear_lookup::Bool function QuinticHermiteSpline( @@ -895,7 +895,7 @@ struct QuinticHermiteSpline{uType, tType, IType, duType, dduType, pType, T} <: linear_lookup = seems_linear(assume_linear_t, t) new{typeof(u), typeof(t), typeof(I), typeof(du), typeof(ddu), typeof(p.c₁), eltype(u)}( - ddu, du, u, t, I, p, extrapolate, Ref(1), cache_parameters, linear_lookup) + ddu, du, u, t, I, p, extrapolate, Guesser(t), cache_parameters, linear_lookup) end end diff --git a/src/interpolation_methods.jl b/src/interpolation_methods.jl index 69a5cf94..697f7375 100644 --- a/src/interpolation_methods.jl +++ b/src/interpolation_methods.jl @@ -1,10 +1,7 @@ function _interpolate(A, t) ((t < A.t[1] || t > A.t[end]) && !A.extrapolate) && throw(ExtrapolationError()) - idx_guess = A.idx_prev[] - val, idx_prev = _interpolate(A, t, idx_guess) - A.idx_prev[] = idx_prev - return val + return _interpolate(A, t, A.iguesser) end # Linear Interpolation @@ -33,14 +30,14 @@ function _interpolate(A::LinearInterpolation{<:AbstractVector}, t::Number, igues end val = oftype(Δu, val) - val, idx + val end function _interpolate(A::LinearInterpolation{<:AbstractMatrix}, t::Number, iguess) idx = get_idx(A, t, iguess) Δt = t - A.t[idx] slope = get_parameters(A, idx) - return A.u[:, idx] + slope * Δt, idx + return A.u[:, idx] + slope * Δt end # Quadratic Interpolation @@ -56,7 +53,7 @@ function _interpolate(A::QuadraticInterpolation, t::Number, iguess) u₀ = l₀ * (t - A.t[i₁]) * (t - A.t[i₂]) u₁ = l₁ * (t - A.t[i₀]) * (t - A.t[i₂]) u₂ = l₂ * (t - A.t[i₀]) * (t - A.t[i₁]) - return u₀ + u₁ + u₂, i₀ + return u₀ + u₁ + u₂ end # Lagrange Interpolation @@ -64,7 +61,7 @@ function _interpolate(A::LagrangeInterpolation{<:AbstractVector}, t::Number, igu idx = get_idx(A, t, iguess) findRequiredIdxs!(A, t, idx) if A.t[A.idxs[1]] == t - return A.u[A.idxs[1]], idx + return A.u[A.idxs[1]] end N = zero(A.u[1]) D = zero(A.t[1]) @@ -86,14 +83,14 @@ function _interpolate(A::LagrangeInterpolation{<:AbstractVector}, t::Number, igu D += tmp N += (tmp * A.u[A.idxs[i]]) end - N / D, idx + N / D end function _interpolate(A::LagrangeInterpolation{<:AbstractMatrix}, t::Number, iguess) idx = get_idx(A, t, iguess) findRequiredIdxs!(A, t, idx) if A.t[A.idxs[1]] == t - return A.u[:, A.idxs[1]], idx + return A.u[:, A.idxs[1]] end N = zero(A.u[:, 1]) D = zero(A.t[1]) @@ -115,13 +112,13 @@ function _interpolate(A::LagrangeInterpolation{<:AbstractMatrix}, t::Number, igu D += tmp @. N += (tmp * A.u[:, A.idxs[i]]) end - N / D, idx + N / D end function _interpolate(A::AkimaInterpolation{<:AbstractVector}, t::Number, iguess) idx = get_idx(A, t, iguess) wj = t - A.t[idx] - (@evalpoly wj A.u[idx] A.b[idx] A.c[idx] A.d[idx]), idx + @evalpoly wj A.u[idx] A.b[idx] A.c[idx] A.d[idx] end # ConstantInterpolation Interpolation @@ -133,7 +130,7 @@ function _interpolate(A::ConstantInterpolation{<:AbstractVector}, t::Number, igu # :right means that value to the right is used for interpolation idx = get_idx(A, t, iguess; side = :first, lb = 1, ub_shift = 0) end - A.u[idx], idx + A.u[idx] end function _interpolate(A::ConstantInterpolation{<:AbstractMatrix}, t::Number, iguess) @@ -144,7 +141,7 @@ function _interpolate(A::ConstantInterpolation{<:AbstractMatrix}, t::Number, igu # :right means that value to the right is used for interpolation idx = get_idx(A, t, iguess; side = :first, lb = 1, ub_shift = 0) end - A.u[:, idx], idx + A.u[:, idx] end # QuadraticSpline Interpolation @@ -153,7 +150,7 @@ function _interpolate(A::QuadraticSpline{<:AbstractVector}, t::Number, iguess) Cᵢ = A.u[idx] Δt = t - A.t[idx] σ = get_parameters(A, idx) - return A.z[idx] * Δt + σ * Δt^2 + Cᵢ, idx + return A.z[idx] * Δt + σ * Δt^2 + Cᵢ end # CubicSpline Interpolation @@ -165,15 +162,15 @@ function _interpolate(A::CubicSpline{<:AbstractVector}, t::Number, iguess) c₁, c₂ = get_parameters(A, idx) C = c₁ * Δt₁ D = c₂ * Δt₂ - I + C + D, idx + I + C + D end # BSpline Curve Interpolation function _interpolate(A::BSplineInterpolation{<:AbstractVector{<:Number}}, t::Number, iguess) - t < A.t[1] && return A.u[1], 1 - t > A.t[end] && return A.u[end], lastindex(t) + t < A.t[1] && return A.u[1] + t > A.t[end] && return A.u[end] # change t into param [0 1] idx = get_idx(A, t, iguess) t = A.p[idx] + (t - A.t[idx]) / (A.t[idx + 1] - A.t[idx]) * (A.p[idx + 1] - A.p[idx]) @@ -184,13 +181,13 @@ function _interpolate(A::BSplineInterpolation{<:AbstractVector{<:Number}}, for i in nonzero_coefficient_idxs ucum += N[i] * A.c[i] end - ucum, idx + ucum end # BSpline Curve Approx function _interpolate(A::BSplineApprox{<:AbstractVector{<:Number}}, t::Number, iguess) - t < A.t[1] && return A.u[1], 1 - t > A.t[end] && return A.u[end], lastindex(t) + t < A.t[1] && return A.u[1] + t > A.t[end] && return A.u[end] # change t into param [0 1] idx = get_idx(A, t, iguess) t = A.p[idx] + (t - A.t[idx]) / (A.t[idx + 1] - A.t[idx]) * (A.p[idx + 1] - A.p[idx]) @@ -200,7 +197,7 @@ function _interpolate(A::BSplineApprox{<:AbstractVector{<:Number}}, t::Number, i for i in nonzero_coefficient_idxs ucum += N[i] * A.c[i] end - ucum, idx + ucum end # Cubic Hermite Spline @@ -212,7 +209,7 @@ function _interpolate( out = A.u[idx] + Δt₀ * A.du[idx] c₁, c₂ = get_parameters(A, idx) out += Δt₀^2 * (c₁ + Δt₁ * c₂) - out, idx + out end # Quintic Hermite Spline @@ -224,5 +221,5 @@ function _interpolate( out = A.u[idx] + Δt₀ * (A.du[idx] + A.ddu[idx] * Δt₀ / 2) c₁, c₂, c₃ = get_parameters(A, idx) out += Δt₀^3 * (c₁ + Δt₁ * (c₂ + c₃ * Δt₁)) - out, idx + out end diff --git a/src/interpolation_utils.jl b/src/interpolation_utils.jl index ccd3db63..0df5e484 100644 --- a/src/interpolation_utils.jl +++ b/src/interpolation_utils.jl @@ -117,16 +117,8 @@ function looks_linear(t; threshold = 1e-2) norm_var < threshold^2 end -function get_idx(A::AbstractInterpolation, t, iguess; lb = 1, +function get_idx(A::AbstractInterpolation, t, iguess::Union{<:Integer, Guesser}; lb = 1, ub_shift = -1, idx_shift = 0, side = :last) - iguess = if hasfield(typeof(A), :linear_lookup) && - A.linear_lookup - f = (t - first(A.t)) / (last(A.t) - first(A.t)) - i_0, i_f = firstindex(A.t), lastindex(A.t) - round(typeof(firstindex(A.t)), f * (i_f - i_0) + i_0) - else - iguess - end tvec = A.t ub = length(tvec) + ub_shift return if side == :last diff --git a/test/derivative_tests.jl b/test/derivative_tests.jl index 50abe4ac..a4eb349b 100644 --- a/test/derivative_tests.jl +++ b/test/derivative_tests.jl @@ -39,9 +39,9 @@ function test_derivatives(method; args = [], kwargs = [], name::String) @test isapprox(fdiff, adiff, atol = 1e-8) @test isapprox(fdiff2, adiff2, atol = 1e-8) # Cached index - if hasproperty(func, :idx_prev) - @test abs(func.idx_prev[] - - searchsortedfirstcorrelated(func.t, _t, func.idx_prev[])) <= 1 + if hasproperty(func, :iguesser) && !func.iguesser.linear_lookup + @test abs(func.iguesser.idx_prev[] - + searchsortedfirstcorrelated(func.t, _t, func.iguesser(_t))) <= 1 end end diff --git a/test/interpolation_tests.jl b/test/interpolation_tests.jl index 1d4e2e6b..3ebc1741 100644 --- a/test/interpolation_tests.jl +++ b/test/interpolation_tests.jl @@ -9,7 +9,7 @@ function test_interpolation_type(T) @test hasfield(T, :u) @test hasfield(T, :t) @test hasfield(T, :extrapolate) - @test hasfield(T, :idx_prev) + @test hasfield(T, :iguesser) @test !isempty(methods(DataInterpolations._interpolate, (T, Any, Number))) @test !isempty(methods(DataInterpolations._integral, (T, Any, Number))) @test !isempty(methods(DataInterpolations._derivative, (T, Any, Number))) @@ -18,9 +18,9 @@ end function test_cached_index(A) for t in range(first(A.t), last(A.t); length = 2 * length(A.t) - 1) A(t) - idx = searchsortedfirstcorrelated(A.t, t, A.idx_prev[]) - @test abs(A.idx_prev[] - - searchsortedfirstcorrelated(A.t, t, A.idx_prev[])) <= 2 + idx = searchsortedfirstcorrelated(A.t, t, A.iguesser) + @test abs(A.iguesser.idx_prev[] - + searchsortedfirstcorrelated(A.t, t, A.iguesser)) <= 2 end end @@ -492,6 +492,13 @@ end A = ConstantInterpolation(u, t) @test_throws DataInterpolations.ExtrapolationError A(-1.0) @test_throws DataInterpolations.ExtrapolationError A(11.0) + + # Test extrapolation with infs with regularly spaced t + u = [1.67e7, 1.6867e7, 1.7034e7, 1.7201e7, 1.7368e7] + t = [0.0, 0.1, 0.2, 0.3, 0.4] + A = ConstantInterpolation(u, t; extrapolate = true) + @test A(Inf) == last(u) + @test A(-Inf) == first(u) end @testset "QuadraticSpline Interpolation" begin @@ -801,20 +808,3 @@ f_cubic_spline = c -> square(CubicSpline, c) @test ForwardDiff.derivative(f_quadratic_spline, 4.0) ≈ 8.0 @test ForwardDiff.derivative(f_cubic_spline, 2.0) ≈ 4.0 @test ForwardDiff.derivative(f_cubic_spline, 4.0) ≈ 8.0 - -@testset "Linear lookup" begin - for N in (100, 1_000, 10_000, 100_000, 1_000_000) # Interpolant size - seed = 1234 - t = collect(LinRange(0, 1, N)) # collect to avoid fast LinRange dispatch - y = rand(N) - A = LinearInterpolation(y, t) - A_fallback = LinearInterpolation(copy(y), copy(t); assume_linear_t = false) - @test A.linear_lookup - @test !(A_fallback.linear_lookup) - - n_samples = 1_000 - b_linear = @benchmark $A(rand()) samples=n_samples - b_fallback = @benchmark $A_fallback(rand()) samples=n_samples - @test mean(b_linear.times) < mean(b_fallback.times) - end -end