diff --git a/Project.toml b/Project.toml index d4133013..c9b7a2e9 100644 --- a/Project.toml +++ b/Project.toml @@ -7,7 +7,6 @@ FindFirstFunctions = "64ca27bc-2ba2-4a57-88aa-44e436879224" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d" -ReadOnlyArrays = "988b38a3-91fc-5605-94a2-ee2116b3bd83" RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" @@ -16,6 +15,7 @@ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Optim = "429524aa-4258-5aef-a3af-852621145aeb" RegularizationTools = "29dad682-9a27-4bc3-9c72-016788665182" Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [extensions] DataInterpolationsChainRulesCoreExt = "ChainRulesCore" @@ -33,7 +33,6 @@ LinearAlgebra = "1.10" Optim = "1.6" PrettyTables = "2" QuadGK = "2.9.1" -ReadOnlyArrays = "0.2.0" RecipesBase = "1.3" Reexport = "1" RegularizationTools = "0.6" @@ -41,6 +40,7 @@ SafeTestsets = "0.1" StableRNGs = "1" Symbolics = "5.29" Test = "1" +Zygote = "0.6" julia = "1.10" [extras] @@ -55,6 +55,7 @@ SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["Aqua", "SafeTestsets", "ChainRulesCore", "Optim", "RegularizationTools", "Test", "StableRNGs", "FiniteDifferences", "QuadGK", "ForwardDiff", "Symbolics"] +test = ["Aqua", "SafeTestsets", "ChainRulesCore", "Optim", "RegularizationTools", "Test", "StableRNGs", "FiniteDifferences", "QuadGK", "ForwardDiff", "Symbolics", "Zygote"] diff --git a/docs/src/interface.md b/docs/src/interface.md index ca5e9819..cfc9ed0b 100644 --- a/docs/src/interface.md +++ b/docs/src/interface.md @@ -35,22 +35,7 @@ A2(300.0) The values computed beyond the range of the time points provided during interpolation will not be reliable, as these methods only perform well within the range and the first/last piece polynomial fit is extrapolated on either side which might not reflect the true nature of the data. -The keyword `safetycopy = false` can be passed to make sure no copies of `u` and `t` are made when initializing the interpolation object. - -```@example interface -A3 = QuadraticInterpolation(u, t; safetycopy = false) - -# Check for same memory -u === A3.u.parent -``` - -Note that this does not prevent allocation in every interpolation constructor call, because parameter values are cached for all interpolation types except [`ConstantInterpolation`](@ref). - -Because of the caching of parameters which depend on `u` and `t`, this data should not be mutated. Therefore `u` and `t` are wrapped in a `ReadOnlyArray` from [ReadOnlyArrays.jl](https://github.com/JuliaArrays/ReadOnlyArrays.jl). - -```@repl interface -A3.t[2] = 3.14 -``` +The keyword `cache_parameters = true` can be passed to precalculate parameters at initialization, making evalations cheaper to compute. This is not compatible with modifying `u` and `t`. The default `cache_parameters = false` does however not prevent allocation in every interpolation constructor call. ## Derivatives diff --git a/ext/DataInterpolationsChainRulesCoreExt.jl b/ext/DataInterpolationsChainRulesCoreExt.jl index 34e27841..9c33b09c 100644 --- a/ext/DataInterpolationsChainRulesCoreExt.jl +++ b/ext/DataInterpolationsChainRulesCoreExt.jl @@ -1,19 +1,87 @@ module DataInterpolationsChainRulesCoreExt - if isdefined(Base, :get_extension) using DataInterpolations: _interpolate, derivative, AbstractInterpolation, + LinearInterpolation, QuadraticInterpolation, LagrangeInterpolation, AkimaInterpolation, - BSplineInterpolation, BSplineApprox + BSplineInterpolation, BSplineApprox, get_idx, get_parameters, + _quad_interp_indices using ChainRulesCore else using ..DataInterpolations: _interpolate, derivative, AbstractInterpolation, + LinearInterpolation, QuadraticInterpolation, LagrangeInterpolation, AkimaInterpolation, - BSplineInterpolation, BSplineApprox + BSplineInterpolation, BSplineApprox, get_parameters, + _quad_interp_indices using ..ChainRulesCore end +function ChainRulesCore.rrule( + ::Type{LinearInterpolation}, u, t, I, p, extrapolate, cache_parameters) + A = LinearInterpolation(u, t, I, p, extrapolate, cache_parameters) + function LinearInterpolation_pullback(ΔA) + df = NoTangent() + du = ΔA.u + dt = NoTangent() + dI = NoTangent() + dp = NoTangent() + dextrapolate = NoTangent() + dcache_parameters = NoTangent() + df, du, dt, dI, dp, dextrapolate, dcache_parameters + end + + A, LinearInterpolation_pullback +end + +function ChainRulesCore.rrule( + ::Type{QuadraticInterpolation}, u, t, I, p, mode, extrapolate, cache_parameters) + A = QuadraticInterpolation(u, t, I, p, mode, extrapolate, cache_parameters) + function LinearInterpolation_pullback(ΔA) + df = NoTangent() + du = ΔA.u + dt = NoTangent() + dI = NoTangent() + dp = NoTangent() + dmode = NoTangent() + dextrapolate = NoTangent() + dcache_parameters = NoTangent() + df, du, dt, dI, dp, dmode, dextrapolate, dcache_parameters + end + + A, LinearInterpolation_pullback +end + +function u_tangent(A::LinearInterpolation, t, Δ) + out = zero(A.u) + idx = get_idx(A.t, t, A.idx_prev[]) + t_factor = (t - A.t[idx]) / (A.t[idx + 1] - A.t[idx]) + out[idx] = Δ * (one(eltype(out)) - t_factor) + out[idx + 1] = Δ * t_factor + out +end + +function u_tangent(A::QuadraticInterpolation, t, Δ) + out = zero(A.u) + i₀, i₁, i₂ = _quad_interp_indices(A, t, A.idx_prev[]) + t₀ = A.t[i₀] + t₁ = A.t[i₁] + t₂ = A.t[i₂] + Δt₀ = t₁ - t₀ + Δt₁ = t₂ - t₁ + Δt₂ = t₂ - t₀ + out[i₀] = Δ * (t - A.t[i₁]) * (t - A.t[i₂]) / (Δt₀ * Δt₂) + out[i₁] = -Δ * (t - A.t[i₀]) * (t - A.t[i₂]) / (Δt₀ * Δt₁) + out[i₂] = Δ * (t - A.t[i₀]) * (t - A.t[i₁]) / (Δt₂ * Δt₁) + out +end + +function u_tangent(A, t, Δ) + NoTangent() +end + function ChainRulesCore.rrule(::typeof(_interpolate), A::Union{ + LinearInterpolation, + QuadraticInterpolation, LagrangeInterpolation, AkimaInterpolation, BSplineInterpolation, @@ -21,7 +89,9 @@ function ChainRulesCore.rrule(::typeof(_interpolate), }, t::Number) deriv = derivative(A, t) - interpolate_pullback(Δ) = (NoTangent(), NoTangent(), deriv * Δ) + function interpolate_pullback(Δ) + (NoTangent(), Tangent{typeof(A)}(; u = u_tangent(A, t, Δ)), deriv * Δ) + end return _interpolate(A, t), interpolate_pullback end diff --git a/ext/DataInterpolationsOptimExt.jl b/ext/DataInterpolationsOptimExt.jl index 5528503f..b3bce295 100644 --- a/ext/DataInterpolationsOptimExt.jl +++ b/ext/DataInterpolationsOptimExt.jl @@ -18,9 +18,8 @@ function Curvefit(u, box = false, lb = nothing, ub = nothing; - extrapolate = false, - safetycopy = false) - u, t = munge_data(u, t, safetycopy) + extrapolate = false) + u, t = munge_data(u, t) errfun(t, u, p) = sum(abs2.(u .- model(t, p))) if box == false mfit = optimize(p -> errfun(t, u, p), p0, alg) diff --git a/ext/DataInterpolationsRegularizationToolsExt.jl b/ext/DataInterpolationsRegularizationToolsExt.jl index 10ea3e4c..732ea1bb 100644 --- a/ext/DataInterpolationsRegularizationToolsExt.jl +++ b/ext/DataInterpolationsRegularizationToolsExt.jl @@ -69,8 +69,8 @@ A = RegularizationSmooth(u, t, t̂, wls, wr, d; λ = 1.0, alg = :gcv_svd) """ function RegularizationSmooth(u::AbstractVector, t::AbstractVector, t̂::AbstractVector, wls::AbstractVector, wr::AbstractVector, d::Int = 2; - λ::Real = 1.0, alg::Symbol = :gcv_svd, extrapolate::Bool = false, safetycopy::Bool = true) - u, t = munge_data(u, t, safetycopy) + λ::Real = 1.0, alg::Symbol = :gcv_svd, extrapolate::Bool = false) + u, t = munge_data(u, t) M = _mapping_matrix(t̂, t) Wls½ = LA.diagm(sqrt.(wls)) Wr½ = LA.diagm(sqrt.(wr)) @@ -86,8 +86,8 @@ A = RegularizationSmooth(u, t, d; λ = 1.0, alg = :gcv_svd, extrapolate = false) """ function RegularizationSmooth(u::AbstractVector, t::AbstractVector, d::Int = 2; λ::Real = 1.0, - alg::Symbol = :gcv_svd, extrapolate::Bool = false, safetycopy::Bool = true) - u, t = munge_data(u, t, safetycopy) + alg::Symbol = :gcv_svd, extrapolate::Bool = false) + u, t = munge_data(u, t) t̂ = t N = length(t) M = Array{Float64}(LA.I, N, N) @@ -115,8 +115,8 @@ A = RegularizationSmooth(u, t, t̂, d; λ = 1.0, alg = :gcv_svd, extrapolate = f """ function RegularizationSmooth(u::AbstractVector, t::AbstractVector, t̂::AbstractVector, d::Int = 2; λ::Real = 1.0, alg::Symbol = :gcv_svd, - extrapolate::Bool = false, safetycopy::Bool = true) - u, t = munge_data(u, t, safetycopy) + extrapolate::Bool = false) + u, t = munge_data(u, t) N, N̂ = length(t), length(t̂) M = _mapping_matrix(t̂, t) Wls½ = Array{Float64}(LA.I, N, N) @@ -143,8 +143,8 @@ A = RegularizationSmooth(u, t, t̂, wls, d; λ = 1.0, alg = :gcv_svd, extrapolat """ function RegularizationSmooth(u::AbstractVector, t::AbstractVector, t̂::AbstractVector, wls::AbstractVector, d::Int = 2; λ::Real = 1.0, - alg::Symbol = :gcv_svd, extrapolate::Bool = false, safetycopy::Bool = true) - u, t = munge_data(u, t, safetycopy) + alg::Symbol = :gcv_svd, extrapolate::Bool = false) + u, t = munge_data(u, t) N, N̂ = length(t), length(t̂) M = _mapping_matrix(t̂, t) Wls½ = LA.diagm(sqrt.(wls)) @@ -172,8 +172,8 @@ A = RegularizationSmooth( """ function RegularizationSmooth(u::AbstractVector, t::AbstractVector, t̂::Nothing, wls::AbstractVector, d::Int = 2; λ::Real = 1.0, - alg::Symbol = :gcv_svd, extrapolate::Bool = false, safetycopy::Bool = true) - u, t = munge_data(u, t, safetycopy) + alg::Symbol = :gcv_svd, extrapolate::Bool = false) + u, t = munge_data(u, t) t̂ = t N = length(t) M = Array{Float64}(LA.I, N, N) @@ -202,8 +202,8 @@ A = RegularizationSmooth( """ function RegularizationSmooth(u::AbstractVector, t::AbstractVector, t̂::Nothing, wls::AbstractVector, wr::AbstractVector, d::Int = 2; - λ::Real = 1.0, alg::Symbol = :gcv_svd, extrapolate::Bool = false, safetycopy::Bool = true) - u, t = munge_data(u, t, safetycopy) + λ::Real = 1.0, alg::Symbol = :gcv_svd, extrapolate::Bool = false) + u, t = munge_data(u, t) t̂ = t N = length(t) M = Array{Float64}(LA.I, N, N) @@ -232,8 +232,8 @@ A = RegularizationSmooth( """ function RegularizationSmooth(u::AbstractVector, t::AbstractVector, t̂::Nothing, wls::Symbol, d::Int = 2; λ::Real = 1.0, alg::Symbol = :gcv_svd, - extrapolate::Bool = false, safetycopy::Bool = true) - u, t = munge_data(u, t, safetycopy) + extrapolate::Bool = false) + u, t = munge_data(u, t) t̂ = t N = length(t) M = Array{Float64}(LA.I, N, N) diff --git a/src/DataInterpolations.jl b/src/DataInterpolations.jl index d96ea69a..7f44c878 100644 --- a/src/DataInterpolations.jl +++ b/src/DataInterpolations.jl @@ -7,7 +7,6 @@ abstract type AbstractInterpolation{T} end using LinearAlgebra, RecipesBase using PrettyTables using ForwardDiff -using ReadOnlyArrays import FindFirstFunctions: searchsortedfirstcorrelated, searchsortedlastcorrelated, bracketstrictlymontonic @@ -90,12 +89,6 @@ function Base.showerror(io::IO, e::IntegralNotInvertibleError) print(io, INTEGRAL_NOT_INVERTIBLE_ERROR) end -const MUST_COPY_ERROR = "A copy must be made of u, t to filter missing data" -struct MustCopyError <: Exception end -function Base.showerror(io::IO, e::MustCopyError) - print(io, MUST_COPY_ERROR) -end - export LinearInterpolation, QuadraticInterpolation, LagrangeInterpolation, AkimaInterpolation, ConstantInterpolation, QuadraticSpline, CubicSpline, BSplineInterpolation, BSplineApprox, CubicHermiteSpline, PCHIPInterpolation, @@ -128,12 +121,12 @@ struct RegularizationSmooth{uType, tType, T, T2, ITP <: AbstractInterpolation{T} Aitp, extrapolate) new{typeof(u), typeof(t), eltype(u), typeof(λ), typeof(Aitp)}( - readonly_wrap(u), - readonly_wrap(û), - readonly_wrap(t), - readonly_wrap(t̂), - readonly_wrap(oftype(u.parent, wls)), - readonly_wrap(oftype(u.parent, wr)), + u, + û, + t, + t̂, + wls, + wr, d, λ, alg, diff --git a/src/derivatives.jl b/src/derivatives.jl index 4d8f7189..75872095 100644 --- a/src/derivatives.jl +++ b/src/derivatives.jl @@ -19,14 +19,16 @@ end function _derivative(A::LinearInterpolation, t::Number, iguess) idx = get_idx(A.t, t, iguess; idx_shift = -1, ub_shift = -1, side = :first) - A.p.slope[idx], idx + slope = get_parameters(A, idx) + slope, idx end function _derivative(A::QuadraticInterpolation, t::Number, iguess) i₀, i₁, i₂ = _quad_interp_indices(A, t, iguess) - du₀ = A.p.l₀[i₀] * (2t - A.t[i₁] - A.t[i₂]) - du₁ = A.p.l₁[i₀] * (2t - A.t[i₀] - A.t[i₂]) - du₂ = A.p.l₂[i₀] * (2t - A.t[i₀] - A.t[i₁]) + l₀, l₁, l₂ = get_parameters(A, i₀) + du₀ = l₀ * (2t - A.t[i₁] - A.t[i₂]) + du₁ = l₁ * (2t - A.t[i₀] - A.t[i₂]) + du₂ = l₂ * (2t - A.t[i₀] - A.t[i₁]) return @views @. du₀ + du₁ + du₂, i₀ end @@ -129,7 +131,7 @@ end # QuadraticSpline Interpolation function _derivative(A::QuadraticSpline{<:AbstractVector}, t::Number, iguess) idx = get_idx(A.t, t, iguess; lb = 2, ub_shift = 0, side = :first) - σ = A.p.σ[idx - 1] + σ = get_parameters(A, idx - 1) A.z[idx - 1] + 2σ * (t - A.t[idx - 1]), idx end @@ -139,8 +141,9 @@ function _derivative(A::CubicSpline{<:AbstractVector}, t::Number, iguess) Δt₁ = t - A.t[idx] Δt₂ = A.t[idx + 1] - t dI = (-A.z[idx] * Δt₂^2 + A.z[idx + 1] * Δt₁^2) / (2A.h[idx + 1]) - dC = A.p.c₁[idx] - dD = -A.p.c₂[idx] + c₁, c₂ = get_parameters(A, idx) + dC = c₁ + dD = -c₂ dI + dC + dD, idx end @@ -193,7 +196,8 @@ function _derivative( Δt₀ = t - A.t[idx] Δt₁ = t - A.t[idx + 1] out = A.du[idx] - out += Δt₀ * (Δt₀ * A.p.c₂[idx] + 2(A.p.c₁[idx] + Δt₁ * A.p.c₂[idx])) + c₁, c₂ = get_parameters(A, idx) + out += Δt₀ * (Δt₀ * c₂ + 2(c₁ + Δt₁ * c₂)) out, idx end @@ -204,7 +208,8 @@ function _derivative( Δt₀ = t - A.t[idx] Δt₁ = t - A.t[idx + 1] out = A.du[idx] + A.ddu[idx] * Δt₀ + c₁, c₂, c₃ = get_parameters(A, idx) out += Δt₀^2 * - (3A.p.c₁[idx] + (3Δt₁ + Δt₀) * A.p.c₂[idx] + (3Δt₁^2 + Δt₀ * 2Δt₁) * A.p.c₃[idx]) + (3c₁ + (3Δt₁ + Δt₀) * c₂ + (3Δt₁^2 + Δt₀ * 2Δt₁) * c₃) out, idx end diff --git a/src/integral_inverses.jl b/src/integral_inverses.jl index 4437726e..31d0853c 100644 --- a/src/integral_inverses.jl +++ b/src/integral_inverses.jl @@ -40,10 +40,9 @@ struct LinearInterpolationIntInv{uType, tType, itpType, T} <: extrapolate::Bool idx_prev::Base.RefValue{Int} itp::itpType - safetycopy::Bool function LinearInterpolationIntInv(u, t, A) new{typeof(u), typeof(t), typeof(A), eltype(u)}( - u, t, A.extrapolate, Ref(1), A, A.safetycopy) + u, t, A.extrapolate, Ref(1), A) end end @@ -51,9 +50,11 @@ function invertible_integral(A::LinearInterpolation{<:AbstractVector{<:Number}}) return all(A.u .> 0) end +get_I(A::AbstractInterpolation) = isempty(A.I) ? cumulative_integral(A, true) : A.I + function invert_integral(A::LinearInterpolation{<:AbstractVector{<:Number}}) !invertible_integral(A) && throw(IntegralNotInvertibleError()) - return LinearInterpolationIntInv(A.t, A.I, A) + return LinearInterpolationIntInv(A.t, get_I(A), A) end function _interpolate( @@ -61,7 +62,8 @@ function _interpolate( idx = get_idx(A.t, t, iguess) Δt = t - A.t[idx] x = A.itp.u[idx] - u = A.u[idx] + 2Δt / (x + sqrt(x^2 + A.itp.p.slope[idx] * 2Δt)) + slope = get_parameters(A.itp, idx) + u = A.u[idx] + 2Δt / (x + sqrt(x^2 + slope * 2Δt)) u, idx end @@ -84,10 +86,9 @@ struct ConstantInterpolationIntInv{uType, tType, itpType, T} <: extrapolate::Bool idx_prev::Base.RefValue{Int} itp::itpType - safetycopy::Bool function ConstantInterpolationIntInv(u, t, A) new{typeof(u), typeof(t), typeof(A), eltype(u)}( - u, t, A.extrapolate, Ref(1), A, A.safetycopy + u, t, A.extrapolate, Ref(1), A ) end end @@ -98,7 +99,7 @@ end function invert_integral(A::ConstantInterpolation{<:AbstractVector{<:Number}}) !invertible_integral(A) && throw(IntegralNotInvertibleError()) - return ConstantInterpolationIntInv(A.t, A.I, A) + return ConstantInterpolationIntInv(A.t, get_I(A), A) end function _interpolate( diff --git a/src/integrals.jl b/src/integrals.jl index 3040189f..03ea26c5 100644 --- a/src/integrals.jl +++ b/src/integrals.jl @@ -12,14 +12,24 @@ function integral(A::AbstractInterpolation, t1::Number, t2::Number) # the index less than t2 idx2 = get_idx(A.t, t2, 0; idx_shift = -1, side = :first) - total = A.I[idx2] - A.I[idx1] - return if t1 == t2 - zero(total) + if A.cache_parameters + total = A.I[idx2] - A.I[idx1] + return if t1 == t2 + zero(total) + else + total += _integral(A, idx1, A.t[idx1]) + total -= _integral(A, idx1, t1) + total += _integral(A, idx2, t2) + total -= _integral(A, idx2, A.t[idx2]) + total + end else - total += _integral(A, idx1, A.t[idx1]) - total -= _integral(A, idx1, t1) - total += _integral(A, idx2, t2) - total -= _integral(A, idx2, A.t[idx2]) + total = zero(eltype(A.u)) + for idx in idx1:idx2 + lt1 = idx == idx1 ? t1 : A.t[idx] + lt2 = idx == idx2 ? t2 : A.t[idx + 1] + total += _integral(A, idx, lt2) - _integral(A, idx, lt1) + end total end end @@ -28,7 +38,8 @@ function _integral(A::LinearInterpolation{<:AbstractVector{<:Number}}, idx::Number, t::Number) Δt = t - A.t[idx] - Δt * (A.u[idx] + A.p.slope[idx] * Δt / 2) + slope = get_parameters(A, idx) + Δt * (A.u[idx] + slope * Δt / 2) end function _integral( @@ -52,24 +63,27 @@ function _integral(A::QuadraticInterpolation{<:AbstractVector{<:Number}}, t₂ = A.t[idx + 2] t_sq = (t^2) / 3 - Iu₀ = A.p.l₀[idx] * t * (t_sq - t * (t₁ + t₂) / 2 + t₁ * t₂) - Iu₁ = A.p.l₁[idx] * t * (t_sq - t * (t₀ + t₂) / 2 + t₀ * t₂) - Iu₂ = A.p.l₂[idx] * t * (t_sq - t * (t₀ + t₁) / 2 + t₀ * t₁) + l₀, l₁, l₂ = get_parameters(A, idx) + Iu₀ = l₀ * t * (t_sq - t * (t₁ + t₂) / 2 + t₁ * t₂) + Iu₁ = l₁ * t * (t_sq - t * (t₀ + t₂) / 2 + t₀ * t₂) + Iu₂ = l₂ * t * (t_sq - t * (t₀ + t₁) / 2 + t₀ * t₁) return Iu₀ + Iu₁ + Iu₂ end function _integral(A::QuadraticSpline{<:AbstractVector{<:Number}}, idx::Number, t::Number) Cᵢ = A.u[idx] Δt = t - A.t[idx] - return A.z[idx] * Δt^2 / 2 + A.p.σ[idx] * Δt^3 / 3 + Cᵢ * Δt + σ = get_parameters(A, idx) + return A.z[idx] * Δt^2 / 2 + σ * Δt^3 / 3 + Cᵢ * Δt end function _integral(A::CubicSpline{<:AbstractVector{<:Number}}, idx::Number, t::Number) Δt₁sq = (t - A.t[idx])^2 / 2 Δt₂sq = (A.t[idx + 1] - t)^2 / 2 II = (-A.z[idx] * Δt₂sq^2 + A.z[idx + 1] * Δt₁sq^2) / (6A.h[idx + 1]) - IC = A.p.c₁[idx] * Δt₁sq - ID = -A.p.c₂[idx] * Δt₂sq + c₁, c₂ = get_parameters(A, idx) + IC = c₁ * Δt₁sq + ID = -c₂ * Δt₂sq II + IC + ID end @@ -91,8 +105,9 @@ function _integral( Δt₀ = t - A.t[idx] Δt₁ = t - A.t[idx + 1] out = Δt₀ * (A.u[idx] + Δt₀ * A.du[idx] / 2) - p = A.p.c₁[idx] + Δt₁ * A.p.c₂[idx] - dp = A.p.c₂[idx] + c₁, c₂ = get_parameters(A, idx) + p = c₁ + Δt₁ * c₂ + dp = c₂ out += Δt₀^3 / 3 * (p - dp * Δt₀ / 4) out end @@ -103,9 +118,10 @@ function _integral( Δt₀ = t - A.t[idx] Δt₁ = t - A.t[idx + 1] out = Δt₀ * (A.u[idx] + A.du[idx] * Δt₀ / 2 + A.ddu[idx] * Δt₀^2 / 6) - p = A.p.c₁[idx] + A.p.c₂[idx] * Δt₁ + A.p.c₃[idx] * Δt₁^2 - dp = A.p.c₂[idx] + 2A.p.c₃[idx] * Δt₁ - ddp = 2A.p.c₃[idx] + c₁, c₂, c₃ = get_parameters(A, idx) + p = c₁ + c₂ * Δt₁ + c₃ * Δt₁^2 + dp = c₂ + 2c₃ * Δt₁ + ddp = 2c₃ out += Δt₀^4 / 4 * (p - Δt₀ / 5 * dp + Δt₀^2 / 30 * ddp) out end diff --git a/src/interpolation_caches.jl b/src/interpolation_caches.jl index cb120a4a..7f6991b8 100644 --- a/src/interpolation_caches.jl +++ b/src/interpolation_caches.jl @@ -1,5 +1,5 @@ """ - LinearInterpolation(u, t; extrapolate = false, safetycopy = true) + LinearInterpolation(u, t; extrapolate = false, cache_parameters = false) It is the method of interpolating between the data points using a linear polynomial. For any point, two data points one each side are chosen and connected with a line. Extrapolation extends the last linear polynomial on each side. @@ -12,7 +12,7 @@ Extrapolation extends the last linear polynomial on each side. ## Keyword Arguments - `extrapolate`: boolean value to allow extrapolation. Defaults to `false`. - - `safetycopy`: boolean value to make a copy of `u` and `t`. Defaults to `true`. + - `cache_parameters`: precompute parameters at initialization for faster interpolation computations. Note: if activated, `u` and `t` should not be modified. Defaults to `false`. """ struct LinearInterpolation{uType, tType, IType, pType, T} <: AbstractInterpolation{T} u::uType @@ -21,23 +21,23 @@ struct LinearInterpolation{uType, tType, IType, pType, T} <: AbstractInterpolati p::LinearParameterCache{pType} extrapolate::Bool idx_prev::Base.RefValue{Int} - safetycopy::Bool - function LinearInterpolation(u, t, I, p, extrapolate, safetycopy) + cache_parameters::Bool + function LinearInterpolation(u, t, I, p, extrapolate, cache_parameters) new{typeof(u), typeof(t), typeof(I), typeof(p.slope), eltype(u)}( - u, t, I, p, extrapolate, Ref(1), safetycopy) + u, t, I, p, extrapolate, Ref(1), cache_parameters) end end -function LinearInterpolation(u, t; extrapolate = false, safetycopy = true) - u, t = munge_data(u, t, safetycopy) - p = LinearParameterCache(u, t) - A = LinearInterpolation(u, t, nothing, p, extrapolate, safetycopy) - I = cumulative_integral(A) - LinearInterpolation(u, t, I, p, extrapolate, safetycopy) +function LinearInterpolation(u, t; extrapolate = false, cache_parameters = false) + u, t = munge_data(u, t) + p = LinearParameterCache(u, t, cache_parameters) + A = LinearInterpolation(u, t, nothing, p, extrapolate, cache_parameters) + I = cumulative_integral(A, cache_parameters) + LinearInterpolation(u, t, I, p, extrapolate, cache_parameters) end """ - QuadraticInterpolation(u, t, mode = :Forward; extrapolate = false, safetycopy = true) + QuadraticInterpolation(u, t, mode = :Forward; extrapolate = false, cache_parameters = false) It is the method of interpolating between the data points using quadratic polynomials. For any point, three data points nearby are taken to fit a quadratic polynomial. Extrapolation extends the last quadratic polynomial on each side. @@ -51,7 +51,7 @@ Extrapolation extends the last quadratic polynomial on each side. ## Keyword Arguments - `extrapolate`: boolean value to allow extrapolation. Defaults to `false`. - - `safetycopy`: boolean value to make a copy of `u` and `t`. Defaults to `true`. + - `cache_parameters`: precompute parameters at initialization for faster interpolation computations. Note: if activated, `u` and `t` should not be modified. Defaults to `false`. """ struct QuadraticInterpolation{uType, tType, IType, pType, T} <: AbstractInterpolation{T} u::uType @@ -61,25 +61,25 @@ struct QuadraticInterpolation{uType, tType, IType, pType, T} <: AbstractInterpol mode::Symbol extrapolate::Bool idx_prev::Base.RefValue{Int} - safetycopy::Bool - function QuadraticInterpolation(u, t, I, p, mode, extrapolate, safetycopy) + cache_parameters::Bool + function QuadraticInterpolation(u, t, I, p, mode, extrapolate, cache_parameters) mode ∈ (:Forward, :Backward) || error("mode should be :Forward or :Backward for QuadraticInterpolation") new{typeof(u), typeof(t), typeof(I), typeof(p.l₀), eltype(u)}( - u, t, I, p, mode, extrapolate, Ref(1), safetycopy) + u, t, I, p, mode, extrapolate, Ref(1), cache_parameters) end end -function QuadraticInterpolation(u, t, mode; extrapolate = false, safetycopy = true) - u, t = munge_data(u, t, safetycopy) - p = QuadraticParameterCache(u, t) - A = QuadraticInterpolation(u, t, nothing, p, mode, extrapolate, safetycopy) - I = cumulative_integral(A) - QuadraticInterpolation(u, t, I, p, mode, extrapolate, safetycopy) +function QuadraticInterpolation(u, t, mode; extrapolate = false, cache_parameters = false) + u, t = munge_data(u, t) + p = QuadraticParameterCache(u, t, cache_parameters) + A = QuadraticInterpolation(u, t, nothing, p, mode, extrapolate, cache_parameters) + I = cumulative_integral(A, cache_parameters) + QuadraticInterpolation(u, t, I, p, mode, extrapolate, cache_parameters) end -function QuadraticInterpolation(u, t; extrapolate = false, safetycopy = true) - QuadraticInterpolation(u, t, :Forward; extrapolate, safetycopy) +function QuadraticInterpolation(u, t; extrapolate = false, cache_parameters = false) + QuadraticInterpolation(u, t, :Forward; extrapolate, cache_parameters) end """ @@ -96,7 +96,6 @@ It is the method of interpolation using Lagrange polynomials of (k-1)th order pa ## Keyword Arguments - `extrapolate`: boolean value to allow extrapolation. Defaults to `false`. - - `safetycopy`: boolean value to make a copy of `u` and `t`. Defaults to `true`. """ struct LagrangeInterpolation{uType, tType, T, bcacheType} <: AbstractInterpolation{T} @@ -107,8 +106,7 @@ struct LagrangeInterpolation{uType, tType, T, bcacheType} <: idxs::Vector{Int} extrapolate::Bool idx_prev::Base.RefValue{Int} - safetycopy::Bool - function LagrangeInterpolation(u, t, n, extrapolate, safetycopy) + function LagrangeInterpolation(u, t, n, extrapolate) bcache = zeros(eltype(u[1]), n + 1) idxs = zeros(Int, n + 1) fill!(bcache, NaN) @@ -118,23 +116,22 @@ struct LagrangeInterpolation{uType, tType, T, bcacheType} <: bcache, idxs, extrapolate, - Ref(1), - safetycopy + Ref(1) ) end end function LagrangeInterpolation( - u, t, n = length(t) - 1; extrapolate = false, safetycopy = true) - u, t = munge_data(u, t, safetycopy) + u, t, n = length(t) - 1; extrapolate = false) + u, t = munge_data(u, t) if n != length(t) - 1 error("Currently only n=length(t) - 1 is supported") end - LagrangeInterpolation(u, t, n, extrapolate, safetycopy) + LagrangeInterpolation(u, t, n, extrapolate) end """ - AkimaInterpolation(u, t; extrapolate = false, safetycopy = true) + AkimaInterpolation(u, t; extrapolate = false, cache_parameters = false) It is a spline interpolation built from cubic polynomials. It forms a continuously differentiable function. For more details, refer: [https://en.wikipedia.org/wiki/Akima_spline](https://en.wikipedia.org/wiki/Akima_spline). Extrapolation extends the last cubic polynomial on each side. @@ -147,7 +144,7 @@ Extrapolation extends the last cubic polynomial on each side. ## Keyword Arguments - `extrapolate`: boolean value to allow extrapolation. Defaults to `false`. - - `safetycopy`: boolean value to make a copy of `u` and `t`. Defaults to `true`. + - `cache_parameters`: precompute parameters at initialization for faster interpolation computations. Note: if activated, `u` and `t` should not be modified. Defaults to `false`. """ struct AkimaInterpolation{uType, tType, IType, bType, cType, dType, T} <: AbstractInterpolation{T} @@ -159,8 +156,8 @@ struct AkimaInterpolation{uType, tType, IType, bType, cType, dType, T} <: d::dType extrapolate::Bool idx_prev::Base.RefValue{Int} - safetycopy::Bool - function AkimaInterpolation(u, t, I, b, c, d, extrapolate, safetycopy) + cache_parameters::Bool + function AkimaInterpolation(u, t, I, b, c, d, extrapolate, cache_parameters) new{typeof(u), typeof(t), typeof(I), typeof(b), typeof(c), typeof(d), eltype(u)}(u, t, @@ -170,13 +167,13 @@ struct AkimaInterpolation{uType, tType, IType, bType, cType, dType, T} <: d, extrapolate, Ref(1), - safetycopy + cache_parameters ) end end -function AkimaInterpolation(u, t; extrapolate = false, safetycopy = true) - u, t = munge_data(u, t, safetycopy) +function AkimaInterpolation(u, t; extrapolate = false, cache_parameters = false) + u, t = munge_data(u, t) n = length(t) dt = diff(t) m = Array{eltype(u)}(undef, n + 3) @@ -197,13 +194,13 @@ function AkimaInterpolation(u, t; extrapolate = false, safetycopy = true) c = (3.0 .* m[3:(end - 2)] .- 2.0 .* b[1:(end - 1)] .- b[2:end]) ./ dt d = (b[1:(end - 1)] .+ b[2:end] .- 2.0 .* m[3:(end - 2)]) ./ dt .^ 2 - A = AkimaInterpolation(u, t, nothing, b, c, d, extrapolate, safetycopy) - I = cumulative_integral(A) - AkimaInterpolation(u, t, I, b, c, d, extrapolate, safetycopy) + A = AkimaInterpolation(u, t, nothing, b, c, d, extrapolate, cache_parameters) + I = cumulative_integral(A, cache_parameters) + AkimaInterpolation(u, t, I, b, c, d, extrapolate, cache_parameters) end """ - ConstantInterpolation(u, t; dir = :left, extrapolate = false, safetycopy = true) + ConstantInterpolation(u, t; dir = :left, extrapolate = false, cache_parameters = false) It is the method of interpolating using a constant polynomial. For any point, two adjacent data points are found on either side (left and right). The value at that point depends on `dir`. If it is `:left`, then the value at the left point is chosen and if it is `:right`, the value at the right point is chosen. @@ -218,7 +215,7 @@ Extrapolation extends the last constant polynomial at the end points on each sid - `dir`: indicates which value should be used for interpolation (`:left` or `:right`). - `extrapolate`: boolean value to allow extrapolation. Defaults to `false`. - - `safetycopy`: boolean value to make a copy of `u` and `t`. Defaults to `true`. + - `cache_parameters`: precompute parameters at initialization for faster interpolation computations. Note: if activated, `u` and `t` should not be modified. Defaults to `false`. """ struct ConstantInterpolation{uType, tType, IType, T} <: AbstractInterpolation{T} u::uType @@ -228,22 +225,23 @@ struct ConstantInterpolation{uType, tType, IType, T} <: AbstractInterpolation{T} dir::Symbol # indicates if value to the $dir should be used for the interpolation extrapolate::Bool idx_prev::Base.RefValue{Int} - safetycopy::Bool - function ConstantInterpolation(u, t, I, dir, extrapolate, safetycopy) + cache_parameters::Bool + function ConstantInterpolation(u, t, I, dir, extrapolate, cache_parameters) new{typeof(u), typeof(t), typeof(I), eltype(u)}( - u, t, I, nothing, dir, extrapolate, Ref(1), safetycopy) + u, t, I, nothing, dir, extrapolate, Ref(1), cache_parameters) end end -function ConstantInterpolation(u, t; dir = :left, extrapolate = false, safetycopy = true) - u, t = munge_data(u, t, safetycopy) - A = ConstantInterpolation(u, t, nothing, dir, extrapolate, safetycopy) - I = cumulative_integral(A) - ConstantInterpolation(u, t, I, dir, extrapolate, safetycopy) +function ConstantInterpolation( + u, t; dir = :left, extrapolate = false, cache_parameters = false) + u, t = munge_data(u, t) + A = ConstantInterpolation(u, t, nothing, dir, extrapolate, cache_parameters) + I = cumulative_integral(A, cache_parameters) + ConstantInterpolation(u, t, I, dir, extrapolate, cache_parameters) end """ - QuadraticSpline(u, t; extrapolate = false, safetycopy = true) + QuadraticSpline(u, t; extrapolate = false, cache_parameters = false) It is a spline interpolation using piecewise quadratic polynomials between each pair of data points. Its first derivative is also continuous. Extrapolation extends the last quadratic polynomial on each side. @@ -256,7 +254,7 @@ Extrapolation extends the last quadratic polynomial on each side. ## Keyword Arguments - `extrapolate`: boolean value to allow extrapolation. Defaults to `false`. - - `safetycopy`: boolean value to make a copy of `u` and `t`. Defaults to `true`. + - `cache_parameters`: precompute parameters at initialization for faster interpolation computations. Note: if activated, `u` and `t` should not be modified. Defaults to `false`. """ struct QuadraticSpline{uType, tType, IType, pType, tAType, dType, zType, T} <: AbstractInterpolation{T} @@ -269,8 +267,8 @@ struct QuadraticSpline{uType, tType, IType, pType, tAType, dType, zType, T} <: z::zType extrapolate::Bool idx_prev::Base.RefValue{Int} - safetycopy::Bool - function QuadraticSpline(u, t, I, p, tA, d, z, extrapolate, safetycopy) + cache_parameters::Bool + function QuadraticSpline(u, t, I, p, tA, d, z, extrapolate, cache_parameters) new{typeof(u), typeof(t), typeof(I), typeof(p.σ), typeof(tA), typeof(d), typeof(z), eltype(u)}(u, t, @@ -281,15 +279,15 @@ struct QuadraticSpline{uType, tType, IType, pType, tAType, dType, zType, T} <: z, extrapolate, Ref(1), - safetycopy + cache_parameters ) end end function QuadraticSpline( u::uType, t; extrapolate = false, - safetycopy = true) where {uType <: AbstractVector{<:Number}} - u, t = munge_data(u, t, safetycopy) + cache_parameters = false) where {uType <: AbstractVector{<:Number}} + u, t = munge_data(u, t) s = length(t) dl = ones(eltype(t), s - 1) d_tmp = ones(eltype(t), s) @@ -301,15 +299,17 @@ function QuadraticSpline( d = map(i -> i == 1 ? typed_zero : 2 // 1 * (u[i] - u[i - 1]) / (t[i] - t[i - 1]), 1:s) z = tA \ d - p = QuadraticSplineParameterCache(z, t) - A = QuadraticSpline(u, t, nothing, p, tA, d, z, extrapolate, safetycopy) - I = cumulative_integral(A) - QuadraticSpline(u, t, I, p, tA, d, z, extrapolate, safetycopy) + + p = QuadraticSplineParameterCache(z, t, cache_parameters) + A = QuadraticSpline(u, t, nothing, p, tA, d, z, extrapolate, cache_parameters) + I = cumulative_integral(A, cache_parameters) + QuadraticSpline(u, t, I, p, tA, d, z, extrapolate, cache_parameters) end function QuadraticSpline( - u::uType, t; extrapolate = false, safetycopy = true) where {uType <: AbstractVector} - u, t = munge_data(u, t, safetycopy) + u::uType, t; extrapolate = false, cache_parameters = false) where {uType <: + AbstractVector} + u, t = munge_data(u, t) s = length(t) dl = ones(eltype(t), s - 1) d_tmp = ones(eltype(t), s) @@ -322,14 +322,15 @@ function QuadraticSpline( d = transpose(reshape(reduce(hcat, d_), :, s)) z_ = reshape(transpose(tA \ d), size(u[1])..., :) z = [z_s for z_s in eachslice(z_, dims = ndims(z_))] - p = QuadraticSplineParameterCache(z, t) - A = QuadraticSpline(u, t, nothing, p, tA, d, z, extrapolate, safetycopy) - I = cumulative_integral(A) - QuadraticSpline(u, t, I, p, tA, d, z, extrapolate, safetycopy) + + p = QuadraticSplineParameterCache(z, t, cache_parameters) + A = QuadraticSpline(u, t, nothing, p, tA, d, z, extrapolate, cache_parameters) + I = cumulative_integral(A, cache_parameters) + QuadraticSpline(u, t, I, p, tA, d, z, extrapolate, cache_parameters) end """ - CubicSpline(u, t; extrapolate = false, safetycopy = true) + CubicSpline(u, t; extrapolate = false, cache_parameters = false) It is a spline interpolation using piecewise cubic polynomials between each pair of data points. Its first and second derivative is also continuous. Second derivative on both ends are zero, which are also called "natural" boundary conditions. Extrapolation extends the last cubic polynomial on each side. @@ -342,7 +343,7 @@ Second derivative on both ends are zero, which are also called "natural" boundar ## Keyword Arguments - `extrapolate`: boolean value to allow extrapolation. Defaults to `false`. - - `safetycopy`: boolean value to make a copy of `u` and `t`. Defaults to `true`. + - `cache_parameters`: precompute parameters at initialization for faster interpolation computations. Note: if activated, `u` and `t` should not be modified. Defaults to `false`. """ struct CubicSpline{uType, tType, IType, pType, hType, zType, T} <: AbstractInterpolation{T} u::uType @@ -353,8 +354,8 @@ struct CubicSpline{uType, tType, IType, pType, hType, zType, T} <: AbstractInter z::zType extrapolate::Bool idx_prev::Base.RefValue{Int} - safetycopy::Bool - function CubicSpline(u, t, I, p, h, z, extrapolate, safetycopy) + cache_parameters::Bool + function CubicSpline(u, t, I, p, h, z, extrapolate, cache_parameters) new{typeof(u), typeof(t), typeof(I), typeof(p.c₁), typeof(h), typeof(z), eltype(u)}( u, t, @@ -364,15 +365,16 @@ struct CubicSpline{uType, tType, IType, pType, hType, zType, T} <: AbstractInter z, extrapolate, Ref(1), - safetycopy + cache_parameters ) end end function CubicSpline(u::uType, t; - extrapolate = false, safetycopy = true) where {uType <: AbstractVector{<:Number}} - u, t = munge_data(u, t, safetycopy) + extrapolate = false, cache_parameters = false) where {uType <: + AbstractVector{<:Number}} + u, t = munge_data(u, t) n = length(t) - 1 h = vcat(0, map(k -> t[k + 1] - t[k], 1:(length(t) - 1)), 0) dl = vcat(h[2:n], zero(eltype(h))) @@ -389,15 +391,17 @@ function CubicSpline(u::uType, 6(u[i + 1] - u[i]) / h[i + 1] - 6(u[i] - u[i - 1]) / h[i], 1:(n + 1)) z = tA \ d - p = CubicSplineParameterCache(u, h, z) - A = CubicSpline(u, t, nothing, p, h[1:(n + 1)], z, extrapolate, safetycopy) - I = cumulative_integral(A) - CubicSpline(u, t, I, p, h[1:(n + 1)], z, extrapolate, safetycopy) + + p = CubicSplineParameterCache(u, h, z, cache_parameters) + A = CubicSpline(u, t, nothing, p, h[1:(n + 1)], z, extrapolate, cache_parameters) + I = cumulative_integral(A, cache_parameters) + CubicSpline(u, t, I, p, h[1:(n + 1)], z, extrapolate, cache_parameters) end function CubicSpline( - u::uType, t; extrapolate = false, safetycopy = true) where {uType <: AbstractVector} - u, t = munge_data(u, t, safetycopy) + u::uType, t; extrapolate = false, cache_parameters = false) where {uType <: + AbstractVector} + u, t = munge_data(u, t) n = length(t) - 1 h = vcat(0, map(k -> t[k + 1] - t[k], 1:(length(t) - 1)), 0) dl = vcat(h[2:n], zero(eltype(h))) @@ -411,10 +415,11 @@ function CubicSpline( d = transpose(reshape(reduce(hcat, d_), :, n + 1)) z_ = reshape(transpose(tA \ d), size(u[1])..., :) z = [z_s for z_s in eachslice(z_, dims = ndims(z_))] - p = CubicSplineParameterCache(u, h, z) - A = CubicSpline(u, t, nothing, p, h[1:(n + 1)], z, extrapolate, safetycopy) - I = cumulative_integral(A) - CubicSpline(u, t, I, p, h[1:(n + 1)], z, extrapolate, safetycopy) + + p = CubicSplineParameterCache(u, h, z, cache_parameters) + A = CubicSpline(u, t, nothing, p, h[1:(n + 1)], z, extrapolate, cache_parameters) + I = cumulative_integral(A, cache_parameters) + CubicSpline(u, t, I, p, h[1:(n + 1)], z, extrapolate, cache_parameters) end """ @@ -434,7 +439,6 @@ Extrapolation is a constant polynomial of the end points on each side. ## Keyword Arguments - `extrapolate`: boolean value to allow extrapolation. Defaults to `false`. - - `safetycopy`: boolean value to make a copy of `u` and `t`. Defaults to `true`. """ struct BSplineInterpolation{uType, tType, pType, kType, cType, NType, T} <: AbstractInterpolation{T} @@ -449,7 +453,6 @@ struct BSplineInterpolation{uType, tType, pType, kType, cType, NType, T} <: knotVecType::Symbol extrapolate::Bool idx_prev::Base.RefValue{Int} - safetycopy::Bool function BSplineInterpolation(u, t, d, @@ -459,8 +462,7 @@ struct BSplineInterpolation{uType, tType, pType, kType, cType, NType, T} <: N, pVecType, knotVecType, - extrapolate, - safetycopy) + extrapolate) new{typeof(u), typeof(t), typeof(p), typeof(k), typeof(c), typeof(N), eltype(u)}(u, t, d, @@ -471,15 +473,14 @@ struct BSplineInterpolation{uType, tType, pType, kType, cType, NType, T} <: pVecType, knotVecType, extrapolate, - Ref(1), - safetycopy + Ref(1) ) end end function BSplineInterpolation( - u, t, d, pVecType, knotVecType; extrapolate = false, safetycopy = true) - u, t = munge_data(u, t, safetycopy) + u, t, d, pVecType, knotVecType; extrapolate = false) + u, t = munge_data(u, t) n = length(t) n < d + 1 && error("BSplineInterpolation needs at least d + 1, i.e. $(d+1) points.") s = zero(eltype(u)) @@ -543,11 +544,11 @@ function BSplineInterpolation( c = vec(N \ u[:, :]) N = zeros(eltype(t), n) BSplineInterpolation( - u, t, d, p, k, c, N, pVecType, knotVecType, extrapolate, safetycopy) + u, t, d, p, k, c, N, pVecType, knotVecType, extrapolate) end """ - BSplineApprox(u, t, d, h, pVecType, knotVecType; extrapolate = false, safetycopy = true) + BSplineApprox(u, t, d, h, pVecType, knotVecType; extrapolate = false) It is a regression based B-spline. The argument choices are the same as the `BSplineInterpolation`, with the additional parameter `h < length(t)` which is the number of control points to use, with smaller `h` indicating more smoothing. For more information, refer [http://www.cad.zju.edu.cn/home/zhx/GM/009/00-bsia.pdf](http://www.cad.zju.edu.cn/home/zhx/GM/009/00-bsia.pdf). @@ -565,7 +566,6 @@ Extrapolation is a constant polynomial of the end points on each side. ## Keyword Arguments - `extrapolate`: boolean value to allow extrapolation. Defaults to `false`. - - `safetycopy`: boolean value to make a copy of `u` and `t`. Defaults to `true`. """ struct BSplineApprox{uType, tType, pType, kType, cType, NType, T} <: AbstractInterpolation{T} @@ -581,7 +581,6 @@ struct BSplineApprox{uType, tType, pType, kType, cType, NType, T} <: knotVecType::Symbol extrapolate::Bool idx_prev::Base.RefValue{Int} - safetycopy::Bool function BSplineApprox(u, t, d, @@ -592,8 +591,7 @@ struct BSplineApprox{uType, tType, pType, kType, cType, NType, T} <: N, pVecType, knotVecType, - extrapolate, - safetycopy + extrapolate ) new{typeof(u), typeof(t), typeof(p), typeof(k), typeof(c), typeof(N), eltype(u)}(u, t, @@ -606,15 +604,14 @@ struct BSplineApprox{uType, tType, pType, kType, cType, NType, T} <: pVecType, knotVecType, extrapolate, - Ref(1), - safetycopy::Bool + Ref(1) ) end end function BSplineApprox( - u, t, d, h, pVecType, knotVecType; extrapolate = false, safetycopy = true) - u, t = munge_data(u, t, safetycopy) + u, t, d, h, pVecType, knotVecType; extrapolate = false) + u, t = munge_data(u, t) n = length(t) h < d + 1 && error("BSplineApprox needs at least d + 1, i.e. $(d+1) control points.") s = zero(eltype(u)) @@ -698,11 +695,12 @@ function BSplineApprox( P = M \ Q c[2:(end - 1)] .= vec(P) N = zeros(eltype(t), h) - BSplineApprox(u, t, d, h, p, k, c, N, pVecType, knotVecType, extrapolate, safetycopy) + BSplineApprox( + u, t, d, h, p, k, c, N, pVecType, knotVecType, extrapolate) end """ - CubicHermiteSpline(du, u, t; extrapolate = false, safetycopy = true) + CubicHermiteSpline(du, u, t; extrapolate = false, cache_parameters = false) It is a Cubic Hermite interpolation, which is a piece-wise third degree polynomial such that the value and the first derivative are equal to given values in the data points. @@ -715,7 +713,7 @@ It is a Cubic Hermite interpolation, which is a piece-wise third degree polynomi ## Keyword Arguments - `extrapolate`: boolean value to allow extrapolation. Defaults to `false`. - - `safetycopy`: boolean value to make a copy of `u` and `t`. Defaults to `true`. + - `cache_parameters`: precompute parameters at initialization for faster interpolation computations. Note: if activated, `u` and `t` should not be modified. Defaults to `false`. """ struct CubicHermiteSpline{uType, tType, IType, duType, pType, T} <: AbstractInterpolation{T} du::duType @@ -725,20 +723,20 @@ struct CubicHermiteSpline{uType, tType, IType, duType, pType, T} <: AbstractInte p::CubicHermiteParameterCache{pType} extrapolate::Bool idx_prev::Base.RefValue{Int} - safetycopy::Bool - function CubicHermiteSpline(du, u, t, I, p, extrapolate, safetycopy) + cache_parameters::Bool + function CubicHermiteSpline(du, u, t, I, p, extrapolate, cache_parameters) new{typeof(u), typeof(t), typeof(I), typeof(du), typeof(p.c₁), eltype(u)}( - du, u, t, I, p, extrapolate, Ref(1), safetycopy) + du, u, t, I, p, extrapolate, Ref(1), cache_parameters) end end -function CubicHermiteSpline(du, u, t; extrapolate = false, safetycopy = true) +function CubicHermiteSpline(du, u, t; extrapolate = false, cache_parameters = false) @assert length(u)==length(du) "Length of `u` is not equal to length of `du`." - u, t = munge_data(u, t, safetycopy) - p = CubicHermiteParameterCache(du, u, t) - A = CubicHermiteSpline(du, u, t, nothing, p, extrapolate, safetycopy) - I = cumulative_integral(A) - CubicHermiteSpline(du, u, t, I, p, extrapolate, safetycopy) + u, t = munge_data(u, t) + p = CubicHermiteParameterCache(du, u, t, cache_parameters) + A = CubicHermiteSpline(du, u, t, nothing, p, extrapolate, cache_parameters) + I = cumulative_integral(A, cache_parameters) + CubicHermiteSpline(du, u, t, I, p, extrapolate, cache_parameters) end """ @@ -756,12 +754,12 @@ section 3.4 for more details. ## Keyword Arguments - `extrapolate`: boolean value to allow extrapolation. Defaults to `false`. - - `safetycopy`: boolean value to make a copy of `u` and `t`. Defaults to `true`. + - `cache_parameters`: precompute parameters at initialization for faster interpolation computations. Note: if activated, `u` and `t` should not be modified. Defaults to `false`. """ -function PCHIPInterpolation(u, t; extrapolate = false, safetycopy = true) - u, t = munge_data(u, t, safetycopy) +function PCHIPInterpolation(u, t; extrapolate = false, cache_parameters = false) + u, t = munge_data(u, t) du = du_PCHIP(u, t) - CubicHermiteSpline(du, u, t; extrapolate, safetycopy) + CubicHermiteSpline(du, u, t; extrapolate, cache_parameters) end """ @@ -779,7 +777,7 @@ It is a Quintic Hermite interpolation, which is a piece-wise fifth degree polyno ## Keyword Arguments - `extrapolate`: boolean value to allow extrapolation. Defaults to `false`. - - `safetycopy`: boolean value to make a copy of `u` and `t`. Defaults to `true`. + - `cache_parameters`: precompute parameters at initialization for faster interpolation computations. Note: if activated, `u` and `t` should not be modified. Defaults to `false`. """ struct QuinticHermiteSpline{uType, tType, IType, duType, dduType, pType, T} <: AbstractInterpolation{T} @@ -791,19 +789,19 @@ struct QuinticHermiteSpline{uType, tType, IType, duType, dduType, pType, T} <: p::QuinticHermiteParameterCache{pType} extrapolate::Bool idx_prev::Base.RefValue{Int} - safetycopy::Bool - function QuinticHermiteSpline(ddu, du, u, t, I, p, extrapolate, safetycopy) + cache_parameters::Bool + function QuinticHermiteSpline(ddu, du, u, t, I, p, extrapolate, cache_parameters) new{typeof(u), typeof(t), typeof(I), typeof(du), typeof(ddu), typeof(p.c₁), eltype(u)}( - ddu, du, u, t, I, p, extrapolate, Ref(1), safetycopy) + ddu, du, u, t, I, p, extrapolate, Ref(1), cache_parameters) end end -function QuinticHermiteSpline(ddu, du, u, t; extrapolate = false, safetycopy = true) +function QuinticHermiteSpline(ddu, du, u, t; extrapolate = false, cache_parameters = false) @assert length(u)==length(du)==length(ddu) "Length of `u` is not equal to length of `du` or `ddu`." - u, t = munge_data(u, t, safetycopy) - p = QuinticHermiteParameterCache(ddu, du, u, t) - A = QuinticHermiteSpline(ddu, du, u, t, nothing, p, extrapolate, safetycopy) - I = cumulative_integral(A) - QuinticHermiteSpline(ddu, du, u, t, I, p, extrapolate, safetycopy) + u, t = munge_data(u, t) + p = QuinticHermiteParameterCache(ddu, du, u, t, cache_parameters) + A = QuinticHermiteSpline(ddu, du, u, t, nothing, p, extrapolate, cache_parameters) + I = cumulative_integral(A, cache_parameters) + QuinticHermiteSpline(ddu, du, u, t, I, p, extrapolate, cache_parameters) end diff --git a/src/interpolation_methods.jl b/src/interpolation_methods.jl index 5d09ceff..c409d5fc 100644 --- a/src/interpolation_methods.jl +++ b/src/interpolation_methods.jl @@ -10,15 +10,15 @@ end function _interpolate(A::LinearInterpolation{<:AbstractVector}, t::Number, iguess) if isnan(t) # For correct derivative with NaN - idx = firstindex(A.u) - 1 + idx = firstindex(A.u) t1 = t2 = one(eltype(A.t)) u1 = u2 = one(eltype(A.u)) - slope = t * one(eltype(A.p.slope)) + slope = t * get_parameters(A, idx) else idx = get_idx(A.t, t, iguess) t1, t2 = A.t[idx], A.t[idx + 1] u1, u2 = A.u[idx], A.u[idx + 1] - slope = A.p.slope[idx] + slope = get_parameters(A, idx) end Δt = t - t1 @@ -38,7 +38,8 @@ end function _interpolate(A::LinearInterpolation{<:AbstractMatrix}, t::Number, iguess) idx = get_idx(A.t, t, iguess) Δt = t - A.t[idx] - return A.u[:, idx] + A.p.slope[idx] * Δt, idx + slope = get_parameters(A, idx) + return A.u[:, idx] + slope * Δt, idx end # Quadratic Interpolation @@ -50,9 +51,10 @@ end function _interpolate(A::QuadraticInterpolation, t::Number, iguess) i₀, i₁, i₂ = _quad_interp_indices(A, t, iguess) - u₀ = A.p.l₀[i₀] * (t - A.t[i₁]) * (t - A.t[i₂]) - u₁ = A.p.l₁[i₀] * (t - A.t[i₀]) * (t - A.t[i₂]) - u₂ = A.p.l₂[i₀] * (t - A.t[i₀]) * (t - A.t[i₁]) + l₀, l₁, l₂ = get_parameters(A, i₀) + u₀ = l₀ * (t - A.t[i₁]) * (t - A.t[i₂]) + u₁ = l₁ * (t - A.t[i₀]) * (t - A.t[i₂]) + u₂ = l₂ * (t - A.t[i₀]) * (t - A.t[i₁]) return u₀ + u₁ + u₂, i₀ end @@ -149,7 +151,8 @@ function _interpolate(A::QuadraticSpline{<:AbstractVector}, t::Number, iguess) idx = get_idx(A.t, t, iguess) Cᵢ = A.u[idx] Δt = t - A.t[idx] - return A.z[idx] * Δt + A.p.σ[idx] * Δt^2 + Cᵢ, idx + σ = get_parameters(A, idx) + return A.z[idx] * Δt + σ * Δt^2 + Cᵢ, idx end # CubicSpline Interpolation @@ -158,8 +161,9 @@ function _interpolate(A::CubicSpline{<:AbstractVector}, t::Number, iguess) Δt₁ = t - A.t[idx] Δt₂ = A.t[idx + 1] - t I = (A.z[idx] * Δt₂^3 + A.z[idx + 1] * Δt₁^3) / (6A.h[idx + 1]) - C = A.p.c₁[idx] * Δt₁ - D = A.p.c₂[idx] * Δt₂ + c₁, c₂ = get_parameters(A, idx) + C = c₁ * Δt₁ + D = c₂ * Δt₂ I + C + D, idx end @@ -205,7 +209,8 @@ function _interpolate( Δt₀ = t - A.t[idx] Δt₁ = t - A.t[idx + 1] out = A.u[idx] + Δt₀ * A.du[idx] - out += Δt₀^2 * (A.p.c₁[idx] + Δt₁ * A.p.c₂[idx]) + c₁, c₂ = get_parameters(A, idx) + out += Δt₀^2 * (c₁ + Δt₁ * c₂) out, idx end @@ -216,6 +221,7 @@ function _interpolate( Δt₀ = t - A.t[idx] Δt₁ = t - A.t[idx + 1] out = A.u[idx] + Δt₀ * (A.du[idx] + A.ddu[idx] * Δt₀ / 2) - out += Δt₀^3 * (A.p.c₁[idx] + Δt₁ * (A.p.c₂[idx] + A.p.c₃[idx] * Δt₁)) + c₁, c₂, c₃ = get_parameters(A, idx) + out += Δt₀^3 * (c₁ + Δt₁ * (c₂ + c₃ * Δt₁)) out, idx end diff --git a/src/interpolation_utils.jl b/src/interpolation_utils.jl index db4393b2..2a2392bd 100644 --- a/src/interpolation_utils.jl +++ b/src/interpolation_utils.jl @@ -60,15 +60,11 @@ function spline_coefficients!(N, d, k, u::AbstractVector) end # helper function for data manipulation -function munge_data(u::AbstractVector{<:Real}, t::AbstractVector{<:Real}, safetycopy::Bool) - if safetycopy - u = copy(u) - t = copy(t) - end - return readonly_wrap(u), readonly_wrap(t) +function munge_data(u::AbstractVector{<:Real}, t::AbstractVector{<:Real}) + return u, t end -function munge_data(u::AbstractVector, t::AbstractVector, safetycopy::Bool) +function munge_data(u::AbstractVector, t::AbstractVector) Tu = Base.nonmissingtype(eltype(u)) Tt = Base.nonmissingtype(eltype(t)) @assert length(t) == length(u) @@ -77,17 +73,13 @@ function munge_data(u::AbstractVector, t::AbstractVector, safetycopy::Bool) if !ismissing(u[i]) && !ismissing(t[i]) ) - if safetycopy - u = Tu.([u[i] for i in non_missing_indices]) - t = Tt.([t[i] for i in non_missing_indices]) - else - !isempty(non_missing_indices) && throw(MustCopyError()) - end + u = Tu.([u[i] for i in non_missing_indices]) + t = Tt.([t[i] for i in non_missing_indices]) - return readonly_wrap(u), readonly_wrap(t) + return u, t end -function munge_data(U::StridedMatrix, t::AbstractVector, safetycopy::Bool) +function munge_data(U::StridedMatrix, t::AbstractVector) TU = Base.nonmissingtype(eltype(U)) Tt = Base.nonmissingtype(eltype(t)) @assert length(t) == size(U, 2) @@ -96,20 +88,12 @@ function munge_data(U::StridedMatrix, t::AbstractVector, safetycopy::Bool) if !any(ismissing, U[:, i]) && !ismissing(t[i]) ) - if safetycopy - U = hcat([TU.(U[:, i]) for i in non_missing_indices]...) - t = Tt.([t[i] for i in non_missing_indices]) - else - !isempty(non_missing_indices) && throw(MustCopyError()) - end + U = hcat([TU.(U[:, i]) for i in non_missing_indices]...) + t = Tt.([t[i] for i in non_missing_indices]) - return readonly_wrap(U), readonly_wrap(t) + return U, t end -# Don't nest ReadOnlyArrays -readonly_wrap(a::AbstractArray) = ReadOnlyArray(a) -readonly_wrap(a::ReadOnlyArray) = a - function get_idx(tvec, t, iguess; lb = 1, ub_shift = -1, idx_shift = 0, side = :last) ub = length(tvec) + ub_shift return if side == :last @@ -121,14 +105,63 @@ function get_idx(tvec, t, iguess; lb = 1, ub_shift = -1, idx_shift = 0, side = : end end -function cumulative_integral(A) - if !hasmethod(_integral, Tuple{typeof(A), Number, Number}) - return nothing +function cumulative_integral(A, cache_parameters) + if cache_parameters && hasmethod(_integral, Tuple{typeof(A), Number, Number}) + integral_values = [_integral(A, idx, A.t[idx + 1]) - _integral(A, idx, A.t[idx]) + for idx in 1:(length(A.t) - 1)] + pushfirst!(integral_values, zero(first(integral_values))) + cumsum(integral_values) + else + promote_type(eltype(A.u), eltype(A.t))[] + end +end + +function get_parameters(A::LinearInterpolation, idx) + if A.cache_parameters + A.p.slope[idx] + else + linear_interpolation_parameters(A.u, A.t, idx) + end +end + +function get_parameters(A::QuadraticInterpolation, idx) + if A.cache_parameters + A.p.l₀[idx], A.p.l₁[idx], A.p.l₂[idx] + else + quadratic_interpolation_parameters(A.u, A.t, idx) + end +end + +function get_parameters(A::QuadraticSpline, idx) + if A.cache_parameters + A.p.σ[idx] + else + quadratic_spline_parameters(A.z, A.t, idx) + end +end + +function get_parameters(A::CubicSpline, idx) + if A.cache_parameters + A.p.c₁[idx], A.p.c₂[idx] + else + cubic_spline_parameters(A.u, A.h, A.z, idx) + end +end + +function get_parameters(A::CubicHermiteSpline, idx) + if A.cache_parameters + A.p.c₁[idx], A.p.c₂[idx] + else + cubic_hermite_spline_parameters(A.du, A.u, A.t, idx) + end +end + +function get_parameters(A::QuinticHermiteSpline, idx) + if A.cache_parameters + A.p.c₁[idx], A.p.c₂[idx], A.p.c₃[idx] + else + quintic_hermite_spline_parameters(A.ddu, A.du, A.u, A.t, idx) end - integral_values = [_integral(A, idx, A.t[idx + 1]) - _integral(A, idx, A.t[idx]) - for idx in 1:(length(A.t) - 1)] - pushfirst!(integral_values, zero(first(integral_values))) - return cumsum(integral_values) end function du_PCHIP(u, t) diff --git a/src/online.jl b/src/online.jl index 630d5cf0..5193e6b2 100644 --- a/src/online.jl +++ b/src/online.jl @@ -9,69 +9,81 @@ function add_integral_values!(A) end function push!(A::LinearInterpolation{U, T}, u::eltype(U), t::eltype(T)) where {U, T} - push!(A.u.parent, u) - push!(A.t.parent, t) - slope = linear_interpolation_parameters(A.u, A.t, length(A.t) - 1) - push!(A.p.slope, slope) - add_integral_values!(A) + push!(A.u, u) + push!(A.t, t) + if A.cache_parameters + slope = linear_interpolation_parameters(A.u, A.t, length(A.t) - 1) + push!(A.p.slope, slope) + add_integral_values!(A) + end A end function push!(A::QuadraticInterpolation{U, T}, u::eltype(U), t::eltype(T)) where {U, T} - push!(A.u.parent, u) - push!(A.t.parent, t) - l₀, l₁, l₂ = quadratic_interpolation_parameters(A.u, A.t, length(A.t) - 2) - push!(A.p.l₀, l₀) - push!(A.p.l₁, l₁) - push!(A.p.l₂, l₂) - add_integral_values!(A) + push!(A.u, u) + push!(A.t, t) + if A.cache_parameters + l₀, l₁, l₂ = quadratic_interpolation_parameters(A.u, A.t, length(A.t) - 2) + push!(A.p.l₀, l₀) + push!(A.p.l₁, l₁) + push!(A.p.l₂, l₂) + add_integral_values!(A) + end A end function push!(A::ConstantInterpolation{U, T}, u::eltype(U), t::eltype(T)) where {U, T} - push!(A.u.parent, u) - push!(A.t.parent, t) - add_integral_values!(A) + push!(A.u, u) + push!(A.t, t) + if A.cache_parameters + add_integral_values!(A) + end A end function append!( - A::LinearInterpolation{ReadOnlyVector{eltypeU, U}, ReadOnlyVector{eltypeT, T}}, u::U, t::T) where { - eltypeU, U, eltypeT, T} + A::LinearInterpolation{U, T}, u::U, t::T) where { + U, T} length_old = length(A.t) - u, t = munge_data(u, t, true) - append!(A.u.parent, u) - append!(A.t.parent, t) - slope = linear_interpolation_parameters.( - Ref(A.u), Ref(A.t), length_old:(length(A.t) - 1)) - append!(A.p.slope, slope) - add_integral_values!(A) + u, t = munge_data(u, t) + append!(A.u, u) + append!(A.t, t) + if A.cache_parameters + slope = linear_interpolation_parameters.( + Ref(A.u), Ref(A.t), length_old:(length(A.t) - 1)) + append!(A.p.slope, slope) + add_integral_values!(A) + end A end function append!( - A::ConstantInterpolation{ReadOnlyVector{eltypeU, U}, ReadOnlyVector{eltypeT, T}}, u::U, t::T) where { - eltypeU, U, eltypeT, T} - u, t = munge_data(u, t, true) - append!(A.u.parent, u) - append!(A.t.parent, t) - add_integral_values!(A) + A::ConstantInterpolation{U, T}, u::U, t::T) where { + U, T} + u, t = munge_data(u, t) + append!(A.u, u) + append!(A.t, t) + if A.cache_parameters + add_integral_values!(A) + end A end function append!( - A::QuadraticInterpolation{ReadOnlyVector{eltypeU, U}, ReadOnlyVector{eltypeT, T}}, u::U, t::T) where { - eltypeU, U, eltypeT, T} + A::QuadraticInterpolation{U, T}, u::U, t::T) where { + U, T} length_old = length(A.t) - u, t = munge_data(u, t, true) - append!(A.u.parent, u) - append!(A.t.parent, t) - parameters = quadratic_interpolation_parameters.( - Ref(A.u), Ref(A.t), (length_old - 1):(length(A.t) - 2)) - l₀, l₁, l₂ = collect.(eachrow(stack(collect.(parameters)))) - append!(A.p.l₀, l₀) - append!(A.p.l₁, l₁) - append!(A.p.l₂, l₂) - add_integral_values!(A) + u, t = munge_data(u, t) + append!(A.u, u) + append!(A.t, t) + if A.cache_parameters + parameters = quadratic_interpolation_parameters.( + Ref(A.u), Ref(A.t), (length_old - 1):(length(A.t) - 2)) + l₀, l₁, l₂ = collect.(eachrow(hcat(collect.(parameters)...))) + append!(A.p.l₀, l₀) + append!(A.p.l₁, l₁) + append!(A.p.l₂, l₂) + add_integral_values!(A) + end A end diff --git a/src/parameter_caches.jl b/src/parameter_caches.jl index a8b4b8d0..0701b3a2 100644 --- a/src/parameter_caches.jl +++ b/src/parameter_caches.jl @@ -2,9 +2,15 @@ struct LinearParameterCache{pType} slope::pType end -function LinearParameterCache(u, t) - slope = linear_interpolation_parameters.(Ref(u), Ref(t), 1:(length(t) - 1)) - return LinearParameterCache(slope) +function LinearParameterCache(u, t, cache_parameters) + if cache_parameters + slope = linear_interpolation_parameters.(Ref(u), Ref(t), 1:(length(t) - 1)) + LinearParameterCache(slope) + else + # Compute parameters once to infer types + slope = linear_interpolation_parameters(u, t, 1) + LinearParameterCache(typeof(slope)[]) + end end # Prevent e.g. Inf - Inf = NaN @@ -30,11 +36,18 @@ struct QuadraticParameterCache{pType} l₂::pType end -function QuadraticParameterCache(u, t) - parameters = quadratic_interpolation_parameters.( - Ref(u), Ref(t), 1:(length(t) - 2)) - l₀, l₁, l₂ = collect.(eachrow(stack(collect.(parameters)))) - return QuadraticParameterCache(l₀, l₁, l₂) +function QuadraticParameterCache(u, t, cache_parameters) + if cache_parameters + parameters = quadratic_interpolation_parameters.( + Ref(u), Ref(t), 1:(length(t) - 2)) + l₀, l₁, l₂ = collect.(eachrow(stack(collect.(parameters)))) + QuadraticParameterCache(l₀, l₁, l₂) + else + # Compute parameters once to infer types + l₀, l₁, l₂ = quadratic_interpolation_parameters(u, t, 1) + pType = typeof(l₀) + QuadraticParameterCache(pType[], pType[], pType[]) + end end function quadratic_interpolation_parameters(u, t, idx) @@ -63,9 +76,15 @@ struct QuadraticSplineParameterCache{pType} σ::pType end -function QuadraticSplineParameterCache(z, t) - σ = quadratic_spline_parameters.(Ref(z), Ref(t), 1:(length(t) - 1)) - return QuadraticSplineParameterCache(σ) +function QuadraticSplineParameterCache(z, t, cache_parameters) + if cache_parameters + σ = quadratic_spline_parameters.(Ref(z), Ref(t), 1:(length(t) - 1)) + QuadraticSplineParameterCache(σ) + else + # Compute parameters once to infer types + σ = quadratic_spline_parameters(z, t, 1) + QuadraticSplineParameterCache(typeof(σ)[]) + end end function quadratic_spline_parameters(z, t, idx) @@ -78,11 +97,18 @@ struct CubicSplineParameterCache{pType} c₂::pType end -function CubicSplineParameterCache(u, h, z) - parameters = cubic_spline_parameters.( - Ref(u), Ref(h), Ref(z), 1:(size(u)[end] - 1)) - c₁, c₂ = collect.(eachrow(stack(collect.(parameters)))) - return CubicSplineParameterCache(c₁, c₂) +function CubicSplineParameterCache(u, h, z, cache_parameters) + if cache_parameters + parameters = cubic_spline_parameters.( + Ref(u), Ref(h), Ref(z), 1:(size(u)[end] - 1)) + c₁, c₂ = collect.(eachrow(stack(collect.(parameters)))) + CubicSplineParameterCache(c₁, c₂) + else + # Compute parameters once to infer types + c₁, c₂ = cubic_spline_parameters(u, h, z, 1) + pType = typeof(c₁) + CubicSplineParameterCache(pType[], pType[]) + end end function cubic_spline_parameters(u, h, z, idx) @@ -96,11 +122,18 @@ struct CubicHermiteParameterCache{pType} c₂::pType end -function CubicHermiteParameterCache(du, u, t) - parameters = cubic_hermite_spline_parameters.( - Ref(du), Ref(u), Ref(t), 1:(length(t) - 1)) - c₁, c₂ = collect.(eachrow(stack(collect.(parameters)))) - return CubicHermiteParameterCache(c₁, c₂) +function CubicHermiteParameterCache(du, u, t, cache_parameters) + if cache_parameters + parameters = cubic_hermite_spline_parameters.( + Ref(du), Ref(u), Ref(t), 1:(length(t) - 1)) + c₁, c₂ = collect.(eachrow(stack(collect.(parameters)))) + CubicHermiteParameterCache(c₁, c₂) + else + # Compute parameters once to infer types + c₁, c₂ = cubic_hermite_spline_parameters(du, u, t, 1) + pType = typeof(c₁) + CubicHermiteParameterCache(pType[], pType[]) + end end function cubic_hermite_spline_parameters(du, u, t, idx) @@ -120,11 +153,18 @@ struct QuinticHermiteParameterCache{pType} c₃::pType end -function QuinticHermiteParameterCache(ddu, du, u, t) - parameters = quintic_hermite_spline_parameters.( - Ref(ddu), Ref(du), Ref(u), Ref(t), 1:(length(t) - 1)) - c₁, c₂, c₃ = collect.(eachrow(stack(collect.(parameters)))) - return QuinticHermiteParameterCache(c₁, c₂, c₃) +function QuinticHermiteParameterCache(ddu, du, u, t, cache_parameters) + if cache_parameters + parameters = quintic_hermite_spline_parameters.( + Ref(ddu), Ref(du), Ref(u), Ref(t), 1:(length(t) - 1)) + c₁, c₂, c₃ = collect.(eachrow(stack(collect.(parameters)))) + QuinticHermiteParameterCache(c₁, c₂, c₃) + else + # Compute parameters once to infer types + c₁, c₂, c₃ = quintic_hermite_spline_parameters(ddu, du, u, t, 1) + pType = typeof(c₁) + QuinticHermiteParameterCache(pType[], pType[], pType[]) + end end function quintic_hermite_spline_parameters(ddu, du, u, t, idx) diff --git a/test/interpolation_tests.jl b/test/interpolation_tests.jl index f611e64c..53fff25e 100644 --- a/test/interpolation_tests.jl +++ b/test/interpolation_tests.jl @@ -9,7 +9,6 @@ function test_interpolation_type(T) @test hasfield(T, :t) @test hasfield(T, :extrapolate) @test hasfield(T, :idx_prev) - @test hasfield(T, :safetycopy) @test !isempty(methods(DataInterpolations._interpolate, (T, Any, Number))) @test !isempty(methods(DataInterpolations._integral, (T, Any, Number))) @test !isempty(methods(DataInterpolations._derivative, (T, Any, Number))) diff --git a/test/online_tests.jl b/test/online_tests.jl index 160dd06f..1872e0cc 100644 --- a/test/online_tests.jl +++ b/test/online_tests.jl @@ -11,9 +11,9 @@ ts_push = 1.0:0.5:4.0 @testset "$method" for method in [ LinearInterpolation, QuadraticInterpolation, ConstantInterpolation] - func1 = method(u1, t1) + func1 = method(copy(u1), copy(t1); cache_parameters = true) append!(func1, u2, t2) - func2 = method(vcat(u1, u2), vcat(t1, t2)) + func2 = method(vcat(u1, u2), vcat(t1, t2); cache_parameters = true) @test func1.u == func2.u @test func1.t == func2.t for name in propertynames(func1.p) @@ -22,9 +22,9 @@ ts_push = 1.0:0.5:4.0 @test func1(ts_append) == func2(ts_append) @test func1.I == func2.I - func1 = method(u1, t1) + func1 = method(copy(u1), copy(t1); cache_parameters = true) push!(func1, 1.0, 4.0) - func2 = method(vcat(u1, 1.0), vcat(t1, 4.0)) + func2 = method(vcat(u1, 1.0), vcat(t1, 4.0); cache_parameters = true) @test func1.u == func2.u @test func1.t == func2.t for name in propertynames(func1.p) diff --git a/test/parameter_tests.jl b/test/parameter_tests.jl index bcd26cf7..2e84b98d 100644 --- a/test/parameter_tests.jl +++ b/test/parameter_tests.jl @@ -3,14 +3,14 @@ using DataInterpolations @testset "Linear Interpolation" begin u = [1.0, 5.0, 3.0, 4.0, 4.0] t = collect(1:5) - A = LinearInterpolation(u, t) + A = LinearInterpolation(u, t; cache_parameters = true) @test A.p.slope ≈ [4.0, -2.0, 1.0, 0.0] end @testset "Quadratic Interpolation" begin u = [1.0, 5.0, 3.0, 4.0, 4.0] t = collect(1:5) - A = QuadraticInterpolation(u, t) + A = QuadraticInterpolation(u, t; cache_parameters = true) @test A.p.l₀ ≈ [0.5, 2.5, 1.5] @test A.p.l₁ ≈ [-5.0, -3.0, -4.0] @test A.p.l₂ ≈ [1.5, 2.0, 2.0] @@ -19,14 +19,14 @@ end @testset "Quadratic Spline" begin u = [1.0, 5.0, 3.0, 4.0, 4.0] t = collect(1:5) - A = QuadraticSpline(u, t) + A = QuadraticSpline(u, t; cache_parameters = true) @test A.p.σ ≈ [4.0, -10.0, 13.0, -14.0] end @testset "Cubic Spline" begin u = [1, 5, 3, 4, 4] t = collect(1:5) - A = CubicSpline(u, t) + A = CubicSpline(u, t; cache_parameters = true) @test A.p.c₁ ≈ [6.839285714285714, 1.642857142857143, 4.589285714285714, 4.0] @test A.p.c₂ ≈ [1.0, 6.839285714285714, 1.642857142857143, 4.589285714285714] end @@ -35,7 +35,7 @@ end du = [5.0, 3.0, 6.0, 8.0, 1.0] u = [1.0, 5.0, 3.0, 4.0, 4.0] t = collect(1:5) - A = CubicHermiteSpline(du, u, t) + A = CubicHermiteSpline(du, u, t; cache_parameters = true) @test A.p.c₁ ≈ [-1.0, -5.0, -5.0, -8.0] @test A.p.c₂ ≈ [0.0, 13.0, 12.0, 9.0] end @@ -45,7 +45,7 @@ end du = [5.0, 3.0, 6.0, 8.0, 1.0] u = [1.0, 5.0, 3.0, 4.0, 4.0] t = collect(1:5) - A = QuinticHermiteSpline(ddu, du, u, t) + A = QuinticHermiteSpline(ddu, du, u, t; cache_parameters = true) @test A.p.c₁ ≈ [-1.0, -6.5, -8.0, -10.0] @test A.p.c₂ ≈ [1.0, 19.5, 20.0, 19.0] @test A.p.c₃ ≈ [1.5, -37.5, -37.0, -26.5] diff --git a/test/runtests.jl b/test/runtests.jl index 0c722b2d..80080a75 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -10,3 +10,4 @@ using SafeTestsets @safetestset "Online Tests" include("online_tests.jl") @safetestset "Regularization Smoothing" include("regularization.jl") @safetestset "Show methods" include("show.jl") +@safetestset "Zygote support" include("zygote_tests.jl") diff --git a/test/zygote_tests.jl b/test/zygote_tests.jl new file mode 100644 index 00000000..1a7fc447 --- /dev/null +++ b/test/zygote_tests.jl @@ -0,0 +1,101 @@ +using DataInterpolations +using ForwardDiff +using Zygote + +function test_zygote(method, u, t; args = [], args_after = [], kwargs = [], name::String) + func = method(args..., u, t, args_after...; kwargs..., extrapolate = true) + (; u, t) = func + trange = collect(range(minimum(t) - 5.0, maximum(t) + 5.0, step = 0.1)) + trange_exclude = filter(x -> !in(x, t), trange) + @testset "$name, derivatives w.r.t. input" begin + for _t in trange_exclude + adiff = DataInterpolations.derivative(func, _t) + zdiff = only(Zygote.gradient(func, _t)) + isnothing(zdiff) && (zdiff = 0.0) + @test adiff ≈ zdiff + end + end + if method ∉ [LagrangeInterpolation, BSplineInterpolation, BSplineApprox] + @testset "$name, derivatives w.r.t. u" begin + function f(u) + A = method(args..., u, t, args_after...; kwargs..., extrapolate = true) + out = zero(eltype(u)) + for _t in trange + out += A(_t) + end + out + end + zgrad = only(Zygote.gradient(f, u)) + fgrad = ForwardDiff.gradient(f, u) + @test zgrad ≈ fgrad + end + end +end + +@testset "LinearInterpolation" begin + u = vcat(collect(1.0:5.0), 2 * collect(6.0:10.0)) + t = collect(1.0:10.0) + test_zygote( + LinearInterpolation, u, t; name = "Linear Interpolation") +end + +@testset "Quadratic Interpolation" begin + u = [1.0, 4.0, 9.0, 16.0] + t = [1.0, 2.0, 3.0, 4.0] + test_zygote(QuadraticInterpolation, u, t; name = "Quadratic Interpolation") +end + +@testset "Constant Interpolation" begin + u = [0.0, 2.0, 1.0, 3.0, 2.0, 6.0, 5.5, 5.5, 2.7, 5.1, 3.0] + t = collect(0.0:10.0) + test_zygote(ConstantInterpolation, u, t; name = "Constant Interpolation") +end + +@testset "Cubic Hermite Spline" begin + du = [-0.047, -0.058, 0.054, 0.012, -0.068, 0.0] + u = [14.7, 11.51, 10.41, 14.95, 12.24, 11.22] + t = [0.0, 62.25, 109.66, 162.66, 205.8, 252.3] + test_zygote(CubicHermiteSpline, u, t, args = [du], name = "Cubic Hermite Spline") +end + +@testset "Quintic Hermite Spline" begin + ddu = [0.0, -0.00033, 0.0051, -0.0067, 0.0029, 0.0] + du = [-0.047, -0.058, 0.054, 0.012, -0.068, 0.0] + u = [14.7, 11.51, 10.41, 14.95, 12.24, 11.22] + t = [0.0, 62.25, 109.66, 162.66, 205.8, 252.3] + test_zygote( + QuinticHermiteSpline, u, t, args = [ddu, du], name = "Quintic Hermite Spline") +end + +@testset "Quadratic Spline" begin + u = [1.0, 4.0, 9.0, 16.0] + t = [1.0, 2.0, 3.0, 4.0] + test_zygote(QuadraticSpline, u, t, name = "Quadratic Spline") +end + +@testset "Lagrange Interpolation" begin + u = [1.0, 4.0, 9.0] + t = [1.0, 2.0, 3.0] + test_zygote(LagrangeInterpolation, u, t, name = "Lagrange Interpolation") +end + +@testset "Constant Interpolation" begin + u = [0.0, 2.0, 1.0, 3.0, 2.0, 6.0, 5.5, 5.5, 2.7, 5.1, 3.0] + t = collect(0.0:10.0) + test_zygote(ConstantInterpolation, u, t, name = "Constant Interpolation") +end + +@testset "Cubic Spline" begin + u = [0.0, 1.0, 3.0] + t = [-1.0, 0.0, 1.0] + test_zygote(CubicSpline, u, t, name = "Cubic Spline") +end + +@testset "BSplines" begin + t = [0, 62.25, 109.66, 162.66, 205.8, 252.3] + u = [14.7, 11.51, 10.41, 14.95, 12.24, 11.22] + test_zygote(BSplineInterpolation, u, t; args_after = [2, :Uniform, :Uniform], + name = "BSpline Interpolation") + test_zygote(BSplineApprox, u, t; args_after = [2, 4, :Uniform, :Uniform], + name = "BSpline approximation") +end