From c8b322455377bcbfe6b08ea14373979def5d00f3 Mon Sep 17 00:00:00 2001 From: Bart de Koning Date: Sat, 27 Jul 2024 17:17:57 +0200 Subject: [PATCH 1/9] Refactor parameter caching, add zygote tests --- Project.toml | 5 +- docs/src/interface.md | 17 +- ext/DataInterpolationsOptimExt.jl | 5 +- ...ataInterpolationsRegularizationToolsExt.jl | 28 +- src/DataInterpolations.jl | 19 +- src/derivatives.jl | 23 +- src/integral_inverses.jl | 15 +- src/integrals.jl | 54 ++- src/interpolation_caches.jl | 335 +++++++++++------- src/interpolation_methods.jl | 30 +- src/interpolation_utils.jl | 84 +++-- src/online.jl | 96 ++--- test/interpolation_tests.jl | 1 - test/online_tests.jl | 8 +- test/parameter_tests.jl | 12 +- test/runtests.jl | 1 + test/zygote_tests.jl | 66 ++++ 17 files changed, 497 insertions(+), 302 deletions(-) create mode 100644 test/zygote_tests.jl diff --git a/Project.toml b/Project.toml index b06d2adb..e492399f 100644 --- a/Project.toml +++ b/Project.toml @@ -7,7 +7,6 @@ FindFirstFunctions = "64ca27bc-2ba2-4a57-88aa-44e436879224" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d" -ReadOnlyArrays = "988b38a3-91fc-5605-94a2-ee2116b3bd83" RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" @@ -33,7 +32,6 @@ LinearAlgebra = "1.10" Optim = "1.6" PrettyTables = "2" QuadGK = "2.9.1" -ReadOnlyArrays = "0.2.0" RecipesBase = "1.3" Reexport = "1" RegularizationTools = "0.6" @@ -55,6 +53,7 @@ SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["Aqua", "SafeTestsets", "ChainRulesCore", "Optim", "RegularizationTools", "Test", "StableRNGs", "FiniteDifferences", "QuadGK", "ForwardDiff", "Symbolics"] +test = ["Aqua", "SafeTestsets", "ChainRulesCore", "Optim", "RegularizationTools", "Test", "StableRNGs", "FiniteDifferences", "QuadGK", "ForwardDiff", "Symbolics", "Zygote"] diff --git a/docs/src/interface.md b/docs/src/interface.md index ca5e9819..cfc9ed0b 100644 --- a/docs/src/interface.md +++ b/docs/src/interface.md @@ -35,22 +35,7 @@ A2(300.0) The values computed beyond the range of the time points provided during interpolation will not be reliable, as these methods only perform well within the range and the first/last piece polynomial fit is extrapolated on either side which might not reflect the true nature of the data. -The keyword `safetycopy = false` can be passed to make sure no copies of `u` and `t` are made when initializing the interpolation object. - -```@example interface -A3 = QuadraticInterpolation(u, t; safetycopy = false) - -# Check for same memory -u === A3.u.parent -``` - -Note that this does not prevent allocation in every interpolation constructor call, because parameter values are cached for all interpolation types except [`ConstantInterpolation`](@ref). - -Because of the caching of parameters which depend on `u` and `t`, this data should not be mutated. Therefore `u` and `t` are wrapped in a `ReadOnlyArray` from [ReadOnlyArrays.jl](https://github.com/JuliaArrays/ReadOnlyArrays.jl). - -```@repl interface -A3.t[2] = 3.14 -``` +The keyword `cache_parameters = true` can be passed to precalculate parameters at initialization, making evalations cheaper to compute. This is not compatible with modifying `u` and `t`. The default `cache_parameters = false` does however not prevent allocation in every interpolation constructor call. ## Derivatives diff --git a/ext/DataInterpolationsOptimExt.jl b/ext/DataInterpolationsOptimExt.jl index 5528503f..b3bce295 100644 --- a/ext/DataInterpolationsOptimExt.jl +++ b/ext/DataInterpolationsOptimExt.jl @@ -18,9 +18,8 @@ function Curvefit(u, box = false, lb = nothing, ub = nothing; - extrapolate = false, - safetycopy = false) - u, t = munge_data(u, t, safetycopy) + extrapolate = false) + u, t = munge_data(u, t) errfun(t, u, p) = sum(abs2.(u .- model(t, p))) if box == false mfit = optimize(p -> errfun(t, u, p), p0, alg) diff --git a/ext/DataInterpolationsRegularizationToolsExt.jl b/ext/DataInterpolationsRegularizationToolsExt.jl index 10ea3e4c..732ea1bb 100644 --- a/ext/DataInterpolationsRegularizationToolsExt.jl +++ b/ext/DataInterpolationsRegularizationToolsExt.jl @@ -69,8 +69,8 @@ A = RegularizationSmooth(u, t, t̂, wls, wr, d; λ = 1.0, alg = :gcv_svd) """ function RegularizationSmooth(u::AbstractVector, t::AbstractVector, t̂::AbstractVector, wls::AbstractVector, wr::AbstractVector, d::Int = 2; - λ::Real = 1.0, alg::Symbol = :gcv_svd, extrapolate::Bool = false, safetycopy::Bool = true) - u, t = munge_data(u, t, safetycopy) + λ::Real = 1.0, alg::Symbol = :gcv_svd, extrapolate::Bool = false) + u, t = munge_data(u, t) M = _mapping_matrix(t̂, t) Wls½ = LA.diagm(sqrt.(wls)) Wr½ = LA.diagm(sqrt.(wr)) @@ -86,8 +86,8 @@ A = RegularizationSmooth(u, t, d; λ = 1.0, alg = :gcv_svd, extrapolate = false) """ function RegularizationSmooth(u::AbstractVector, t::AbstractVector, d::Int = 2; λ::Real = 1.0, - alg::Symbol = :gcv_svd, extrapolate::Bool = false, safetycopy::Bool = true) - u, t = munge_data(u, t, safetycopy) + alg::Symbol = :gcv_svd, extrapolate::Bool = false) + u, t = munge_data(u, t) t̂ = t N = length(t) M = Array{Float64}(LA.I, N, N) @@ -115,8 +115,8 @@ A = RegularizationSmooth(u, t, t̂, d; λ = 1.0, alg = :gcv_svd, extrapolate = f """ function RegularizationSmooth(u::AbstractVector, t::AbstractVector, t̂::AbstractVector, d::Int = 2; λ::Real = 1.0, alg::Symbol = :gcv_svd, - extrapolate::Bool = false, safetycopy::Bool = true) - u, t = munge_data(u, t, safetycopy) + extrapolate::Bool = false) + u, t = munge_data(u, t) N, N̂ = length(t), length(t̂) M = _mapping_matrix(t̂, t) Wls½ = Array{Float64}(LA.I, N, N) @@ -143,8 +143,8 @@ A = RegularizationSmooth(u, t, t̂, wls, d; λ = 1.0, alg = :gcv_svd, extrapolat """ function RegularizationSmooth(u::AbstractVector, t::AbstractVector, t̂::AbstractVector, wls::AbstractVector, d::Int = 2; λ::Real = 1.0, - alg::Symbol = :gcv_svd, extrapolate::Bool = false, safetycopy::Bool = true) - u, t = munge_data(u, t, safetycopy) + alg::Symbol = :gcv_svd, extrapolate::Bool = false) + u, t = munge_data(u, t) N, N̂ = length(t), length(t̂) M = _mapping_matrix(t̂, t) Wls½ = LA.diagm(sqrt.(wls)) @@ -172,8 +172,8 @@ A = RegularizationSmooth( """ function RegularizationSmooth(u::AbstractVector, t::AbstractVector, t̂::Nothing, wls::AbstractVector, d::Int = 2; λ::Real = 1.0, - alg::Symbol = :gcv_svd, extrapolate::Bool = false, safetycopy::Bool = true) - u, t = munge_data(u, t, safetycopy) + alg::Symbol = :gcv_svd, extrapolate::Bool = false) + u, t = munge_data(u, t) t̂ = t N = length(t) M = Array{Float64}(LA.I, N, N) @@ -202,8 +202,8 @@ A = RegularizationSmooth( """ function RegularizationSmooth(u::AbstractVector, t::AbstractVector, t̂::Nothing, wls::AbstractVector, wr::AbstractVector, d::Int = 2; - λ::Real = 1.0, alg::Symbol = :gcv_svd, extrapolate::Bool = false, safetycopy::Bool = true) - u, t = munge_data(u, t, safetycopy) + λ::Real = 1.0, alg::Symbol = :gcv_svd, extrapolate::Bool = false) + u, t = munge_data(u, t) t̂ = t N = length(t) M = Array{Float64}(LA.I, N, N) @@ -232,8 +232,8 @@ A = RegularizationSmooth( """ function RegularizationSmooth(u::AbstractVector, t::AbstractVector, t̂::Nothing, wls::Symbol, d::Int = 2; λ::Real = 1.0, alg::Symbol = :gcv_svd, - extrapolate::Bool = false, safetycopy::Bool = true) - u, t = munge_data(u, t, safetycopy) + extrapolate::Bool = false) + u, t = munge_data(u, t) t̂ = t N = length(t) M = Array{Float64}(LA.I, N, N) diff --git a/src/DataInterpolations.jl b/src/DataInterpolations.jl index c86a6579..19cb47c0 100644 --- a/src/DataInterpolations.jl +++ b/src/DataInterpolations.jl @@ -7,7 +7,6 @@ abstract type AbstractInterpolation{T} end using LinearAlgebra, RecipesBase using PrettyTables using ForwardDiff -using ReadOnlyArrays import FindFirstFunctions: searchsortedfirstcorrelated, searchsortedlastcorrelated, bracketstrictlymontonic @@ -88,12 +87,6 @@ function Base.showerror(io::IO, e::IntegralNotInvertibleError) print(io, INTEGRAL_NOT_INVERTIBLE_ERROR) end -const MUST_COPY_ERROR = "A copy must be made of u, t to filter missing data" -struct MustCopyError <: Exception end -function Base.showerror(io::IO, e::MustCopyError) - print(io, MUST_COPY_ERROR) -end - export LinearInterpolation, QuadraticInterpolation, LagrangeInterpolation, AkimaInterpolation, ConstantInterpolation, QuadraticSpline, CubicSpline, BSplineInterpolation, BSplineApprox, CubicHermiteSpline, @@ -126,12 +119,12 @@ struct RegularizationSmooth{uType, tType, T, T2, ITP <: AbstractInterpolation{T} Aitp, extrapolate) new{typeof(u), typeof(t), eltype(u), typeof(λ), typeof(Aitp)}( - readonly_wrap(u), - readonly_wrap(û), - readonly_wrap(t), - readonly_wrap(t̂), - readonly_wrap(oftype(u.parent, wls)), - readonly_wrap(oftype(u.parent, wr)), + u, + û, + t, + t̂, + wls, + wr, d, λ, alg, diff --git a/src/derivatives.jl b/src/derivatives.jl index 30c76fd0..01eb18bb 100644 --- a/src/derivatives.jl +++ b/src/derivatives.jl @@ -19,14 +19,16 @@ end function _derivative(A::LinearInterpolation, t::Number, iguess) idx = get_idx(A.t, t, iguess; idx_shift = -1, ub_shift = -2, side = :first) - A.p.slope[idx], idx + slope = get_parameters(A, idx) + slope, idx end function _derivative(A::QuadraticInterpolation, t::Number, iguess) i₀, i₁, i₂ = _quad_interp_indices(A, t, iguess) - du₀ = A.p.l₀[i₀] * (2t - A.t[i₁] - A.t[i₂]) - du₁ = A.p.l₁[i₀] * (2t - A.t[i₀] - A.t[i₂]) - du₂ = A.p.l₂[i₀] * (2t - A.t[i₀] - A.t[i₁]) + l₀, l₁, l₂ = get_parameters(A, i₀) + du₀ = l₀ * (2t - A.t[i₁] - A.t[i₂]) + du₁ = l₁ * (2t - A.t[i₀] - A.t[i₂]) + du₂ = l₂ * (2t - A.t[i₀] - A.t[i₁]) return @views @. du₀ + du₁ + du₂, i₀ end @@ -129,7 +131,7 @@ end # QuadraticSpline Interpolation function _derivative(A::QuadraticSpline{<:AbstractVector}, t::Number, iguess) idx = get_idx(A.t, t, iguess; lb = 2, ub_shift = 0, side = :first) - σ = A.p.σ[idx - 1] + σ = get_parameters(A, idx - 1) A.z[idx - 1] + 2σ * (t - A.t[idx - 1]), idx end @@ -139,8 +141,9 @@ function _derivative(A::CubicSpline{<:AbstractVector}, t::Number, iguess) Δt₁ = t - A.t[idx] Δt₂ = A.t[idx + 1] - t dI = (-A.z[idx] * Δt₂^2 + A.z[idx + 1] * Δt₁^2) / (2A.h[idx + 1]) - dC = A.p.c₁[idx] - dD = -A.p.c₂[idx] + c₁, c₂ = get_parameters(A, idx) + dC = c₁ + dD = -c₂ dI + dC + dD, idx end @@ -193,7 +196,8 @@ function _derivative( Δt₀ = t - A.t[idx] Δt₁ = t - A.t[idx + 1] out = A.du[idx] - out += Δt₀ * (Δt₀ * A.p.c₂[idx] + 2(A.p.c₁[idx] + Δt₁ * A.p.c₂[idx])) + c₁, c₂ = get_parameters(A, idx) + out += Δt₀ * (Δt₀ * c₂ + 2(c₁ + Δt₁ * c₂)) out, idx end @@ -204,7 +208,8 @@ function _derivative( Δt₀ = t - A.t[idx] Δt₁ = t - A.t[idx + 1] out = A.du[idx] + A.ddu[idx] * Δt₀ + c₁, c₂, c₃ = get_parameters(A, idx) out += Δt₀^2 * - (3A.p.c₁[idx] + (3Δt₁ + Δt₀) * A.p.c₂[idx] + (3Δt₁^2 + Δt₀ * 2Δt₁) * A.p.c₃[idx]) + (3c₁ + (3Δt₁ + Δt₀) * c₂ + (3Δt₁^2 + Δt₀ * 2Δt₁) * c₃) out, idx end diff --git a/src/integral_inverses.jl b/src/integral_inverses.jl index 4437726e..38c14b14 100644 --- a/src/integral_inverses.jl +++ b/src/integral_inverses.jl @@ -40,10 +40,9 @@ struct LinearInterpolationIntInv{uType, tType, itpType, T} <: extrapolate::Bool idx_prev::Base.RefValue{Int} itp::itpType - safetycopy::Bool function LinearInterpolationIntInv(u, t, A) new{typeof(u), typeof(t), typeof(A), eltype(u)}( - u, t, A.extrapolate, Ref(1), A, A.safetycopy) + u, t, A.extrapolate, Ref(1), A) end end @@ -51,9 +50,11 @@ function invertible_integral(A::LinearInterpolation{<:AbstractVector{<:Number}}) return all(A.u .> 0) end +get_I(A::AbstractInterpolation) = isnothing(A.I) ? cumulative_integral(A) : A.I + function invert_integral(A::LinearInterpolation{<:AbstractVector{<:Number}}) !invertible_integral(A) && throw(IntegralNotInvertibleError()) - return LinearInterpolationIntInv(A.t, A.I, A) + return LinearInterpolationIntInv(A.t, get_I(A), A) end function _interpolate( @@ -61,7 +62,8 @@ function _interpolate( idx = get_idx(A.t, t, iguess) Δt = t - A.t[idx] x = A.itp.u[idx] - u = A.u[idx] + 2Δt / (x + sqrt(x^2 + A.itp.p.slope[idx] * 2Δt)) + slope = get_parameters(A.itp, idx) + u = A.u[idx] + 2Δt / (x + sqrt(x^2 + slope * 2Δt)) u, idx end @@ -84,10 +86,9 @@ struct ConstantInterpolationIntInv{uType, tType, itpType, T} <: extrapolate::Bool idx_prev::Base.RefValue{Int} itp::itpType - safetycopy::Bool function ConstantInterpolationIntInv(u, t, A) new{typeof(u), typeof(t), typeof(A), eltype(u)}( - u, t, A.extrapolate, Ref(1), A, A.safetycopy + u, t, A.extrapolate, Ref(1), A ) end end @@ -98,7 +99,7 @@ end function invert_integral(A::ConstantInterpolation{<:AbstractVector{<:Number}}) !invertible_integral(A) && throw(IntegralNotInvertibleError()) - return ConstantInterpolationIntInv(A.t, A.I, A) + return ConstantInterpolationIntInv(A.t, get_I(A), A) end function _interpolate( diff --git a/src/integrals.jl b/src/integrals.jl index 3040189f..03ea26c5 100644 --- a/src/integrals.jl +++ b/src/integrals.jl @@ -12,14 +12,24 @@ function integral(A::AbstractInterpolation, t1::Number, t2::Number) # the index less than t2 idx2 = get_idx(A.t, t2, 0; idx_shift = -1, side = :first) - total = A.I[idx2] - A.I[idx1] - return if t1 == t2 - zero(total) + if A.cache_parameters + total = A.I[idx2] - A.I[idx1] + return if t1 == t2 + zero(total) + else + total += _integral(A, idx1, A.t[idx1]) + total -= _integral(A, idx1, t1) + total += _integral(A, idx2, t2) + total -= _integral(A, idx2, A.t[idx2]) + total + end else - total += _integral(A, idx1, A.t[idx1]) - total -= _integral(A, idx1, t1) - total += _integral(A, idx2, t2) - total -= _integral(A, idx2, A.t[idx2]) + total = zero(eltype(A.u)) + for idx in idx1:idx2 + lt1 = idx == idx1 ? t1 : A.t[idx] + lt2 = idx == idx2 ? t2 : A.t[idx + 1] + total += _integral(A, idx, lt2) - _integral(A, idx, lt1) + end total end end @@ -28,7 +38,8 @@ function _integral(A::LinearInterpolation{<:AbstractVector{<:Number}}, idx::Number, t::Number) Δt = t - A.t[idx] - Δt * (A.u[idx] + A.p.slope[idx] * Δt / 2) + slope = get_parameters(A, idx) + Δt * (A.u[idx] + slope * Δt / 2) end function _integral( @@ -52,24 +63,27 @@ function _integral(A::QuadraticInterpolation{<:AbstractVector{<:Number}}, t₂ = A.t[idx + 2] t_sq = (t^2) / 3 - Iu₀ = A.p.l₀[idx] * t * (t_sq - t * (t₁ + t₂) / 2 + t₁ * t₂) - Iu₁ = A.p.l₁[idx] * t * (t_sq - t * (t₀ + t₂) / 2 + t₀ * t₂) - Iu₂ = A.p.l₂[idx] * t * (t_sq - t * (t₀ + t₁) / 2 + t₀ * t₁) + l₀, l₁, l₂ = get_parameters(A, idx) + Iu₀ = l₀ * t * (t_sq - t * (t₁ + t₂) / 2 + t₁ * t₂) + Iu₁ = l₁ * t * (t_sq - t * (t₀ + t₂) / 2 + t₀ * t₂) + Iu₂ = l₂ * t * (t_sq - t * (t₀ + t₁) / 2 + t₀ * t₁) return Iu₀ + Iu₁ + Iu₂ end function _integral(A::QuadraticSpline{<:AbstractVector{<:Number}}, idx::Number, t::Number) Cᵢ = A.u[idx] Δt = t - A.t[idx] - return A.z[idx] * Δt^2 / 2 + A.p.σ[idx] * Δt^3 / 3 + Cᵢ * Δt + σ = get_parameters(A, idx) + return A.z[idx] * Δt^2 / 2 + σ * Δt^3 / 3 + Cᵢ * Δt end function _integral(A::CubicSpline{<:AbstractVector{<:Number}}, idx::Number, t::Number) Δt₁sq = (t - A.t[idx])^2 / 2 Δt₂sq = (A.t[idx + 1] - t)^2 / 2 II = (-A.z[idx] * Δt₂sq^2 + A.z[idx + 1] * Δt₁sq^2) / (6A.h[idx + 1]) - IC = A.p.c₁[idx] * Δt₁sq - ID = -A.p.c₂[idx] * Δt₂sq + c₁, c₂ = get_parameters(A, idx) + IC = c₁ * Δt₁sq + ID = -c₂ * Δt₂sq II + IC + ID end @@ -91,8 +105,9 @@ function _integral( Δt₀ = t - A.t[idx] Δt₁ = t - A.t[idx + 1] out = Δt₀ * (A.u[idx] + Δt₀ * A.du[idx] / 2) - p = A.p.c₁[idx] + Δt₁ * A.p.c₂[idx] - dp = A.p.c₂[idx] + c₁, c₂ = get_parameters(A, idx) + p = c₁ + Δt₁ * c₂ + dp = c₂ out += Δt₀^3 / 3 * (p - dp * Δt₀ / 4) out end @@ -103,9 +118,10 @@ function _integral( Δt₀ = t - A.t[idx] Δt₁ = t - A.t[idx + 1] out = Δt₀ * (A.u[idx] + A.du[idx] * Δt₀ / 2 + A.ddu[idx] * Δt₀^2 / 6) - p = A.p.c₁[idx] + A.p.c₂[idx] * Δt₁ + A.p.c₃[idx] * Δt₁^2 - dp = A.p.c₂[idx] + 2A.p.c₃[idx] * Δt₁ - ddp = 2A.p.c₃[idx] + c₁, c₂, c₃ = get_parameters(A, idx) + p = c₁ + c₂ * Δt₁ + c₃ * Δt₁^2 + dp = c₂ + 2c₃ * Δt₁ + ddp = 2c₃ out += Δt₀^4 / 4 * (p - Δt₀ / 5 * dp + Δt₀^2 / 30 * ddp) out end diff --git a/src/interpolation_caches.jl b/src/interpolation_caches.jl index c7274471..286bf6bc 100644 --- a/src/interpolation_caches.jl +++ b/src/interpolation_caches.jl @@ -1,5 +1,5 @@ """ - LinearInterpolation(u, t; extrapolate = false) + LinearInterpolation(u, t; extrapolate = false, cache_parameters = false) It is the method of interpolating between the data points using a linear polynomial. For any point, two data points one each side are chosen and connected with a line. Extrapolation extends the last linear polynomial on each side. @@ -12,7 +12,7 @@ Extrapolation extends the last linear polynomial on each side. ## Keyword Arguments - `extrapolate`: boolean value to allow extrapolation. Defaults to `false`. - - `safetycopy`: boolean value to make a copy of `u` and `t`. Defaults to `true`. + - `cache_parameters`: precompute parameters at initialization for faster interpolation computations. Note: if activated, `u` and `t` should not be modified. Defaults to `false`. """ struct LinearInterpolation{uType, tType, IType, pType, T} <: AbstractInterpolation{T} u::uType @@ -21,23 +21,33 @@ struct LinearInterpolation{uType, tType, IType, pType, T} <: AbstractInterpolati p::LinearParameterCache{pType} extrapolate::Bool idx_prev::Base.RefValue{Int} - safetycopy::Bool - function LinearInterpolation(u, t, I, p, extrapolate, safetycopy) + cache_parameters::Bool + function LinearInterpolation(u, t, I, p, extrapolate, cache_parameters) new{typeof(u), typeof(t), typeof(I), typeof(p.slope), eltype(u)}( - u, t, I, p, extrapolate, Ref(1), safetycopy) + u, t, I, p, extrapolate, Ref(1), cache_parameters) end end -function LinearInterpolation(u, t; extrapolate = false, safetycopy = true) - u, t = munge_data(u, t, safetycopy) - p = LinearParameterCache(u, t) - A = LinearInterpolation(u, t, nothing, p, extrapolate, safetycopy) - I = cumulative_integral(A) - LinearInterpolation(u, t, I, p, extrapolate, safetycopy) +function LinearInterpolation(u, t; extrapolate = false, cache_parameters = false) + u, t = munge_data(u, t) + p = if cache_parameters + LinearParameterCache(u, t) + else + LinearParameterCache(nothing) + end + + A = LinearInterpolation(u, t, nothing, p, extrapolate, cache_parameters) + + if cache_parameters + I = cumulative_integral(A) + A = LinearInterpolation(u, t, I, p, extrapolate, cache_parameters) + end + + A end """ - QuadraticInterpolation(u, t, mode = :Forward; extrapolate = false) + QuadraticInterpolation(u, t, mode = :Forward; cache_parameters = false) It is the method of interpolating between the data points using quadratic polynomials. For any point, three data points nearby are taken to fit a quadratic polynomial. Extrapolation extends the last quadratic polynomial on each side. @@ -51,7 +61,7 @@ Extrapolation extends the last quadratic polynomial on each side. ## Keyword Arguments - `extrapolate`: boolean value to allow extrapolation. Defaults to `false`. - - `safetycopy`: boolean value to make a copy of `u` and `t`. Defaults to `true`. + - `cache_parameters`: precompute parameters at initialization for faster interpolation computations. Note: if activated, `u` and `t` should not be modified. Defaults to `false`. """ struct QuadraticInterpolation{uType, tType, IType, pType, T} <: AbstractInterpolation{T} u::uType @@ -61,25 +71,35 @@ struct QuadraticInterpolation{uType, tType, IType, pType, T} <: AbstractInterpol mode::Symbol extrapolate::Bool idx_prev::Base.RefValue{Int} - safetycopy::Bool - function QuadraticInterpolation(u, t, I, p, mode, extrapolate, safetycopy) + cache_parameters::Bool + function QuadraticInterpolation(u, t, I, p, mode, extrapolate, cache_parameters) mode ∈ (:Forward, :Backward) || error("mode should be :Forward or :Backward for QuadraticInterpolation") new{typeof(u), typeof(t), typeof(I), typeof(p.l₀), eltype(u)}( - u, t, I, p, mode, extrapolate, Ref(1), safetycopy) + u, t, I, p, mode, extrapolate, Ref(1), cache_parameters) end end -function QuadraticInterpolation(u, t, mode; extrapolate = false, safetycopy = true) - u, t = munge_data(u, t, safetycopy) - p = QuadraticParameterCache(u, t) - A = QuadraticInterpolation(u, t, nothing, p, mode, extrapolate, safetycopy) - I = cumulative_integral(A) - QuadraticInterpolation(u, t, I, p, mode, extrapolate, safetycopy) +function QuadraticInterpolation(u, t, mode; extrapolate = false, cache_parameters = false) + u, t = munge_data(u, t) + p = if cache_parameters + QuadraticParameterCache(u, t) + else + QuadraticParameterCache(nothing, nothing, nothing) + end + + A = QuadraticInterpolation(u, t, nothing, p, mode, extrapolate, cache_parameters) + + if cache_parameters + I = cumulative_integral(A) + A = QuadraticInterpolation(u, t, I, p, mode, extrapolate, cache_parameters) + end + + A end -function QuadraticInterpolation(u, t; extrapolate = false, safetycopy = true) - QuadraticInterpolation(u, t, :Forward; extrapolate, safetycopy) +function QuadraticInterpolation(u, t; extrapolate = false, cache_parameters = false) + QuadraticInterpolation(u, t, :Forward; extrapolate, cache_parameters) end """ @@ -96,7 +116,6 @@ It is the method of interpolation using Lagrange polynomials of (k-1)th order pa ## Keyword Arguments - `extrapolate`: boolean value to allow extrapolation. Defaults to `false`. - - `safetycopy`: boolean value to make a copy of `u` and `t`. Defaults to `true`. """ struct LagrangeInterpolation{uType, tType, T, bcacheType} <: AbstractInterpolation{T} @@ -107,8 +126,7 @@ struct LagrangeInterpolation{uType, tType, T, bcacheType} <: idxs::Vector{Int} extrapolate::Bool idx_prev::Base.RefValue{Int} - safetycopy::Bool - function LagrangeInterpolation(u, t, n, extrapolate, safetycopy) + function LagrangeInterpolation(u, t, n, extrapolate) bcache = zeros(eltype(u[1]), n + 1) idxs = zeros(Int, n + 1) fill!(bcache, NaN) @@ -118,23 +136,22 @@ struct LagrangeInterpolation{uType, tType, T, bcacheType} <: bcache, idxs, extrapolate, - Ref(1), - safetycopy + Ref(1) ) end end function LagrangeInterpolation( - u, t, n = length(t) - 1; extrapolate = false, safetycopy = true) - u, t = munge_data(u, t, safetycopy) + u, t, n = length(t) - 1; extrapolate = false) + u, t = munge_data(u, t) if n != length(t) - 1 error("Currently only n=length(t) - 1 is supported") end - LagrangeInterpolation(u, t, n, extrapolate, safetycopy) + LagrangeInterpolation(u, t, n, extrapolate) end """ - AkimaInterpolation(u, t; extrapolate = false) + AkimaInterpolation(u, t; extrapolate = false, cache_parameters = false) It is a spline interpolation built from cubic polynomials. It forms a continuously differentiable function. For more details, refer: https://en.wikipedia.org/wiki/Akima_spline. Extrapolation extends the last cubic polynomial on each side. @@ -147,7 +164,7 @@ Extrapolation extends the last cubic polynomial on each side. ## Keyword Arguments - `extrapolate`: boolean value to allow extrapolation. Defaults to `false`. - - `safetycopy`: boolean value to make a copy of `u` and `t`. Defaults to `true`. + - `cache_parameters`: precompute parameters at initialization for faster interpolation computations. Note: if activated, `u` and `t` should not be modified. Defaults to `false`. """ struct AkimaInterpolation{uType, tType, IType, bType, cType, dType, T} <: AbstractInterpolation{T} @@ -159,8 +176,8 @@ struct AkimaInterpolation{uType, tType, IType, bType, cType, dType, T} <: d::dType extrapolate::Bool idx_prev::Base.RefValue{Int} - safetycopy::Bool - function AkimaInterpolation(u, t, I, b, c, d, extrapolate, safetycopy) + cache_parameters::Bool + function AkimaInterpolation(u, t, I, b, c, d, extrapolate, cache_parameters) new{typeof(u), typeof(t), typeof(I), typeof(b), typeof(c), typeof(d), eltype(u)}(u, t, @@ -170,13 +187,13 @@ struct AkimaInterpolation{uType, tType, IType, bType, cType, dType, T} <: d, extrapolate, Ref(1), - safetycopy + cache_parameters ) end end -function AkimaInterpolation(u, t; extrapolate = false, safetycopy = true) - u, t = munge_data(u, t, safetycopy) +function AkimaInterpolation(u, t; extrapolate = false, cache_parameters = false) + u, t = munge_data(u, t) n = length(t) dt = diff(t) m = Array{eltype(u)}(undef, n + 3) @@ -197,13 +214,18 @@ function AkimaInterpolation(u, t; extrapolate = false, safetycopy = true) c = (3.0 .* m[3:(end - 2)] .- 2.0 .* b[1:(end - 1)] .- b[2:end]) ./ dt d = (b[1:(end - 1)] .+ b[2:end] .- 2.0 .* m[3:(end - 2)]) ./ dt .^ 2 - A = AkimaInterpolation(u, t, nothing, b, c, d, extrapolate, safetycopy) - I = cumulative_integral(A) - AkimaInterpolation(u, t, I, b, c, d, extrapolate, safetycopy) + A = AkimaInterpolation(u, t, nothing, b, c, d, extrapolate, cache_parameters) + + if cache_parameters + I = cumulative_integral(A) + A = AkimaInterpolation(u, t, I, b, c, d, extrapolate, cache_parameters) + end + + A end """ - ConstantInterpolation(u, t; dir = :left, extrapolate = false) + ConstantInterpolation(u, t; dir = :left, extrapolate = false, cache_parameters = false) It is the method of interpolating using a constant polynomial. For any point, two adjacent data points are found on either side (left and right). The value at that point depends on `dir`. If it is `:left`, then the value at the left point is chosen and if it is `:right`, the value at the right point is chosen. @@ -218,7 +240,7 @@ Extrapolation extends the last constant polynomial at the end points on each sid - `dir`: indicates which value should be used for interpolation (`:left` or `:right`). - `extrapolate`: boolean value to allow extrapolation. Defaults to `false`. - - `safetycopy`: boolean value to make a copy of `u` and `t`. Defaults to `true`. + - `cache_parameters`: precompute parameters at initialization for faster interpolation computations. Note: if activated, `u` and `t` should not be modified. Defaults to `false`. """ struct ConstantInterpolation{uType, tType, IType, T} <: AbstractInterpolation{T} u::uType @@ -228,22 +250,28 @@ struct ConstantInterpolation{uType, tType, IType, T} <: AbstractInterpolation{T} dir::Symbol # indicates if value to the $dir should be used for the interpolation extrapolate::Bool idx_prev::Base.RefValue{Int} - safetycopy::Bool - function ConstantInterpolation(u, t, I, dir, extrapolate, safetycopy) + cache_parameters::Bool + function ConstantInterpolation(u, t, I, dir, extrapolate, cache_parameters) new{typeof(u), typeof(t), typeof(I), eltype(u)}( - u, t, I, nothing, dir, extrapolate, Ref(1), safetycopy) + u, t, I, nothing, dir, extrapolate, Ref(1), cache_parameters) end end -function ConstantInterpolation(u, t; dir = :left, extrapolate = false, safetycopy = true) - u, t = munge_data(u, t, safetycopy) - A = ConstantInterpolation(u, t, nothing, dir, extrapolate, safetycopy) - I = cumulative_integral(A) - ConstantInterpolation(u, t, I, dir, extrapolate, safetycopy) +function ConstantInterpolation( + u, t; dir = :left, extrapolate = false, cache_parameters = false) + u, t = munge_data(u, t) + A = ConstantInterpolation(u, t, nothing, dir, extrapolate, cache_parameters) + + if cache_parameters + I = cumulative_integral(A) + A = ConstantInterpolation(u, t, I, dir, extrapolate, cache_parameters) + end + + A end """ - QuadraticSpline(u, t; extrapolate = false) + QuadraticSpline(u, t; extrapolate = false, cache_parameters = false) It is a spline interpolation using piecewise quadratic polynomials between each pair of data points. Its first derivative is also continuous. Extrapolation extends the last quadratic polynomial on each side. @@ -256,7 +284,7 @@ Extrapolation extends the last quadratic polynomial on each side. ## Keyword Arguments - `extrapolate`: boolean value to allow extrapolation. Defaults to `false`. - - `safetycopy`: boolean value to make a copy of `u` and `t`. Defaults to `true`. + - `cache_parameters`: precompute parameters at initialization for faster interpolation computations. Note: if activated, `u` and `t` should not be modified. Defaults to `false`. """ struct QuadraticSpline{uType, tType, IType, pType, tAType, dType, zType, T} <: AbstractInterpolation{T} @@ -269,8 +297,8 @@ struct QuadraticSpline{uType, tType, IType, pType, tAType, dType, zType, T} <: z::zType extrapolate::Bool idx_prev::Base.RefValue{Int} - safetycopy::Bool - function QuadraticSpline(u, t, I, p, tA, d, z, extrapolate, safetycopy) + cache_parameters::Bool + function QuadraticSpline(u, t, I, p, tA, d, z, extrapolate, cache_parameters) new{typeof(u), typeof(t), typeof(I), typeof(p.σ), typeof(tA), typeof(d), typeof(z), eltype(u)}(u, t, @@ -281,15 +309,15 @@ struct QuadraticSpline{uType, tType, IType, pType, tAType, dType, zType, T} <: z, extrapolate, Ref(1), - safetycopy + cache_parameters ) end end function QuadraticSpline( u::uType, t; extrapolate = false, - safetycopy = true) where {uType <: AbstractVector{<:Number}} - u, t = munge_data(u, t, safetycopy) + cache_parameters = false) where {uType <: AbstractVector{<:Number}} + u, t = munge_data(u, t) s = length(t) dl = ones(eltype(t), s - 1) d_tmp = ones(eltype(t), s) @@ -301,15 +329,27 @@ function QuadraticSpline( d = map(i -> i == 1 ? typed_zero : 2 // 1 * (u[i] - u[i - 1]) / (t[i] - t[i - 1]), 1:s) z = tA \ d - p = QuadraticSplineParameterCache(z, t) - A = QuadraticSpline(u, t, nothing, p, tA, d, z, extrapolate, safetycopy) - I = cumulative_integral(A) - QuadraticSpline(u, t, I, p, tA, d, z, extrapolate, safetycopy) + + p = if cache_parameters + QuadraticSplineParameterCache(z, t) + else + QuadraticSplineParameterCache(nothing) + end + + A = QuadraticSpline(u, t, nothing, p, tA, d, z, extrapolate, cache_parameters) + + if cache_parameters + I = cumulative_integral(A) + A = QuadraticSpline(u, t, I, p, tA, d, z, extrapolate, cache_parameters) + end + + A end function QuadraticSpline( - u::uType, t; extrapolate = false, safetycopy = true) where {uType <: AbstractVector} - u, t = munge_data(u, t, safetycopy) + u::uType, t; extrapolate = false, cache_parameters = false) where {uType <: + AbstractVector} + u, t = munge_data(u, t) s = length(t) dl = ones(eltype(t), s - 1) d_tmp = ones(eltype(t), s) @@ -322,14 +362,23 @@ function QuadraticSpline( d = transpose(reshape(reduce(hcat, d_), :, s)) z_ = reshape(transpose(tA \ d), size(u[1])..., :) z = [z_s for z_s in eachslice(z_, dims = ndims(z_))] - p = QuadraticSplineParameterCache(z, t) - A = QuadraticSpline(u, t, nothing, p, tA, d, z, extrapolate, safetycopy) - I = cumulative_integral(A) - QuadraticSpline(u, t, I, p, tA, d, z, extrapolate, safetycopy) + p = if cache_parameters + QuadraticSplineParameterCache(z, t) + else + QuadraticSplineParameterCache(nothing) + end + A = QuadraticSpline(u, t, nothing, p, tA, d, z, extrapolate, cache_parameters) + + if cache_parameters + I = cumulative_integral(A) + A = QuadraticSpline(u, t, I, p, tA, d, z, extrapolate, cache_parameters) + end + + A end """ - CubicSpline(u, t; extrapolate = false) + CubicSpline(u, t; extrapolate = false, cache_parameters = false) It is a spline interpolation using piecewise cubic polynomials between each pair of data points. Its first and second derivative is also continuous. Second derivative on both ends are zero, which are also called "natural" boundary conditions. Extrapolation extends the last cubic polynomial on each side. @@ -342,7 +391,7 @@ Second derivative on both ends are zero, which are also called "natural" boundar ## Keyword Arguments - `extrapolate`: boolean value to allow extrapolation. Defaults to `false`. - - `safetycopy`: boolean value to make a copy of `u` and `t`. Defaults to `true`. + - `cache_parameters`: precompute parameters at initialization for faster interpolation computations. Note: if activated, `u` and `t` should not be modified. Defaults to `false`. """ struct CubicSpline{uType, tType, IType, pType, hType, zType, T} <: AbstractInterpolation{T} u::uType @@ -353,8 +402,8 @@ struct CubicSpline{uType, tType, IType, pType, hType, zType, T} <: AbstractInter z::zType extrapolate::Bool idx_prev::Base.RefValue{Int} - safetycopy::Bool - function CubicSpline(u, t, I, p, h, z, extrapolate, safetycopy) + cache_parameters::Bool + function CubicSpline(u, t, I, p, h, z, extrapolate, cache_parameters) new{typeof(u), typeof(t), typeof(I), typeof(p.c₁), typeof(h), typeof(z), eltype(u)}( u, t, @@ -364,15 +413,16 @@ struct CubicSpline{uType, tType, IType, pType, hType, zType, T} <: AbstractInter z, extrapolate, Ref(1), - safetycopy + cache_parameters ) end end function CubicSpline(u::uType, t; - extrapolate = false, safetycopy = true) where {uType <: AbstractVector{<:Number}} - u, t = munge_data(u, t, safetycopy) + extrapolate = false, cache_parameters = false) where {uType <: + AbstractVector{<:Number}} + u, t = munge_data(u, t) n = length(t) - 1 h = vcat(0, map(k -> t[k + 1] - t[k], 1:(length(t) - 1)), 0) dl = vcat(h[2:n], zero(eltype(h))) @@ -389,15 +439,25 @@ function CubicSpline(u::uType, 6(u[i + 1] - u[i]) / h[i + 1] - 6(u[i] - u[i - 1]) / h[i], 1:(n + 1)) z = tA \ d - p = CubicSplineParameterCache(u, h, z) - A = CubicSpline(u, t, nothing, p, h[1:(n + 1)], z, extrapolate, safetycopy) - I = cumulative_integral(A) - CubicSpline(u, t, I, p, h[1:(n + 1)], z, extrapolate, safetycopy) + p = if cache_parameters + CubicSplineParameterCache(u, h, z) + else + CubicSplineParameterCache(nothing, nothing) + end + A = CubicSpline(u, t, nothing, p, h[1:(n + 1)], z, extrapolate, cache_parameters) + + if cache_parameters + I = cumulative_integral(A) + A = CubicSpline(u, t, I, p, h[1:(n + 1)], z, extrapolate, cache_parameters) + end + + A end function CubicSpline( - u::uType, t; extrapolate = false, safetycopy = true) where {uType <: AbstractVector} - u, t = munge_data(u, t, safetycopy) + u::uType, t; extrapolate = false, cache_parameters = false) where {uType <: + AbstractVector} + u, t = munge_data(u, t) n = length(t) - 1 h = vcat(0, map(k -> t[k + 1] - t[k], 1:(length(t) - 1)), 0) dl = vcat(h[2:n], zero(eltype(h))) @@ -411,10 +471,20 @@ function CubicSpline( d = transpose(reshape(reduce(hcat, d_), :, n + 1)) z_ = reshape(transpose(tA \ d), size(u[1])..., :) z = [z_s for z_s in eachslice(z_, dims = ndims(z_))] - p = CubicSplineParameterCache(u, h, z) - A = CubicSpline(u, t, nothing, p, h[1:(n + 1)], z, extrapolate, safetycopy) - I = cumulative_integral(A) - CubicSpline(u, t, I, p, h[1:(n + 1)], z, extrapolate, safetycopy) + p = if cache_parameters + CubicSplineParameterCache(u, h, z) + else + CubicSplineParameterCache(nothing, nothing) + end + + A = CubicSpline(u, t, nothing, p, h[1:(n + 1)], z, extrapolate, cache_parameters) + + if cache_parameters + I = cumulative_integral(A) + A = CubicSpline(u, t, I, p, h[1:(n + 1)], z, extrapolate, cache_parameters) + end + + A end """ @@ -434,7 +504,6 @@ Extrapolation is a constant polynomial of the end points on each side. ## Keyword Arguments - `extrapolate`: boolean value to allow extrapolation. Defaults to `false`. - - `safetycopy`: boolean value to make a copy of `u` and `t`. Defaults to `true`. """ struct BSplineInterpolation{uType, tType, pType, kType, cType, NType, T} <: AbstractInterpolation{T} @@ -449,7 +518,6 @@ struct BSplineInterpolation{uType, tType, pType, kType, cType, NType, T} <: knotVecType::Symbol extrapolate::Bool idx_prev::Base.RefValue{Int} - safetycopy::Bool function BSplineInterpolation(u, t, d, @@ -459,8 +527,7 @@ struct BSplineInterpolation{uType, tType, pType, kType, cType, NType, T} <: N, pVecType, knotVecType, - extrapolate, - safetycopy) + extrapolate) new{typeof(u), typeof(t), typeof(p), typeof(k), typeof(c), typeof(N), eltype(u)}(u, t, d, @@ -471,15 +538,14 @@ struct BSplineInterpolation{uType, tType, pType, kType, cType, NType, T} <: pVecType, knotVecType, extrapolate, - Ref(1), - safetycopy + Ref(1) ) end end function BSplineInterpolation( - u, t, d, pVecType, knotVecType; extrapolate = false, safetycopy = true) - u, t = munge_data(u, t, safetycopy) + u, t, d, pVecType, knotVecType; extrapolate = false) + u, t = munge_data(u, t) n = length(t) n < d + 1 && error("BSplineInterpolation needs at least d + 1, i.e. $(d+1) points.") s = zero(eltype(u)) @@ -543,11 +609,11 @@ function BSplineInterpolation( c = vec(N \ u[:, :]) N = zeros(eltype(t), n) BSplineInterpolation( - u, t, d, p, k, c, N, pVecType, knotVecType, extrapolate, safetycopy) + u, t, d, p, k, c, N, pVecType, knotVecType, extrapolate) end """ - BSplineApprox(u, t, d, h, pVecType, knotVecType; extrapolate = false) + BSplineApprox(u, t, d, h, pVecType, knotVecType) It is a regression based B-spline. The argument choices are the same as the `BSplineInterpolation`, with the additional parameter `h < length(t)` which is the number of control points to use, with smaller `h` indicating more smoothing. For more information, refer http://www.cad.zju.edu.cn/home/zhx/GM/009/00-bsia.pdf. @@ -565,7 +631,6 @@ Extrapolation is a constant polynomial of the end points on each side. ## Keyword Arguments - `extrapolate`: boolean value to allow extrapolation. Defaults to `false`. - - `safetycopy`: boolean value to make a copy of `u` and `t`. Defaults to `true`. """ struct BSplineApprox{uType, tType, pType, kType, cType, NType, T} <: AbstractInterpolation{T} @@ -581,7 +646,6 @@ struct BSplineApprox{uType, tType, pType, kType, cType, NType, T} <: knotVecType::Symbol extrapolate::Bool idx_prev::Base.RefValue{Int} - safetycopy::Bool function BSplineApprox(u, t, d, @@ -592,8 +656,7 @@ struct BSplineApprox{uType, tType, pType, kType, cType, NType, T} <: N, pVecType, knotVecType, - extrapolate, - safetycopy + extrapolate ) new{typeof(u), typeof(t), typeof(p), typeof(k), typeof(c), typeof(N), eltype(u)}(u, t, @@ -606,15 +669,14 @@ struct BSplineApprox{uType, tType, pType, kType, cType, NType, T} <: pVecType, knotVecType, extrapolate, - Ref(1), - safetycopy::Bool + Ref(1) ) end end function BSplineApprox( - u, t, d, h, pVecType, knotVecType; extrapolate = false, safetycopy = true) - u, t = munge_data(u, t, safetycopy) + u, t, d, h, pVecType, knotVecType; extrapolate = false) + u, t = munge_data(u, t) n = length(t) h < d + 1 && error("BSplineApprox needs at least d + 1, i.e. $(d+1) control points.") s = zero(eltype(u)) @@ -698,11 +760,12 @@ function BSplineApprox( P = M \ Q c[2:(end - 1)] .= vec(P) N = zeros(eltype(t), h) - BSplineApprox(u, t, d, h, p, k, c, N, pVecType, knotVecType, extrapolate, safetycopy) + BSplineApprox( + u, t, d, h, p, k, c, N, pVecType, knotVecType, extrapolate) end """ - CubicHermiteSpline(du, u, t; extrapolate = false) + CubicHermiteSpline(du, u, t; extrapolate = false, cache_parameters = false) It is a Cubic Hermite interpolation, which is a piece-wise third degree polynomial such that the value and the first derivative are equal to given values in the data points. @@ -715,7 +778,7 @@ It is a Cubic Hermite interpolation, which is a piece-wise third degree polynomi ## Keyword Arguments - `extrapolate`: boolean value to allow extrapolation. Defaults to `false`. - - `safetycopy`: boolean value to make a copy of `u` and `t`. Defaults to `true`. + - `cache_parameters`: precompute parameters at initialization for faster interpolation computations. Note: if activated, `u` and `t` should not be modified. Defaults to `false`. """ struct CubicHermiteSpline{uType, tType, IType, duType, pType, T} <: AbstractInterpolation{T} du::duType @@ -725,24 +788,33 @@ struct CubicHermiteSpline{uType, tType, IType, duType, pType, T} <: AbstractInte p::CubicHermiteParameterCache{pType} extrapolate::Bool idx_prev::Base.RefValue{Int} - safetycopy::Bool - function CubicHermiteSpline(du, u, t, I, p, extrapolate, safetycopy) + cache_parameters::Bool + function CubicHermiteSpline(du, u, t, I, p, extrapolate, cache_parameters) new{typeof(u), typeof(t), typeof(I), typeof(du), typeof(p.c₁), eltype(u)}( - du, u, t, I, p, extrapolate, Ref(1), safetycopy) + du, u, t, I, p, extrapolate, Ref(1), cache_parameters) end end -function CubicHermiteSpline(du, u, t; extrapolate = false, safetycopy = true) +function CubicHermiteSpline(du, u, t; extrapolate = false, cache_parameters = false) @assert length(u)==length(du) "Length of `u` is not equal to length of `du`." - u, t = munge_data(u, t, safetycopy) - p = CubicHermiteParameterCache(du, u, t) - A = CubicHermiteSpline(du, u, t, nothing, p, extrapolate, safetycopy) - I = cumulative_integral(A) - CubicHermiteSpline(du, u, t, I, p, extrapolate, safetycopy) + u, t = munge_data(u, t) + p = if cache_parameters + CubicHermiteParameterCache(du, u, t) + else + CubicHermiteParameterCache(nothing, nothing) + end + A = CubicHermiteSpline(du, u, t, nothing, p, extrapolate, cache_parameters) + + if cache_parameters + I = cumulative_integral(A) + A = CubicHermiteSpline(du, u, t, I, p, extrapolate, cache_parameters) + end + + A end """ - QuinticHermiteSpline(ddu, du, u, t; extrapolate = false) + QuinticHermiteSpline(ddu, du, u, t; extrapolate = false, cache_parameters = false) It is a Quintic Hermite interpolation, which is a piece-wise fifth degree polynomial such that the value and the first and second derivative are equal to given values in the data points. @@ -756,7 +828,7 @@ It is a Quintic Hermite interpolation, which is a piece-wise fifth degree polyno ## Keyword Arguments - `extrapolate`: boolean value to allow extrapolation. Defaults to `false`. - - `safetycopy`: boolean value to make a copy of `u` and `t`. Defaults to `true`. + - `cache_parameters`: precompute parameters at initialization for faster interpolation computations. Note: if activated, `u` and `t` should not be modified. Defaults to `false`. """ struct QuinticHermiteSpline{uType, tType, IType, duType, dduType, pType, T} <: AbstractInterpolation{T} @@ -768,19 +840,28 @@ struct QuinticHermiteSpline{uType, tType, IType, duType, dduType, pType, T} <: p::QuinticHermiteParameterCache{pType} extrapolate::Bool idx_prev::Base.RefValue{Int} - safetycopy::Bool - function QuinticHermiteSpline(ddu, du, u, t, I, p, extrapolate, safetycopy) + cache_parameters::Bool + function QuinticHermiteSpline(ddu, du, u, t, I, p, extrapolate, cache_parameters) new{typeof(u), typeof(t), typeof(I), typeof(du), typeof(ddu), typeof(p.c₁), eltype(u)}( - ddu, du, u, t, I, p, extrapolate, Ref(1), safetycopy) + ddu, du, u, t, I, p, extrapolate, Ref(1), cache_parameters) end end -function QuinticHermiteSpline(ddu, du, u, t; extrapolate = false, safetycopy = true) +function QuinticHermiteSpline(ddu, du, u, t; extrapolate = false, cache_parameters = false) @assert length(u)==length(du)==length(ddu) "Length of `u` is not equal to length of `du` or `ddu`." - u, t = munge_data(u, t, safetycopy) - p = QuinticHermiteParameterCache(ddu, du, u, t) - A = QuinticHermiteSpline(ddu, du, u, t, nothing, p, extrapolate, safetycopy) - I = cumulative_integral(A) - QuinticHermiteSpline(ddu, du, u, t, I, p, extrapolate, safetycopy) + u, t = munge_data(u, t) + p = if cache_parameters + QuinticHermiteParameterCache(ddu, du, u, t) + else + QuinticHermiteParameterCache(nothing, nothing, nothing) + end + A = QuinticHermiteSpline(ddu, du, u, t, nothing, p, extrapolate, cache_parameters) + + if cache_parameters + I = cumulative_integral(A) + A = QuinticHermiteSpline(ddu, du, u, t, I, p, extrapolate, cache_parameters) + end + + A end diff --git a/src/interpolation_methods.jl b/src/interpolation_methods.jl index 5d09ceff..c409d5fc 100644 --- a/src/interpolation_methods.jl +++ b/src/interpolation_methods.jl @@ -10,15 +10,15 @@ end function _interpolate(A::LinearInterpolation{<:AbstractVector}, t::Number, iguess) if isnan(t) # For correct derivative with NaN - idx = firstindex(A.u) - 1 + idx = firstindex(A.u) t1 = t2 = one(eltype(A.t)) u1 = u2 = one(eltype(A.u)) - slope = t * one(eltype(A.p.slope)) + slope = t * get_parameters(A, idx) else idx = get_idx(A.t, t, iguess) t1, t2 = A.t[idx], A.t[idx + 1] u1, u2 = A.u[idx], A.u[idx + 1] - slope = A.p.slope[idx] + slope = get_parameters(A, idx) end Δt = t - t1 @@ -38,7 +38,8 @@ end function _interpolate(A::LinearInterpolation{<:AbstractMatrix}, t::Number, iguess) idx = get_idx(A.t, t, iguess) Δt = t - A.t[idx] - return A.u[:, idx] + A.p.slope[idx] * Δt, idx + slope = get_parameters(A, idx) + return A.u[:, idx] + slope * Δt, idx end # Quadratic Interpolation @@ -50,9 +51,10 @@ end function _interpolate(A::QuadraticInterpolation, t::Number, iguess) i₀, i₁, i₂ = _quad_interp_indices(A, t, iguess) - u₀ = A.p.l₀[i₀] * (t - A.t[i₁]) * (t - A.t[i₂]) - u₁ = A.p.l₁[i₀] * (t - A.t[i₀]) * (t - A.t[i₂]) - u₂ = A.p.l₂[i₀] * (t - A.t[i₀]) * (t - A.t[i₁]) + l₀, l₁, l₂ = get_parameters(A, i₀) + u₀ = l₀ * (t - A.t[i₁]) * (t - A.t[i₂]) + u₁ = l₁ * (t - A.t[i₀]) * (t - A.t[i₂]) + u₂ = l₂ * (t - A.t[i₀]) * (t - A.t[i₁]) return u₀ + u₁ + u₂, i₀ end @@ -149,7 +151,8 @@ function _interpolate(A::QuadraticSpline{<:AbstractVector}, t::Number, iguess) idx = get_idx(A.t, t, iguess) Cᵢ = A.u[idx] Δt = t - A.t[idx] - return A.z[idx] * Δt + A.p.σ[idx] * Δt^2 + Cᵢ, idx + σ = get_parameters(A, idx) + return A.z[idx] * Δt + σ * Δt^2 + Cᵢ, idx end # CubicSpline Interpolation @@ -158,8 +161,9 @@ function _interpolate(A::CubicSpline{<:AbstractVector}, t::Number, iguess) Δt₁ = t - A.t[idx] Δt₂ = A.t[idx + 1] - t I = (A.z[idx] * Δt₂^3 + A.z[idx + 1] * Δt₁^3) / (6A.h[idx + 1]) - C = A.p.c₁[idx] * Δt₁ - D = A.p.c₂[idx] * Δt₂ + c₁, c₂ = get_parameters(A, idx) + C = c₁ * Δt₁ + D = c₂ * Δt₂ I + C + D, idx end @@ -205,7 +209,8 @@ function _interpolate( Δt₀ = t - A.t[idx] Δt₁ = t - A.t[idx + 1] out = A.u[idx] + Δt₀ * A.du[idx] - out += Δt₀^2 * (A.p.c₁[idx] + Δt₁ * A.p.c₂[idx]) + c₁, c₂ = get_parameters(A, idx) + out += Δt₀^2 * (c₁ + Δt₁ * c₂) out, idx end @@ -216,6 +221,7 @@ function _interpolate( Δt₀ = t - A.t[idx] Δt₁ = t - A.t[idx + 1] out = A.u[idx] + Δt₀ * (A.du[idx] + A.ddu[idx] * Δt₀ / 2) - out += Δt₀^3 * (A.p.c₁[idx] + Δt₁ * (A.p.c₂[idx] + A.p.c₃[idx] * Δt₁)) + c₁, c₂, c₃ = get_parameters(A, idx) + out += Δt₀^3 * (c₁ + Δt₁ * (c₂ + c₃ * Δt₁)) out, idx end diff --git a/src/interpolation_utils.jl b/src/interpolation_utils.jl index 466248b1..17b08328 100644 --- a/src/interpolation_utils.jl +++ b/src/interpolation_utils.jl @@ -60,15 +60,11 @@ function spline_coefficients!(N, d, k, u::AbstractVector) end # helper function for data manipulation -function munge_data(u::AbstractVector{<:Real}, t::AbstractVector{<:Real}, safetycopy::Bool) - if safetycopy - u = copy(u) - t = copy(t) - end - return readonly_wrap(u), readonly_wrap(t) +function munge_data(u::AbstractVector{<:Real}, t::AbstractVector{<:Real}) + return u, t end -function munge_data(u::AbstractVector, t::AbstractVector, safetycopy::Bool) +function munge_data(u::AbstractVector, t::AbstractVector) Tu = Base.nonmissingtype(eltype(u)) Tt = Base.nonmissingtype(eltype(t)) @assert length(t) == length(u) @@ -77,17 +73,13 @@ function munge_data(u::AbstractVector, t::AbstractVector, safetycopy::Bool) if !ismissing(u[i]) && !ismissing(t[i]) ) - if safetycopy - u = Tu.([u[i] for i in non_missing_indices]) - t = Tt.([t[i] for i in non_missing_indices]) - else - !isempty(non_missing_indices) && throw(MustCopyError()) - end + u = Tu.([u[i] for i in non_missing_indices]) + t = Tt.([t[i] for i in non_missing_indices]) - return readonly_wrap(u), readonly_wrap(t) + return u, t end -function munge_data(U::StridedMatrix, t::AbstractVector, safetycopy::Bool) +function munge_data(U::StridedMatrix, t::AbstractVector) TU = Base.nonmissingtype(eltype(U)) Tt = Base.nonmissingtype(eltype(t)) @assert length(t) == size(U, 2) @@ -96,20 +88,12 @@ function munge_data(U::StridedMatrix, t::AbstractVector, safetycopy::Bool) if !any(ismissing, U[:, i]) && !ismissing(t[i]) ) - if safetycopy - U = hcat([TU.(U[:, i]) for i in non_missing_indices]...) - t = Tt.([t[i] for i in non_missing_indices]) - else - !isempty(non_missing_indices) && throw(MustCopyError()) - end + U = hcat([TU.(U[:, i]) for i in non_missing_indices]...) + t = Tt.([t[i] for i in non_missing_indices]) - return readonly_wrap(U), readonly_wrap(t) + return U, t end -# Don't nest ReadOnlyArrays -readonly_wrap(a::AbstractArray) = ReadOnlyArray(a) -readonly_wrap(a::ReadOnlyArray) = a - function get_idx(tvec, t, iguess; lb = 1, ub_shift = -1, idx_shift = 0, side = :last) ub = length(tvec) + ub_shift return if side == :last @@ -130,3 +114,51 @@ function cumulative_integral(A) pushfirst!(integral_values, zero(first(integral_values))) return cumsum(integral_values) end + +function get_parameters(A::LinearInterpolation, idx) + if A.cache_parameters + A.p.slope[idx] + else + linear_interpolation_parameters(A.u, A.t, idx) + end +end + +function get_parameters(A::QuadraticInterpolation, idx) + if A.cache_parameters + A.p.l₀[idx], A.p.l₁[idx], A.p.l₂[idx] + else + quadratic_interpolation_parameters(A.u, A.t, idx) + end +end + +function get_parameters(A::QuadraticSpline, idx) + if A.cache_parameters + A.p.σ[idx] + else + quadratic_spline_parameters(A.z, A.t, idx) + end +end + +function get_parameters(A::CubicSpline, idx) + if A.cache_parameters + A.p.c₁[idx], A.p.c₂[idx] + else + cubic_spline_parameters(A.u, A.h, A.z, idx) + end +end + +function get_parameters(A::CubicHermiteSpline, idx) + if A.cache_parameters + A.p.c₁[idx], A.p.c₂[idx] + else + cubic_hermite_spline_parameters(A.du, A.u, A.t, idx) + end +end + +function get_parameters(A::QuinticHermiteSpline, idx) + if A.cache_parameters + A.p.c₁[idx], A.p.c₂[idx], A.p.c₃[idx] + else + quintic_hermite_spline_parameters(A.ddu, A.du, A.u, A.t, idx) + end +end diff --git a/src/online.jl b/src/online.jl index 0fab5d44..5193e6b2 100644 --- a/src/online.jl +++ b/src/online.jl @@ -9,69 +9,81 @@ function add_integral_values!(A) end function push!(A::LinearInterpolation{U, T}, u::eltype(U), t::eltype(T)) where {U, T} - push!(A.u.parent, u) - push!(A.t.parent, t) - slope = linear_interpolation_parameters(A.u, A.t, length(A.t) - 1) - push!(A.p.slope, slope) - add_integral_values!(A) + push!(A.u, u) + push!(A.t, t) + if A.cache_parameters + slope = linear_interpolation_parameters(A.u, A.t, length(A.t) - 1) + push!(A.p.slope, slope) + add_integral_values!(A) + end A end function push!(A::QuadraticInterpolation{U, T}, u::eltype(U), t::eltype(T)) where {U, T} - push!(A.u.parent, u) - push!(A.t.parent, t) - l₀, l₁, l₂ = quadratic_interpolation_parameters(A.u, A.t, length(A.t) - 2) - push!(A.p.l₀, l₀) - push!(A.p.l₁, l₁) - push!(A.p.l₂, l₂) - add_integral_values!(A) + push!(A.u, u) + push!(A.t, t) + if A.cache_parameters + l₀, l₁, l₂ = quadratic_interpolation_parameters(A.u, A.t, length(A.t) - 2) + push!(A.p.l₀, l₀) + push!(A.p.l₁, l₁) + push!(A.p.l₂, l₂) + add_integral_values!(A) + end A end function push!(A::ConstantInterpolation{U, T}, u::eltype(U), t::eltype(T)) where {U, T} - push!(A.u.parent, u) - push!(A.t.parent, t) - add_integral_values!(A) + push!(A.u, u) + push!(A.t, t) + if A.cache_parameters + add_integral_values!(A) + end A end function append!( - A::LinearInterpolation{ReadOnlyVector{eltypeU, U}, ReadOnlyVector{eltypeT, T}}, u::U, t::T) where { - eltypeU, U, eltypeT, T} + A::LinearInterpolation{U, T}, u::U, t::T) where { + U, T} length_old = length(A.t) - u, t = munge_data(u, t, true) - append!(A.u.parent, u) - append!(A.t.parent, t) - slope = linear_interpolation_parameters.( - Ref(A.u), Ref(A.t), length_old:(length(A.t) - 1)) - append!(A.p.slope, slope) - add_integral_values!(A) + u, t = munge_data(u, t) + append!(A.u, u) + append!(A.t, t) + if A.cache_parameters + slope = linear_interpolation_parameters.( + Ref(A.u), Ref(A.t), length_old:(length(A.t) - 1)) + append!(A.p.slope, slope) + add_integral_values!(A) + end A end function append!( - A::ConstantInterpolation{ReadOnlyVector{eltypeU, U}, ReadOnlyVector{eltypeT, T}}, u::U, t::T) where { - eltypeU, U, eltypeT, T} - u, t = munge_data(u, t, true) - append!(A.u.parent, u) - append!(A.t.parent, t) - add_integral_values!(A) + A::ConstantInterpolation{U, T}, u::U, t::T) where { + U, T} + u, t = munge_data(u, t) + append!(A.u, u) + append!(A.t, t) + if A.cache_parameters + add_integral_values!(A) + end A end function append!( - A::QuadraticInterpolation{ReadOnlyVector{eltypeU, U}, ReadOnlyVector{eltypeT, T}}, u::U, t::T) where { - eltypeU, U, eltypeT, T} + A::QuadraticInterpolation{U, T}, u::U, t::T) where { + U, T} length_old = length(A.t) - u, t = munge_data(u, t, true) - append!(A.u.parent, u) - append!(A.t.parent, t) - parameters = quadratic_interpolation_parameters.( - Ref(A.u), Ref(A.t), (length_old - 1):(length(A.t) - 2)) - l₀, l₁, l₂ = collect.(eachrow(hcat(collect.(parameters)...))) - append!(A.p.l₀, l₀) - append!(A.p.l₁, l₁) - append!(A.p.l₂, l₂) - add_integral_values!(A) + u, t = munge_data(u, t) + append!(A.u, u) + append!(A.t, t) + if A.cache_parameters + parameters = quadratic_interpolation_parameters.( + Ref(A.u), Ref(A.t), (length_old - 1):(length(A.t) - 2)) + l₀, l₁, l₂ = collect.(eachrow(hcat(collect.(parameters)...))) + append!(A.p.l₀, l₀) + append!(A.p.l₁, l₁) + append!(A.p.l₂, l₂) + add_integral_values!(A) + end A end diff --git a/test/interpolation_tests.jl b/test/interpolation_tests.jl index 9549038c..9562c7b4 100644 --- a/test/interpolation_tests.jl +++ b/test/interpolation_tests.jl @@ -9,7 +9,6 @@ function test_interpolation_type(T) @test hasfield(T, :t) @test hasfield(T, :extrapolate) @test hasfield(T, :idx_prev) - @test hasfield(T, :safetycopy) @test !isempty(methods(DataInterpolations._interpolate, (T, Any, Number))) @test !isempty(methods(DataInterpolations._integral, (T, Any, Number))) @test !isempty(methods(DataInterpolations._derivative, (T, Any, Number))) diff --git a/test/online_tests.jl b/test/online_tests.jl index f9c3e1dd..3ae6438e 100644 --- a/test/online_tests.jl +++ b/test/online_tests.jl @@ -9,9 +9,9 @@ u2 = [1.0, 2.0, 1.0] ts = 1.0:0.5:6.0 for method in [LinearInterpolation, QuadraticInterpolation, ConstantInterpolation] - func1 = method(u1, t1) + func1 = method(copy(u1), copy(t1); cache_parameters = true) append!(func1, u2, t2) - func2 = method(vcat(u1, u2), vcat(t1, t2)) + func2 = method(vcat(u1, u2), vcat(t1, t2); cache_parameters = true) @test func1.u == func2.u @test func1.t == func2.t for name in propertynames(func1.p) @@ -20,9 +20,9 @@ for method in [LinearInterpolation, QuadraticInterpolation, ConstantInterpolatio @test func1(ts) == func2(ts) @test func1.I == func2.I - func1 = method(u1, t1) + func1 = method(copy(u1), copy(t1); cache_parameters = true) push!(func1, 1.0, 4.0) - func2 = method(vcat(u1, 1.0), vcat(t1, 4.0)) + func2 = method(vcat(u1, 1.0), vcat(t1, 4.0); cache_parameters = true) @test func1.u == func2.u @test func1.t == func2.t for name in propertynames(func1.p) diff --git a/test/parameter_tests.jl b/test/parameter_tests.jl index bcd26cf7..2e84b98d 100644 --- a/test/parameter_tests.jl +++ b/test/parameter_tests.jl @@ -3,14 +3,14 @@ using DataInterpolations @testset "Linear Interpolation" begin u = [1.0, 5.0, 3.0, 4.0, 4.0] t = collect(1:5) - A = LinearInterpolation(u, t) + A = LinearInterpolation(u, t; cache_parameters = true) @test A.p.slope ≈ [4.0, -2.0, 1.0, 0.0] end @testset "Quadratic Interpolation" begin u = [1.0, 5.0, 3.0, 4.0, 4.0] t = collect(1:5) - A = QuadraticInterpolation(u, t) + A = QuadraticInterpolation(u, t; cache_parameters = true) @test A.p.l₀ ≈ [0.5, 2.5, 1.5] @test A.p.l₁ ≈ [-5.0, -3.0, -4.0] @test A.p.l₂ ≈ [1.5, 2.0, 2.0] @@ -19,14 +19,14 @@ end @testset "Quadratic Spline" begin u = [1.0, 5.0, 3.0, 4.0, 4.0] t = collect(1:5) - A = QuadraticSpline(u, t) + A = QuadraticSpline(u, t; cache_parameters = true) @test A.p.σ ≈ [4.0, -10.0, 13.0, -14.0] end @testset "Cubic Spline" begin u = [1, 5, 3, 4, 4] t = collect(1:5) - A = CubicSpline(u, t) + A = CubicSpline(u, t; cache_parameters = true) @test A.p.c₁ ≈ [6.839285714285714, 1.642857142857143, 4.589285714285714, 4.0] @test A.p.c₂ ≈ [1.0, 6.839285714285714, 1.642857142857143, 4.589285714285714] end @@ -35,7 +35,7 @@ end du = [5.0, 3.0, 6.0, 8.0, 1.0] u = [1.0, 5.0, 3.0, 4.0, 4.0] t = collect(1:5) - A = CubicHermiteSpline(du, u, t) + A = CubicHermiteSpline(du, u, t; cache_parameters = true) @test A.p.c₁ ≈ [-1.0, -5.0, -5.0, -8.0] @test A.p.c₂ ≈ [0.0, 13.0, 12.0, 9.0] end @@ -45,7 +45,7 @@ end du = [5.0, 3.0, 6.0, 8.0, 1.0] u = [1.0, 5.0, 3.0, 4.0, 4.0] t = collect(1:5) - A = QuinticHermiteSpline(ddu, du, u, t) + A = QuinticHermiteSpline(ddu, du, u, t; cache_parameters = true) @test A.p.c₁ ≈ [-1.0, -6.5, -8.0, -10.0] @test A.p.c₂ ≈ [1.0, 19.5, 20.0, 19.0] @test A.p.c₃ ≈ [1.5, -37.5, -37.0, -26.5] diff --git a/test/runtests.jl b/test/runtests.jl index 0c722b2d..80080a75 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -10,3 +10,4 @@ using SafeTestsets @safetestset "Online Tests" include("online_tests.jl") @safetestset "Regularization Smoothing" include("regularization.jl") @safetestset "Show methods" include("show.jl") +@safetestset "Zygote support" include("zygote_tests.jl") diff --git a/test/zygote_tests.jl b/test/zygote_tests.jl new file mode 100644 index 00000000..6887ddd2 --- /dev/null +++ b/test/zygote_tests.jl @@ -0,0 +1,66 @@ +using DataInterpolations +using ForwardDiff +using Zygote + +function test_zygote(method, u, t; args = [], kwargs = [], name::String) + func = method(args..., u, t; kwargs..., extrapolate = true) + (; u, t) = func + trange = collect(range(minimum(t) - 5.0, maximum(t) + 5.0, step = 0.1)) + trange_exclude = filter(x -> !in(x, t), trange) + @testset "$name, derivatives w.r.t. input" begin + for _t in trange_exclude + adiff = DataInterpolations.derivative(func, _t) + zdiff = only(Zygote.gradient(func, _t)) + zdiff == nothing && (zdiff = 0.0) + @test adiff ≈ zdiff + end + end + @testset "$name, derivatives w.r.t. u" begin + function f(u) + A = method(args..., u, t; kwargs..., extrapolate = true) + out = zero(eltype(u)) + for _t in trange + out += A(_t) + end + out + end + zgrad = only(Zygote.gradient(f, u)) + fgrad = ForwardDiff.gradient(f, u) + @test zgrad ≈ fgrad + end +end + +@testset "LinearInterpolation" begin + u = vcat(collect(1:5), 2 * collect(6:10)) + t = 1.0collect(1:10) + test_zygote( + LinearInterpolation, u, t; name = "Linear Interpolation") +end + +@testset "Quadratic Interpolation" begin + u = [1.0, 4.0, 9.0, 16.0] + t = [1.0, 2.0, 3.0, 4.0] + test_zygote(QuadraticInterpolation, u, t; name = "Quadratic Interpolation") +end + +@testset "Constant Interpolation" begin + u = [0.0, 2.0, 1.0, 3.0, 2.0, 6.0, 5.5, 5.5, 2.7, 5.1, 3.0] + t = collect(0.0:10.0) + test_zygote(ConstantInterpolation, u, t; name = "Constant Interpolation") +end + +@testset "Cubic Hermite Spline" begin + du = [-0.047, -0.058, 0.054, 0.012, -0.068, 0.0] + u = [14.7, 11.51, 10.41, 14.95, 12.24, 11.22] + t = [0.0, 62.25, 109.66, 162.66, 205.8, 252.3] + test_zygote(CubicHermiteSpline, u, t, args = [du], name = "Cubic Hermite Spline") +end + +@testset "Quintic Hermite Spline" begin + ddu = [0.0, -0.00033, 0.0051, -0.0067, 0.0029, 0.0] + du = [-0.047, -0.058, 0.054, 0.012, -0.068, 0.0] + u = [14.7, 11.51, 10.41, 14.95, 12.24, 11.22] + t = [0.0, 62.25, 109.66, 162.66, 205.8, 252.3] + test_zygote( + QuinticHermiteSpline, u, t, args = [ddu, du], name = "Quintic Hermite Spline") +end From 89757b62b97ecbbef722429824e4eb23c528c0cd Mon Sep 17 00:00:00 2001 From: Bart de Koning Date: Sat, 27 Jul 2024 17:21:54 +0200 Subject: [PATCH 2/9] Refactor parameter caching, add zygote tests --- Project.toml | 6 +- docs/src/interface.md | 17 +- ext/DataInterpolationsOptimExt.jl | 5 +- ...ataInterpolationsRegularizationToolsExt.jl | 28 +- src/DataInterpolations.jl | 19 +- src/derivatives.jl | 23 +- src/integral_inverses.jl | 15 +- src/integrals.jl | 54 ++- src/interpolation_caches.jl | 335 +++++++++++------- src/interpolation_methods.jl | 30 +- src/interpolation_utils.jl | 84 +++-- src/online.jl | 96 ++--- test/interpolation_tests.jl | 1 - test/online_tests.jl | 8 +- test/parameter_tests.jl | 12 +- test/runtests.jl | 1 + test/zygote_tests.jl | 66 ++++ 17 files changed, 498 insertions(+), 302 deletions(-) create mode 100644 test/zygote_tests.jl diff --git a/Project.toml b/Project.toml index b06d2adb..963ee8cd 100644 --- a/Project.toml +++ b/Project.toml @@ -7,7 +7,6 @@ FindFirstFunctions = "64ca27bc-2ba2-4a57-88aa-44e436879224" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d" -ReadOnlyArrays = "988b38a3-91fc-5605-94a2-ee2116b3bd83" RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" @@ -33,7 +32,6 @@ LinearAlgebra = "1.10" Optim = "1.6" PrettyTables = "2" QuadGK = "2.9.1" -ReadOnlyArrays = "0.2.0" RecipesBase = "1.3" Reexport = "1" RegularizationTools = "0.6" @@ -41,6 +39,7 @@ SafeTestsets = "0.1" StableRNGs = "1" Symbolics = "5.29" Test = "1" +Zygote = "0.6" julia = "1.10" [extras] @@ -55,6 +54,7 @@ SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["Aqua", "SafeTestsets", "ChainRulesCore", "Optim", "RegularizationTools", "Test", "StableRNGs", "FiniteDifferences", "QuadGK", "ForwardDiff", "Symbolics"] +test = ["Aqua", "SafeTestsets", "ChainRulesCore", "Optim", "RegularizationTools", "Test", "StableRNGs", "FiniteDifferences", "QuadGK", "ForwardDiff", "Symbolics", "Zygote"] diff --git a/docs/src/interface.md b/docs/src/interface.md index ca5e9819..cfc9ed0b 100644 --- a/docs/src/interface.md +++ b/docs/src/interface.md @@ -35,22 +35,7 @@ A2(300.0) The values computed beyond the range of the time points provided during interpolation will not be reliable, as these methods only perform well within the range and the first/last piece polynomial fit is extrapolated on either side which might not reflect the true nature of the data. -The keyword `safetycopy = false` can be passed to make sure no copies of `u` and `t` are made when initializing the interpolation object. - -```@example interface -A3 = QuadraticInterpolation(u, t; safetycopy = false) - -# Check for same memory -u === A3.u.parent -``` - -Note that this does not prevent allocation in every interpolation constructor call, because parameter values are cached for all interpolation types except [`ConstantInterpolation`](@ref). - -Because of the caching of parameters which depend on `u` and `t`, this data should not be mutated. Therefore `u` and `t` are wrapped in a `ReadOnlyArray` from [ReadOnlyArrays.jl](https://github.com/JuliaArrays/ReadOnlyArrays.jl). - -```@repl interface -A3.t[2] = 3.14 -``` +The keyword `cache_parameters = true` can be passed to precalculate parameters at initialization, making evalations cheaper to compute. This is not compatible with modifying `u` and `t`. The default `cache_parameters = false` does however not prevent allocation in every interpolation constructor call. ## Derivatives diff --git a/ext/DataInterpolationsOptimExt.jl b/ext/DataInterpolationsOptimExt.jl index 5528503f..b3bce295 100644 --- a/ext/DataInterpolationsOptimExt.jl +++ b/ext/DataInterpolationsOptimExt.jl @@ -18,9 +18,8 @@ function Curvefit(u, box = false, lb = nothing, ub = nothing; - extrapolate = false, - safetycopy = false) - u, t = munge_data(u, t, safetycopy) + extrapolate = false) + u, t = munge_data(u, t) errfun(t, u, p) = sum(abs2.(u .- model(t, p))) if box == false mfit = optimize(p -> errfun(t, u, p), p0, alg) diff --git a/ext/DataInterpolationsRegularizationToolsExt.jl b/ext/DataInterpolationsRegularizationToolsExt.jl index 10ea3e4c..732ea1bb 100644 --- a/ext/DataInterpolationsRegularizationToolsExt.jl +++ b/ext/DataInterpolationsRegularizationToolsExt.jl @@ -69,8 +69,8 @@ A = RegularizationSmooth(u, t, t̂, wls, wr, d; λ = 1.0, alg = :gcv_svd) """ function RegularizationSmooth(u::AbstractVector, t::AbstractVector, t̂::AbstractVector, wls::AbstractVector, wr::AbstractVector, d::Int = 2; - λ::Real = 1.0, alg::Symbol = :gcv_svd, extrapolate::Bool = false, safetycopy::Bool = true) - u, t = munge_data(u, t, safetycopy) + λ::Real = 1.0, alg::Symbol = :gcv_svd, extrapolate::Bool = false) + u, t = munge_data(u, t) M = _mapping_matrix(t̂, t) Wls½ = LA.diagm(sqrt.(wls)) Wr½ = LA.diagm(sqrt.(wr)) @@ -86,8 +86,8 @@ A = RegularizationSmooth(u, t, d; λ = 1.0, alg = :gcv_svd, extrapolate = false) """ function RegularizationSmooth(u::AbstractVector, t::AbstractVector, d::Int = 2; λ::Real = 1.0, - alg::Symbol = :gcv_svd, extrapolate::Bool = false, safetycopy::Bool = true) - u, t = munge_data(u, t, safetycopy) + alg::Symbol = :gcv_svd, extrapolate::Bool = false) + u, t = munge_data(u, t) t̂ = t N = length(t) M = Array{Float64}(LA.I, N, N) @@ -115,8 +115,8 @@ A = RegularizationSmooth(u, t, t̂, d; λ = 1.0, alg = :gcv_svd, extrapolate = f """ function RegularizationSmooth(u::AbstractVector, t::AbstractVector, t̂::AbstractVector, d::Int = 2; λ::Real = 1.0, alg::Symbol = :gcv_svd, - extrapolate::Bool = false, safetycopy::Bool = true) - u, t = munge_data(u, t, safetycopy) + extrapolate::Bool = false) + u, t = munge_data(u, t) N, N̂ = length(t), length(t̂) M = _mapping_matrix(t̂, t) Wls½ = Array{Float64}(LA.I, N, N) @@ -143,8 +143,8 @@ A = RegularizationSmooth(u, t, t̂, wls, d; λ = 1.0, alg = :gcv_svd, extrapolat """ function RegularizationSmooth(u::AbstractVector, t::AbstractVector, t̂::AbstractVector, wls::AbstractVector, d::Int = 2; λ::Real = 1.0, - alg::Symbol = :gcv_svd, extrapolate::Bool = false, safetycopy::Bool = true) - u, t = munge_data(u, t, safetycopy) + alg::Symbol = :gcv_svd, extrapolate::Bool = false) + u, t = munge_data(u, t) N, N̂ = length(t), length(t̂) M = _mapping_matrix(t̂, t) Wls½ = LA.diagm(sqrt.(wls)) @@ -172,8 +172,8 @@ A = RegularizationSmooth( """ function RegularizationSmooth(u::AbstractVector, t::AbstractVector, t̂::Nothing, wls::AbstractVector, d::Int = 2; λ::Real = 1.0, - alg::Symbol = :gcv_svd, extrapolate::Bool = false, safetycopy::Bool = true) - u, t = munge_data(u, t, safetycopy) + alg::Symbol = :gcv_svd, extrapolate::Bool = false) + u, t = munge_data(u, t) t̂ = t N = length(t) M = Array{Float64}(LA.I, N, N) @@ -202,8 +202,8 @@ A = RegularizationSmooth( """ function RegularizationSmooth(u::AbstractVector, t::AbstractVector, t̂::Nothing, wls::AbstractVector, wr::AbstractVector, d::Int = 2; - λ::Real = 1.0, alg::Symbol = :gcv_svd, extrapolate::Bool = false, safetycopy::Bool = true) - u, t = munge_data(u, t, safetycopy) + λ::Real = 1.0, alg::Symbol = :gcv_svd, extrapolate::Bool = false) + u, t = munge_data(u, t) t̂ = t N = length(t) M = Array{Float64}(LA.I, N, N) @@ -232,8 +232,8 @@ A = RegularizationSmooth( """ function RegularizationSmooth(u::AbstractVector, t::AbstractVector, t̂::Nothing, wls::Symbol, d::Int = 2; λ::Real = 1.0, alg::Symbol = :gcv_svd, - extrapolate::Bool = false, safetycopy::Bool = true) - u, t = munge_data(u, t, safetycopy) + extrapolate::Bool = false) + u, t = munge_data(u, t) t̂ = t N = length(t) M = Array{Float64}(LA.I, N, N) diff --git a/src/DataInterpolations.jl b/src/DataInterpolations.jl index c86a6579..19cb47c0 100644 --- a/src/DataInterpolations.jl +++ b/src/DataInterpolations.jl @@ -7,7 +7,6 @@ abstract type AbstractInterpolation{T} end using LinearAlgebra, RecipesBase using PrettyTables using ForwardDiff -using ReadOnlyArrays import FindFirstFunctions: searchsortedfirstcorrelated, searchsortedlastcorrelated, bracketstrictlymontonic @@ -88,12 +87,6 @@ function Base.showerror(io::IO, e::IntegralNotInvertibleError) print(io, INTEGRAL_NOT_INVERTIBLE_ERROR) end -const MUST_COPY_ERROR = "A copy must be made of u, t to filter missing data" -struct MustCopyError <: Exception end -function Base.showerror(io::IO, e::MustCopyError) - print(io, MUST_COPY_ERROR) -end - export LinearInterpolation, QuadraticInterpolation, LagrangeInterpolation, AkimaInterpolation, ConstantInterpolation, QuadraticSpline, CubicSpline, BSplineInterpolation, BSplineApprox, CubicHermiteSpline, @@ -126,12 +119,12 @@ struct RegularizationSmooth{uType, tType, T, T2, ITP <: AbstractInterpolation{T} Aitp, extrapolate) new{typeof(u), typeof(t), eltype(u), typeof(λ), typeof(Aitp)}( - readonly_wrap(u), - readonly_wrap(û), - readonly_wrap(t), - readonly_wrap(t̂), - readonly_wrap(oftype(u.parent, wls)), - readonly_wrap(oftype(u.parent, wr)), + u, + û, + t, + t̂, + wls, + wr, d, λ, alg, diff --git a/src/derivatives.jl b/src/derivatives.jl index 30c76fd0..01eb18bb 100644 --- a/src/derivatives.jl +++ b/src/derivatives.jl @@ -19,14 +19,16 @@ end function _derivative(A::LinearInterpolation, t::Number, iguess) idx = get_idx(A.t, t, iguess; idx_shift = -1, ub_shift = -2, side = :first) - A.p.slope[idx], idx + slope = get_parameters(A, idx) + slope, idx end function _derivative(A::QuadraticInterpolation, t::Number, iguess) i₀, i₁, i₂ = _quad_interp_indices(A, t, iguess) - du₀ = A.p.l₀[i₀] * (2t - A.t[i₁] - A.t[i₂]) - du₁ = A.p.l₁[i₀] * (2t - A.t[i₀] - A.t[i₂]) - du₂ = A.p.l₂[i₀] * (2t - A.t[i₀] - A.t[i₁]) + l₀, l₁, l₂ = get_parameters(A, i₀) + du₀ = l₀ * (2t - A.t[i₁] - A.t[i₂]) + du₁ = l₁ * (2t - A.t[i₀] - A.t[i₂]) + du₂ = l₂ * (2t - A.t[i₀] - A.t[i₁]) return @views @. du₀ + du₁ + du₂, i₀ end @@ -129,7 +131,7 @@ end # QuadraticSpline Interpolation function _derivative(A::QuadraticSpline{<:AbstractVector}, t::Number, iguess) idx = get_idx(A.t, t, iguess; lb = 2, ub_shift = 0, side = :first) - σ = A.p.σ[idx - 1] + σ = get_parameters(A, idx - 1) A.z[idx - 1] + 2σ * (t - A.t[idx - 1]), idx end @@ -139,8 +141,9 @@ function _derivative(A::CubicSpline{<:AbstractVector}, t::Number, iguess) Δt₁ = t - A.t[idx] Δt₂ = A.t[idx + 1] - t dI = (-A.z[idx] * Δt₂^2 + A.z[idx + 1] * Δt₁^2) / (2A.h[idx + 1]) - dC = A.p.c₁[idx] - dD = -A.p.c₂[idx] + c₁, c₂ = get_parameters(A, idx) + dC = c₁ + dD = -c₂ dI + dC + dD, idx end @@ -193,7 +196,8 @@ function _derivative( Δt₀ = t - A.t[idx] Δt₁ = t - A.t[idx + 1] out = A.du[idx] - out += Δt₀ * (Δt₀ * A.p.c₂[idx] + 2(A.p.c₁[idx] + Δt₁ * A.p.c₂[idx])) + c₁, c₂ = get_parameters(A, idx) + out += Δt₀ * (Δt₀ * c₂ + 2(c₁ + Δt₁ * c₂)) out, idx end @@ -204,7 +208,8 @@ function _derivative( Δt₀ = t - A.t[idx] Δt₁ = t - A.t[idx + 1] out = A.du[idx] + A.ddu[idx] * Δt₀ + c₁, c₂, c₃ = get_parameters(A, idx) out += Δt₀^2 * - (3A.p.c₁[idx] + (3Δt₁ + Δt₀) * A.p.c₂[idx] + (3Δt₁^2 + Δt₀ * 2Δt₁) * A.p.c₃[idx]) + (3c₁ + (3Δt₁ + Δt₀) * c₂ + (3Δt₁^2 + Δt₀ * 2Δt₁) * c₃) out, idx end diff --git a/src/integral_inverses.jl b/src/integral_inverses.jl index 4437726e..38c14b14 100644 --- a/src/integral_inverses.jl +++ b/src/integral_inverses.jl @@ -40,10 +40,9 @@ struct LinearInterpolationIntInv{uType, tType, itpType, T} <: extrapolate::Bool idx_prev::Base.RefValue{Int} itp::itpType - safetycopy::Bool function LinearInterpolationIntInv(u, t, A) new{typeof(u), typeof(t), typeof(A), eltype(u)}( - u, t, A.extrapolate, Ref(1), A, A.safetycopy) + u, t, A.extrapolate, Ref(1), A) end end @@ -51,9 +50,11 @@ function invertible_integral(A::LinearInterpolation{<:AbstractVector{<:Number}}) return all(A.u .> 0) end +get_I(A::AbstractInterpolation) = isnothing(A.I) ? cumulative_integral(A) : A.I + function invert_integral(A::LinearInterpolation{<:AbstractVector{<:Number}}) !invertible_integral(A) && throw(IntegralNotInvertibleError()) - return LinearInterpolationIntInv(A.t, A.I, A) + return LinearInterpolationIntInv(A.t, get_I(A), A) end function _interpolate( @@ -61,7 +62,8 @@ function _interpolate( idx = get_idx(A.t, t, iguess) Δt = t - A.t[idx] x = A.itp.u[idx] - u = A.u[idx] + 2Δt / (x + sqrt(x^2 + A.itp.p.slope[idx] * 2Δt)) + slope = get_parameters(A.itp, idx) + u = A.u[idx] + 2Δt / (x + sqrt(x^2 + slope * 2Δt)) u, idx end @@ -84,10 +86,9 @@ struct ConstantInterpolationIntInv{uType, tType, itpType, T} <: extrapolate::Bool idx_prev::Base.RefValue{Int} itp::itpType - safetycopy::Bool function ConstantInterpolationIntInv(u, t, A) new{typeof(u), typeof(t), typeof(A), eltype(u)}( - u, t, A.extrapolate, Ref(1), A, A.safetycopy + u, t, A.extrapolate, Ref(1), A ) end end @@ -98,7 +99,7 @@ end function invert_integral(A::ConstantInterpolation{<:AbstractVector{<:Number}}) !invertible_integral(A) && throw(IntegralNotInvertibleError()) - return ConstantInterpolationIntInv(A.t, A.I, A) + return ConstantInterpolationIntInv(A.t, get_I(A), A) end function _interpolate( diff --git a/src/integrals.jl b/src/integrals.jl index 3040189f..03ea26c5 100644 --- a/src/integrals.jl +++ b/src/integrals.jl @@ -12,14 +12,24 @@ function integral(A::AbstractInterpolation, t1::Number, t2::Number) # the index less than t2 idx2 = get_idx(A.t, t2, 0; idx_shift = -1, side = :first) - total = A.I[idx2] - A.I[idx1] - return if t1 == t2 - zero(total) + if A.cache_parameters + total = A.I[idx2] - A.I[idx1] + return if t1 == t2 + zero(total) + else + total += _integral(A, idx1, A.t[idx1]) + total -= _integral(A, idx1, t1) + total += _integral(A, idx2, t2) + total -= _integral(A, idx2, A.t[idx2]) + total + end else - total += _integral(A, idx1, A.t[idx1]) - total -= _integral(A, idx1, t1) - total += _integral(A, idx2, t2) - total -= _integral(A, idx2, A.t[idx2]) + total = zero(eltype(A.u)) + for idx in idx1:idx2 + lt1 = idx == idx1 ? t1 : A.t[idx] + lt2 = idx == idx2 ? t2 : A.t[idx + 1] + total += _integral(A, idx, lt2) - _integral(A, idx, lt1) + end total end end @@ -28,7 +38,8 @@ function _integral(A::LinearInterpolation{<:AbstractVector{<:Number}}, idx::Number, t::Number) Δt = t - A.t[idx] - Δt * (A.u[idx] + A.p.slope[idx] * Δt / 2) + slope = get_parameters(A, idx) + Δt * (A.u[idx] + slope * Δt / 2) end function _integral( @@ -52,24 +63,27 @@ function _integral(A::QuadraticInterpolation{<:AbstractVector{<:Number}}, t₂ = A.t[idx + 2] t_sq = (t^2) / 3 - Iu₀ = A.p.l₀[idx] * t * (t_sq - t * (t₁ + t₂) / 2 + t₁ * t₂) - Iu₁ = A.p.l₁[idx] * t * (t_sq - t * (t₀ + t₂) / 2 + t₀ * t₂) - Iu₂ = A.p.l₂[idx] * t * (t_sq - t * (t₀ + t₁) / 2 + t₀ * t₁) + l₀, l₁, l₂ = get_parameters(A, idx) + Iu₀ = l₀ * t * (t_sq - t * (t₁ + t₂) / 2 + t₁ * t₂) + Iu₁ = l₁ * t * (t_sq - t * (t₀ + t₂) / 2 + t₀ * t₂) + Iu₂ = l₂ * t * (t_sq - t * (t₀ + t₁) / 2 + t₀ * t₁) return Iu₀ + Iu₁ + Iu₂ end function _integral(A::QuadraticSpline{<:AbstractVector{<:Number}}, idx::Number, t::Number) Cᵢ = A.u[idx] Δt = t - A.t[idx] - return A.z[idx] * Δt^2 / 2 + A.p.σ[idx] * Δt^3 / 3 + Cᵢ * Δt + σ = get_parameters(A, idx) + return A.z[idx] * Δt^2 / 2 + σ * Δt^3 / 3 + Cᵢ * Δt end function _integral(A::CubicSpline{<:AbstractVector{<:Number}}, idx::Number, t::Number) Δt₁sq = (t - A.t[idx])^2 / 2 Δt₂sq = (A.t[idx + 1] - t)^2 / 2 II = (-A.z[idx] * Δt₂sq^2 + A.z[idx + 1] * Δt₁sq^2) / (6A.h[idx + 1]) - IC = A.p.c₁[idx] * Δt₁sq - ID = -A.p.c₂[idx] * Δt₂sq + c₁, c₂ = get_parameters(A, idx) + IC = c₁ * Δt₁sq + ID = -c₂ * Δt₂sq II + IC + ID end @@ -91,8 +105,9 @@ function _integral( Δt₀ = t - A.t[idx] Δt₁ = t - A.t[idx + 1] out = Δt₀ * (A.u[idx] + Δt₀ * A.du[idx] / 2) - p = A.p.c₁[idx] + Δt₁ * A.p.c₂[idx] - dp = A.p.c₂[idx] + c₁, c₂ = get_parameters(A, idx) + p = c₁ + Δt₁ * c₂ + dp = c₂ out += Δt₀^3 / 3 * (p - dp * Δt₀ / 4) out end @@ -103,9 +118,10 @@ function _integral( Δt₀ = t - A.t[idx] Δt₁ = t - A.t[idx + 1] out = Δt₀ * (A.u[idx] + A.du[idx] * Δt₀ / 2 + A.ddu[idx] * Δt₀^2 / 6) - p = A.p.c₁[idx] + A.p.c₂[idx] * Δt₁ + A.p.c₃[idx] * Δt₁^2 - dp = A.p.c₂[idx] + 2A.p.c₃[idx] * Δt₁ - ddp = 2A.p.c₃[idx] + c₁, c₂, c₃ = get_parameters(A, idx) + p = c₁ + c₂ * Δt₁ + c₃ * Δt₁^2 + dp = c₂ + 2c₃ * Δt₁ + ddp = 2c₃ out += Δt₀^4 / 4 * (p - Δt₀ / 5 * dp + Δt₀^2 / 30 * ddp) out end diff --git a/src/interpolation_caches.jl b/src/interpolation_caches.jl index c7274471..286bf6bc 100644 --- a/src/interpolation_caches.jl +++ b/src/interpolation_caches.jl @@ -1,5 +1,5 @@ """ - LinearInterpolation(u, t; extrapolate = false) + LinearInterpolation(u, t; extrapolate = false, cache_parameters = false) It is the method of interpolating between the data points using a linear polynomial. For any point, two data points one each side are chosen and connected with a line. Extrapolation extends the last linear polynomial on each side. @@ -12,7 +12,7 @@ Extrapolation extends the last linear polynomial on each side. ## Keyword Arguments - `extrapolate`: boolean value to allow extrapolation. Defaults to `false`. - - `safetycopy`: boolean value to make a copy of `u` and `t`. Defaults to `true`. + - `cache_parameters`: precompute parameters at initialization for faster interpolation computations. Note: if activated, `u` and `t` should not be modified. Defaults to `false`. """ struct LinearInterpolation{uType, tType, IType, pType, T} <: AbstractInterpolation{T} u::uType @@ -21,23 +21,33 @@ struct LinearInterpolation{uType, tType, IType, pType, T} <: AbstractInterpolati p::LinearParameterCache{pType} extrapolate::Bool idx_prev::Base.RefValue{Int} - safetycopy::Bool - function LinearInterpolation(u, t, I, p, extrapolate, safetycopy) + cache_parameters::Bool + function LinearInterpolation(u, t, I, p, extrapolate, cache_parameters) new{typeof(u), typeof(t), typeof(I), typeof(p.slope), eltype(u)}( - u, t, I, p, extrapolate, Ref(1), safetycopy) + u, t, I, p, extrapolate, Ref(1), cache_parameters) end end -function LinearInterpolation(u, t; extrapolate = false, safetycopy = true) - u, t = munge_data(u, t, safetycopy) - p = LinearParameterCache(u, t) - A = LinearInterpolation(u, t, nothing, p, extrapolate, safetycopy) - I = cumulative_integral(A) - LinearInterpolation(u, t, I, p, extrapolate, safetycopy) +function LinearInterpolation(u, t; extrapolate = false, cache_parameters = false) + u, t = munge_data(u, t) + p = if cache_parameters + LinearParameterCache(u, t) + else + LinearParameterCache(nothing) + end + + A = LinearInterpolation(u, t, nothing, p, extrapolate, cache_parameters) + + if cache_parameters + I = cumulative_integral(A) + A = LinearInterpolation(u, t, I, p, extrapolate, cache_parameters) + end + + A end """ - QuadraticInterpolation(u, t, mode = :Forward; extrapolate = false) + QuadraticInterpolation(u, t, mode = :Forward; cache_parameters = false) It is the method of interpolating between the data points using quadratic polynomials. For any point, three data points nearby are taken to fit a quadratic polynomial. Extrapolation extends the last quadratic polynomial on each side. @@ -51,7 +61,7 @@ Extrapolation extends the last quadratic polynomial on each side. ## Keyword Arguments - `extrapolate`: boolean value to allow extrapolation. Defaults to `false`. - - `safetycopy`: boolean value to make a copy of `u` and `t`. Defaults to `true`. + - `cache_parameters`: precompute parameters at initialization for faster interpolation computations. Note: if activated, `u` and `t` should not be modified. Defaults to `false`. """ struct QuadraticInterpolation{uType, tType, IType, pType, T} <: AbstractInterpolation{T} u::uType @@ -61,25 +71,35 @@ struct QuadraticInterpolation{uType, tType, IType, pType, T} <: AbstractInterpol mode::Symbol extrapolate::Bool idx_prev::Base.RefValue{Int} - safetycopy::Bool - function QuadraticInterpolation(u, t, I, p, mode, extrapolate, safetycopy) + cache_parameters::Bool + function QuadraticInterpolation(u, t, I, p, mode, extrapolate, cache_parameters) mode ∈ (:Forward, :Backward) || error("mode should be :Forward or :Backward for QuadraticInterpolation") new{typeof(u), typeof(t), typeof(I), typeof(p.l₀), eltype(u)}( - u, t, I, p, mode, extrapolate, Ref(1), safetycopy) + u, t, I, p, mode, extrapolate, Ref(1), cache_parameters) end end -function QuadraticInterpolation(u, t, mode; extrapolate = false, safetycopy = true) - u, t = munge_data(u, t, safetycopy) - p = QuadraticParameterCache(u, t) - A = QuadraticInterpolation(u, t, nothing, p, mode, extrapolate, safetycopy) - I = cumulative_integral(A) - QuadraticInterpolation(u, t, I, p, mode, extrapolate, safetycopy) +function QuadraticInterpolation(u, t, mode; extrapolate = false, cache_parameters = false) + u, t = munge_data(u, t) + p = if cache_parameters + QuadraticParameterCache(u, t) + else + QuadraticParameterCache(nothing, nothing, nothing) + end + + A = QuadraticInterpolation(u, t, nothing, p, mode, extrapolate, cache_parameters) + + if cache_parameters + I = cumulative_integral(A) + A = QuadraticInterpolation(u, t, I, p, mode, extrapolate, cache_parameters) + end + + A end -function QuadraticInterpolation(u, t; extrapolate = false, safetycopy = true) - QuadraticInterpolation(u, t, :Forward; extrapolate, safetycopy) +function QuadraticInterpolation(u, t; extrapolate = false, cache_parameters = false) + QuadraticInterpolation(u, t, :Forward; extrapolate, cache_parameters) end """ @@ -96,7 +116,6 @@ It is the method of interpolation using Lagrange polynomials of (k-1)th order pa ## Keyword Arguments - `extrapolate`: boolean value to allow extrapolation. Defaults to `false`. - - `safetycopy`: boolean value to make a copy of `u` and `t`. Defaults to `true`. """ struct LagrangeInterpolation{uType, tType, T, bcacheType} <: AbstractInterpolation{T} @@ -107,8 +126,7 @@ struct LagrangeInterpolation{uType, tType, T, bcacheType} <: idxs::Vector{Int} extrapolate::Bool idx_prev::Base.RefValue{Int} - safetycopy::Bool - function LagrangeInterpolation(u, t, n, extrapolate, safetycopy) + function LagrangeInterpolation(u, t, n, extrapolate) bcache = zeros(eltype(u[1]), n + 1) idxs = zeros(Int, n + 1) fill!(bcache, NaN) @@ -118,23 +136,22 @@ struct LagrangeInterpolation{uType, tType, T, bcacheType} <: bcache, idxs, extrapolate, - Ref(1), - safetycopy + Ref(1) ) end end function LagrangeInterpolation( - u, t, n = length(t) - 1; extrapolate = false, safetycopy = true) - u, t = munge_data(u, t, safetycopy) + u, t, n = length(t) - 1; extrapolate = false) + u, t = munge_data(u, t) if n != length(t) - 1 error("Currently only n=length(t) - 1 is supported") end - LagrangeInterpolation(u, t, n, extrapolate, safetycopy) + LagrangeInterpolation(u, t, n, extrapolate) end """ - AkimaInterpolation(u, t; extrapolate = false) + AkimaInterpolation(u, t; extrapolate = false, cache_parameters = false) It is a spline interpolation built from cubic polynomials. It forms a continuously differentiable function. For more details, refer: https://en.wikipedia.org/wiki/Akima_spline. Extrapolation extends the last cubic polynomial on each side. @@ -147,7 +164,7 @@ Extrapolation extends the last cubic polynomial on each side. ## Keyword Arguments - `extrapolate`: boolean value to allow extrapolation. Defaults to `false`. - - `safetycopy`: boolean value to make a copy of `u` and `t`. Defaults to `true`. + - `cache_parameters`: precompute parameters at initialization for faster interpolation computations. Note: if activated, `u` and `t` should not be modified. Defaults to `false`. """ struct AkimaInterpolation{uType, tType, IType, bType, cType, dType, T} <: AbstractInterpolation{T} @@ -159,8 +176,8 @@ struct AkimaInterpolation{uType, tType, IType, bType, cType, dType, T} <: d::dType extrapolate::Bool idx_prev::Base.RefValue{Int} - safetycopy::Bool - function AkimaInterpolation(u, t, I, b, c, d, extrapolate, safetycopy) + cache_parameters::Bool + function AkimaInterpolation(u, t, I, b, c, d, extrapolate, cache_parameters) new{typeof(u), typeof(t), typeof(I), typeof(b), typeof(c), typeof(d), eltype(u)}(u, t, @@ -170,13 +187,13 @@ struct AkimaInterpolation{uType, tType, IType, bType, cType, dType, T} <: d, extrapolate, Ref(1), - safetycopy + cache_parameters ) end end -function AkimaInterpolation(u, t; extrapolate = false, safetycopy = true) - u, t = munge_data(u, t, safetycopy) +function AkimaInterpolation(u, t; extrapolate = false, cache_parameters = false) + u, t = munge_data(u, t) n = length(t) dt = diff(t) m = Array{eltype(u)}(undef, n + 3) @@ -197,13 +214,18 @@ function AkimaInterpolation(u, t; extrapolate = false, safetycopy = true) c = (3.0 .* m[3:(end - 2)] .- 2.0 .* b[1:(end - 1)] .- b[2:end]) ./ dt d = (b[1:(end - 1)] .+ b[2:end] .- 2.0 .* m[3:(end - 2)]) ./ dt .^ 2 - A = AkimaInterpolation(u, t, nothing, b, c, d, extrapolate, safetycopy) - I = cumulative_integral(A) - AkimaInterpolation(u, t, I, b, c, d, extrapolate, safetycopy) + A = AkimaInterpolation(u, t, nothing, b, c, d, extrapolate, cache_parameters) + + if cache_parameters + I = cumulative_integral(A) + A = AkimaInterpolation(u, t, I, b, c, d, extrapolate, cache_parameters) + end + + A end """ - ConstantInterpolation(u, t; dir = :left, extrapolate = false) + ConstantInterpolation(u, t; dir = :left, extrapolate = false, cache_parameters = false) It is the method of interpolating using a constant polynomial. For any point, two adjacent data points are found on either side (left and right). The value at that point depends on `dir`. If it is `:left`, then the value at the left point is chosen and if it is `:right`, the value at the right point is chosen. @@ -218,7 +240,7 @@ Extrapolation extends the last constant polynomial at the end points on each sid - `dir`: indicates which value should be used for interpolation (`:left` or `:right`). - `extrapolate`: boolean value to allow extrapolation. Defaults to `false`. - - `safetycopy`: boolean value to make a copy of `u` and `t`. Defaults to `true`. + - `cache_parameters`: precompute parameters at initialization for faster interpolation computations. Note: if activated, `u` and `t` should not be modified. Defaults to `false`. """ struct ConstantInterpolation{uType, tType, IType, T} <: AbstractInterpolation{T} u::uType @@ -228,22 +250,28 @@ struct ConstantInterpolation{uType, tType, IType, T} <: AbstractInterpolation{T} dir::Symbol # indicates if value to the $dir should be used for the interpolation extrapolate::Bool idx_prev::Base.RefValue{Int} - safetycopy::Bool - function ConstantInterpolation(u, t, I, dir, extrapolate, safetycopy) + cache_parameters::Bool + function ConstantInterpolation(u, t, I, dir, extrapolate, cache_parameters) new{typeof(u), typeof(t), typeof(I), eltype(u)}( - u, t, I, nothing, dir, extrapolate, Ref(1), safetycopy) + u, t, I, nothing, dir, extrapolate, Ref(1), cache_parameters) end end -function ConstantInterpolation(u, t; dir = :left, extrapolate = false, safetycopy = true) - u, t = munge_data(u, t, safetycopy) - A = ConstantInterpolation(u, t, nothing, dir, extrapolate, safetycopy) - I = cumulative_integral(A) - ConstantInterpolation(u, t, I, dir, extrapolate, safetycopy) +function ConstantInterpolation( + u, t; dir = :left, extrapolate = false, cache_parameters = false) + u, t = munge_data(u, t) + A = ConstantInterpolation(u, t, nothing, dir, extrapolate, cache_parameters) + + if cache_parameters + I = cumulative_integral(A) + A = ConstantInterpolation(u, t, I, dir, extrapolate, cache_parameters) + end + + A end """ - QuadraticSpline(u, t; extrapolate = false) + QuadraticSpline(u, t; extrapolate = false, cache_parameters = false) It is a spline interpolation using piecewise quadratic polynomials between each pair of data points. Its first derivative is also continuous. Extrapolation extends the last quadratic polynomial on each side. @@ -256,7 +284,7 @@ Extrapolation extends the last quadratic polynomial on each side. ## Keyword Arguments - `extrapolate`: boolean value to allow extrapolation. Defaults to `false`. - - `safetycopy`: boolean value to make a copy of `u` and `t`. Defaults to `true`. + - `cache_parameters`: precompute parameters at initialization for faster interpolation computations. Note: if activated, `u` and `t` should not be modified. Defaults to `false`. """ struct QuadraticSpline{uType, tType, IType, pType, tAType, dType, zType, T} <: AbstractInterpolation{T} @@ -269,8 +297,8 @@ struct QuadraticSpline{uType, tType, IType, pType, tAType, dType, zType, T} <: z::zType extrapolate::Bool idx_prev::Base.RefValue{Int} - safetycopy::Bool - function QuadraticSpline(u, t, I, p, tA, d, z, extrapolate, safetycopy) + cache_parameters::Bool + function QuadraticSpline(u, t, I, p, tA, d, z, extrapolate, cache_parameters) new{typeof(u), typeof(t), typeof(I), typeof(p.σ), typeof(tA), typeof(d), typeof(z), eltype(u)}(u, t, @@ -281,15 +309,15 @@ struct QuadraticSpline{uType, tType, IType, pType, tAType, dType, zType, T} <: z, extrapolate, Ref(1), - safetycopy + cache_parameters ) end end function QuadraticSpline( u::uType, t; extrapolate = false, - safetycopy = true) where {uType <: AbstractVector{<:Number}} - u, t = munge_data(u, t, safetycopy) + cache_parameters = false) where {uType <: AbstractVector{<:Number}} + u, t = munge_data(u, t) s = length(t) dl = ones(eltype(t), s - 1) d_tmp = ones(eltype(t), s) @@ -301,15 +329,27 @@ function QuadraticSpline( d = map(i -> i == 1 ? typed_zero : 2 // 1 * (u[i] - u[i - 1]) / (t[i] - t[i - 1]), 1:s) z = tA \ d - p = QuadraticSplineParameterCache(z, t) - A = QuadraticSpline(u, t, nothing, p, tA, d, z, extrapolate, safetycopy) - I = cumulative_integral(A) - QuadraticSpline(u, t, I, p, tA, d, z, extrapolate, safetycopy) + + p = if cache_parameters + QuadraticSplineParameterCache(z, t) + else + QuadraticSplineParameterCache(nothing) + end + + A = QuadraticSpline(u, t, nothing, p, tA, d, z, extrapolate, cache_parameters) + + if cache_parameters + I = cumulative_integral(A) + A = QuadraticSpline(u, t, I, p, tA, d, z, extrapolate, cache_parameters) + end + + A end function QuadraticSpline( - u::uType, t; extrapolate = false, safetycopy = true) where {uType <: AbstractVector} - u, t = munge_data(u, t, safetycopy) + u::uType, t; extrapolate = false, cache_parameters = false) where {uType <: + AbstractVector} + u, t = munge_data(u, t) s = length(t) dl = ones(eltype(t), s - 1) d_tmp = ones(eltype(t), s) @@ -322,14 +362,23 @@ function QuadraticSpline( d = transpose(reshape(reduce(hcat, d_), :, s)) z_ = reshape(transpose(tA \ d), size(u[1])..., :) z = [z_s for z_s in eachslice(z_, dims = ndims(z_))] - p = QuadraticSplineParameterCache(z, t) - A = QuadraticSpline(u, t, nothing, p, tA, d, z, extrapolate, safetycopy) - I = cumulative_integral(A) - QuadraticSpline(u, t, I, p, tA, d, z, extrapolate, safetycopy) + p = if cache_parameters + QuadraticSplineParameterCache(z, t) + else + QuadraticSplineParameterCache(nothing) + end + A = QuadraticSpline(u, t, nothing, p, tA, d, z, extrapolate, cache_parameters) + + if cache_parameters + I = cumulative_integral(A) + A = QuadraticSpline(u, t, I, p, tA, d, z, extrapolate, cache_parameters) + end + + A end """ - CubicSpline(u, t; extrapolate = false) + CubicSpline(u, t; extrapolate = false, cache_parameters = false) It is a spline interpolation using piecewise cubic polynomials between each pair of data points. Its first and second derivative is also continuous. Second derivative on both ends are zero, which are also called "natural" boundary conditions. Extrapolation extends the last cubic polynomial on each side. @@ -342,7 +391,7 @@ Second derivative on both ends are zero, which are also called "natural" boundar ## Keyword Arguments - `extrapolate`: boolean value to allow extrapolation. Defaults to `false`. - - `safetycopy`: boolean value to make a copy of `u` and `t`. Defaults to `true`. + - `cache_parameters`: precompute parameters at initialization for faster interpolation computations. Note: if activated, `u` and `t` should not be modified. Defaults to `false`. """ struct CubicSpline{uType, tType, IType, pType, hType, zType, T} <: AbstractInterpolation{T} u::uType @@ -353,8 +402,8 @@ struct CubicSpline{uType, tType, IType, pType, hType, zType, T} <: AbstractInter z::zType extrapolate::Bool idx_prev::Base.RefValue{Int} - safetycopy::Bool - function CubicSpline(u, t, I, p, h, z, extrapolate, safetycopy) + cache_parameters::Bool + function CubicSpline(u, t, I, p, h, z, extrapolate, cache_parameters) new{typeof(u), typeof(t), typeof(I), typeof(p.c₁), typeof(h), typeof(z), eltype(u)}( u, t, @@ -364,15 +413,16 @@ struct CubicSpline{uType, tType, IType, pType, hType, zType, T} <: AbstractInter z, extrapolate, Ref(1), - safetycopy + cache_parameters ) end end function CubicSpline(u::uType, t; - extrapolate = false, safetycopy = true) where {uType <: AbstractVector{<:Number}} - u, t = munge_data(u, t, safetycopy) + extrapolate = false, cache_parameters = false) where {uType <: + AbstractVector{<:Number}} + u, t = munge_data(u, t) n = length(t) - 1 h = vcat(0, map(k -> t[k + 1] - t[k], 1:(length(t) - 1)), 0) dl = vcat(h[2:n], zero(eltype(h))) @@ -389,15 +439,25 @@ function CubicSpline(u::uType, 6(u[i + 1] - u[i]) / h[i + 1] - 6(u[i] - u[i - 1]) / h[i], 1:(n + 1)) z = tA \ d - p = CubicSplineParameterCache(u, h, z) - A = CubicSpline(u, t, nothing, p, h[1:(n + 1)], z, extrapolate, safetycopy) - I = cumulative_integral(A) - CubicSpline(u, t, I, p, h[1:(n + 1)], z, extrapolate, safetycopy) + p = if cache_parameters + CubicSplineParameterCache(u, h, z) + else + CubicSplineParameterCache(nothing, nothing) + end + A = CubicSpline(u, t, nothing, p, h[1:(n + 1)], z, extrapolate, cache_parameters) + + if cache_parameters + I = cumulative_integral(A) + A = CubicSpline(u, t, I, p, h[1:(n + 1)], z, extrapolate, cache_parameters) + end + + A end function CubicSpline( - u::uType, t; extrapolate = false, safetycopy = true) where {uType <: AbstractVector} - u, t = munge_data(u, t, safetycopy) + u::uType, t; extrapolate = false, cache_parameters = false) where {uType <: + AbstractVector} + u, t = munge_data(u, t) n = length(t) - 1 h = vcat(0, map(k -> t[k + 1] - t[k], 1:(length(t) - 1)), 0) dl = vcat(h[2:n], zero(eltype(h))) @@ -411,10 +471,20 @@ function CubicSpline( d = transpose(reshape(reduce(hcat, d_), :, n + 1)) z_ = reshape(transpose(tA \ d), size(u[1])..., :) z = [z_s for z_s in eachslice(z_, dims = ndims(z_))] - p = CubicSplineParameterCache(u, h, z) - A = CubicSpline(u, t, nothing, p, h[1:(n + 1)], z, extrapolate, safetycopy) - I = cumulative_integral(A) - CubicSpline(u, t, I, p, h[1:(n + 1)], z, extrapolate, safetycopy) + p = if cache_parameters + CubicSplineParameterCache(u, h, z) + else + CubicSplineParameterCache(nothing, nothing) + end + + A = CubicSpline(u, t, nothing, p, h[1:(n + 1)], z, extrapolate, cache_parameters) + + if cache_parameters + I = cumulative_integral(A) + A = CubicSpline(u, t, I, p, h[1:(n + 1)], z, extrapolate, cache_parameters) + end + + A end """ @@ -434,7 +504,6 @@ Extrapolation is a constant polynomial of the end points on each side. ## Keyword Arguments - `extrapolate`: boolean value to allow extrapolation. Defaults to `false`. - - `safetycopy`: boolean value to make a copy of `u` and `t`. Defaults to `true`. """ struct BSplineInterpolation{uType, tType, pType, kType, cType, NType, T} <: AbstractInterpolation{T} @@ -449,7 +518,6 @@ struct BSplineInterpolation{uType, tType, pType, kType, cType, NType, T} <: knotVecType::Symbol extrapolate::Bool idx_prev::Base.RefValue{Int} - safetycopy::Bool function BSplineInterpolation(u, t, d, @@ -459,8 +527,7 @@ struct BSplineInterpolation{uType, tType, pType, kType, cType, NType, T} <: N, pVecType, knotVecType, - extrapolate, - safetycopy) + extrapolate) new{typeof(u), typeof(t), typeof(p), typeof(k), typeof(c), typeof(N), eltype(u)}(u, t, d, @@ -471,15 +538,14 @@ struct BSplineInterpolation{uType, tType, pType, kType, cType, NType, T} <: pVecType, knotVecType, extrapolate, - Ref(1), - safetycopy + Ref(1) ) end end function BSplineInterpolation( - u, t, d, pVecType, knotVecType; extrapolate = false, safetycopy = true) - u, t = munge_data(u, t, safetycopy) + u, t, d, pVecType, knotVecType; extrapolate = false) + u, t = munge_data(u, t) n = length(t) n < d + 1 && error("BSplineInterpolation needs at least d + 1, i.e. $(d+1) points.") s = zero(eltype(u)) @@ -543,11 +609,11 @@ function BSplineInterpolation( c = vec(N \ u[:, :]) N = zeros(eltype(t), n) BSplineInterpolation( - u, t, d, p, k, c, N, pVecType, knotVecType, extrapolate, safetycopy) + u, t, d, p, k, c, N, pVecType, knotVecType, extrapolate) end """ - BSplineApprox(u, t, d, h, pVecType, knotVecType; extrapolate = false) + BSplineApprox(u, t, d, h, pVecType, knotVecType) It is a regression based B-spline. The argument choices are the same as the `BSplineInterpolation`, with the additional parameter `h < length(t)` which is the number of control points to use, with smaller `h` indicating more smoothing. For more information, refer http://www.cad.zju.edu.cn/home/zhx/GM/009/00-bsia.pdf. @@ -565,7 +631,6 @@ Extrapolation is a constant polynomial of the end points on each side. ## Keyword Arguments - `extrapolate`: boolean value to allow extrapolation. Defaults to `false`. - - `safetycopy`: boolean value to make a copy of `u` and `t`. Defaults to `true`. """ struct BSplineApprox{uType, tType, pType, kType, cType, NType, T} <: AbstractInterpolation{T} @@ -581,7 +646,6 @@ struct BSplineApprox{uType, tType, pType, kType, cType, NType, T} <: knotVecType::Symbol extrapolate::Bool idx_prev::Base.RefValue{Int} - safetycopy::Bool function BSplineApprox(u, t, d, @@ -592,8 +656,7 @@ struct BSplineApprox{uType, tType, pType, kType, cType, NType, T} <: N, pVecType, knotVecType, - extrapolate, - safetycopy + extrapolate ) new{typeof(u), typeof(t), typeof(p), typeof(k), typeof(c), typeof(N), eltype(u)}(u, t, @@ -606,15 +669,14 @@ struct BSplineApprox{uType, tType, pType, kType, cType, NType, T} <: pVecType, knotVecType, extrapolate, - Ref(1), - safetycopy::Bool + Ref(1) ) end end function BSplineApprox( - u, t, d, h, pVecType, knotVecType; extrapolate = false, safetycopy = true) - u, t = munge_data(u, t, safetycopy) + u, t, d, h, pVecType, knotVecType; extrapolate = false) + u, t = munge_data(u, t) n = length(t) h < d + 1 && error("BSplineApprox needs at least d + 1, i.e. $(d+1) control points.") s = zero(eltype(u)) @@ -698,11 +760,12 @@ function BSplineApprox( P = M \ Q c[2:(end - 1)] .= vec(P) N = zeros(eltype(t), h) - BSplineApprox(u, t, d, h, p, k, c, N, pVecType, knotVecType, extrapolate, safetycopy) + BSplineApprox( + u, t, d, h, p, k, c, N, pVecType, knotVecType, extrapolate) end """ - CubicHermiteSpline(du, u, t; extrapolate = false) + CubicHermiteSpline(du, u, t; extrapolate = false, cache_parameters = false) It is a Cubic Hermite interpolation, which is a piece-wise third degree polynomial such that the value and the first derivative are equal to given values in the data points. @@ -715,7 +778,7 @@ It is a Cubic Hermite interpolation, which is a piece-wise third degree polynomi ## Keyword Arguments - `extrapolate`: boolean value to allow extrapolation. Defaults to `false`. - - `safetycopy`: boolean value to make a copy of `u` and `t`. Defaults to `true`. + - `cache_parameters`: precompute parameters at initialization for faster interpolation computations. Note: if activated, `u` and `t` should not be modified. Defaults to `false`. """ struct CubicHermiteSpline{uType, tType, IType, duType, pType, T} <: AbstractInterpolation{T} du::duType @@ -725,24 +788,33 @@ struct CubicHermiteSpline{uType, tType, IType, duType, pType, T} <: AbstractInte p::CubicHermiteParameterCache{pType} extrapolate::Bool idx_prev::Base.RefValue{Int} - safetycopy::Bool - function CubicHermiteSpline(du, u, t, I, p, extrapolate, safetycopy) + cache_parameters::Bool + function CubicHermiteSpline(du, u, t, I, p, extrapolate, cache_parameters) new{typeof(u), typeof(t), typeof(I), typeof(du), typeof(p.c₁), eltype(u)}( - du, u, t, I, p, extrapolate, Ref(1), safetycopy) + du, u, t, I, p, extrapolate, Ref(1), cache_parameters) end end -function CubicHermiteSpline(du, u, t; extrapolate = false, safetycopy = true) +function CubicHermiteSpline(du, u, t; extrapolate = false, cache_parameters = false) @assert length(u)==length(du) "Length of `u` is not equal to length of `du`." - u, t = munge_data(u, t, safetycopy) - p = CubicHermiteParameterCache(du, u, t) - A = CubicHermiteSpline(du, u, t, nothing, p, extrapolate, safetycopy) - I = cumulative_integral(A) - CubicHermiteSpline(du, u, t, I, p, extrapolate, safetycopy) + u, t = munge_data(u, t) + p = if cache_parameters + CubicHermiteParameterCache(du, u, t) + else + CubicHermiteParameterCache(nothing, nothing) + end + A = CubicHermiteSpline(du, u, t, nothing, p, extrapolate, cache_parameters) + + if cache_parameters + I = cumulative_integral(A) + A = CubicHermiteSpline(du, u, t, I, p, extrapolate, cache_parameters) + end + + A end """ - QuinticHermiteSpline(ddu, du, u, t; extrapolate = false) + QuinticHermiteSpline(ddu, du, u, t; extrapolate = false, cache_parameters = false) It is a Quintic Hermite interpolation, which is a piece-wise fifth degree polynomial such that the value and the first and second derivative are equal to given values in the data points. @@ -756,7 +828,7 @@ It is a Quintic Hermite interpolation, which is a piece-wise fifth degree polyno ## Keyword Arguments - `extrapolate`: boolean value to allow extrapolation. Defaults to `false`. - - `safetycopy`: boolean value to make a copy of `u` and `t`. Defaults to `true`. + - `cache_parameters`: precompute parameters at initialization for faster interpolation computations. Note: if activated, `u` and `t` should not be modified. Defaults to `false`. """ struct QuinticHermiteSpline{uType, tType, IType, duType, dduType, pType, T} <: AbstractInterpolation{T} @@ -768,19 +840,28 @@ struct QuinticHermiteSpline{uType, tType, IType, duType, dduType, pType, T} <: p::QuinticHermiteParameterCache{pType} extrapolate::Bool idx_prev::Base.RefValue{Int} - safetycopy::Bool - function QuinticHermiteSpline(ddu, du, u, t, I, p, extrapolate, safetycopy) + cache_parameters::Bool + function QuinticHermiteSpline(ddu, du, u, t, I, p, extrapolate, cache_parameters) new{typeof(u), typeof(t), typeof(I), typeof(du), typeof(ddu), typeof(p.c₁), eltype(u)}( - ddu, du, u, t, I, p, extrapolate, Ref(1), safetycopy) + ddu, du, u, t, I, p, extrapolate, Ref(1), cache_parameters) end end -function QuinticHermiteSpline(ddu, du, u, t; extrapolate = false, safetycopy = true) +function QuinticHermiteSpline(ddu, du, u, t; extrapolate = false, cache_parameters = false) @assert length(u)==length(du)==length(ddu) "Length of `u` is not equal to length of `du` or `ddu`." - u, t = munge_data(u, t, safetycopy) - p = QuinticHermiteParameterCache(ddu, du, u, t) - A = QuinticHermiteSpline(ddu, du, u, t, nothing, p, extrapolate, safetycopy) - I = cumulative_integral(A) - QuinticHermiteSpline(ddu, du, u, t, I, p, extrapolate, safetycopy) + u, t = munge_data(u, t) + p = if cache_parameters + QuinticHermiteParameterCache(ddu, du, u, t) + else + QuinticHermiteParameterCache(nothing, nothing, nothing) + end + A = QuinticHermiteSpline(ddu, du, u, t, nothing, p, extrapolate, cache_parameters) + + if cache_parameters + I = cumulative_integral(A) + A = QuinticHermiteSpline(ddu, du, u, t, I, p, extrapolate, cache_parameters) + end + + A end diff --git a/src/interpolation_methods.jl b/src/interpolation_methods.jl index 5d09ceff..c409d5fc 100644 --- a/src/interpolation_methods.jl +++ b/src/interpolation_methods.jl @@ -10,15 +10,15 @@ end function _interpolate(A::LinearInterpolation{<:AbstractVector}, t::Number, iguess) if isnan(t) # For correct derivative with NaN - idx = firstindex(A.u) - 1 + idx = firstindex(A.u) t1 = t2 = one(eltype(A.t)) u1 = u2 = one(eltype(A.u)) - slope = t * one(eltype(A.p.slope)) + slope = t * get_parameters(A, idx) else idx = get_idx(A.t, t, iguess) t1, t2 = A.t[idx], A.t[idx + 1] u1, u2 = A.u[idx], A.u[idx + 1] - slope = A.p.slope[idx] + slope = get_parameters(A, idx) end Δt = t - t1 @@ -38,7 +38,8 @@ end function _interpolate(A::LinearInterpolation{<:AbstractMatrix}, t::Number, iguess) idx = get_idx(A.t, t, iguess) Δt = t - A.t[idx] - return A.u[:, idx] + A.p.slope[idx] * Δt, idx + slope = get_parameters(A, idx) + return A.u[:, idx] + slope * Δt, idx end # Quadratic Interpolation @@ -50,9 +51,10 @@ end function _interpolate(A::QuadraticInterpolation, t::Number, iguess) i₀, i₁, i₂ = _quad_interp_indices(A, t, iguess) - u₀ = A.p.l₀[i₀] * (t - A.t[i₁]) * (t - A.t[i₂]) - u₁ = A.p.l₁[i₀] * (t - A.t[i₀]) * (t - A.t[i₂]) - u₂ = A.p.l₂[i₀] * (t - A.t[i₀]) * (t - A.t[i₁]) + l₀, l₁, l₂ = get_parameters(A, i₀) + u₀ = l₀ * (t - A.t[i₁]) * (t - A.t[i₂]) + u₁ = l₁ * (t - A.t[i₀]) * (t - A.t[i₂]) + u₂ = l₂ * (t - A.t[i₀]) * (t - A.t[i₁]) return u₀ + u₁ + u₂, i₀ end @@ -149,7 +151,8 @@ function _interpolate(A::QuadraticSpline{<:AbstractVector}, t::Number, iguess) idx = get_idx(A.t, t, iguess) Cᵢ = A.u[idx] Δt = t - A.t[idx] - return A.z[idx] * Δt + A.p.σ[idx] * Δt^2 + Cᵢ, idx + σ = get_parameters(A, idx) + return A.z[idx] * Δt + σ * Δt^2 + Cᵢ, idx end # CubicSpline Interpolation @@ -158,8 +161,9 @@ function _interpolate(A::CubicSpline{<:AbstractVector}, t::Number, iguess) Δt₁ = t - A.t[idx] Δt₂ = A.t[idx + 1] - t I = (A.z[idx] * Δt₂^3 + A.z[idx + 1] * Δt₁^3) / (6A.h[idx + 1]) - C = A.p.c₁[idx] * Δt₁ - D = A.p.c₂[idx] * Δt₂ + c₁, c₂ = get_parameters(A, idx) + C = c₁ * Δt₁ + D = c₂ * Δt₂ I + C + D, idx end @@ -205,7 +209,8 @@ function _interpolate( Δt₀ = t - A.t[idx] Δt₁ = t - A.t[idx + 1] out = A.u[idx] + Δt₀ * A.du[idx] - out += Δt₀^2 * (A.p.c₁[idx] + Δt₁ * A.p.c₂[idx]) + c₁, c₂ = get_parameters(A, idx) + out += Δt₀^2 * (c₁ + Δt₁ * c₂) out, idx end @@ -216,6 +221,7 @@ function _interpolate( Δt₀ = t - A.t[idx] Δt₁ = t - A.t[idx + 1] out = A.u[idx] + Δt₀ * (A.du[idx] + A.ddu[idx] * Δt₀ / 2) - out += Δt₀^3 * (A.p.c₁[idx] + Δt₁ * (A.p.c₂[idx] + A.p.c₃[idx] * Δt₁)) + c₁, c₂, c₃ = get_parameters(A, idx) + out += Δt₀^3 * (c₁ + Δt₁ * (c₂ + c₃ * Δt₁)) out, idx end diff --git a/src/interpolation_utils.jl b/src/interpolation_utils.jl index 466248b1..17b08328 100644 --- a/src/interpolation_utils.jl +++ b/src/interpolation_utils.jl @@ -60,15 +60,11 @@ function spline_coefficients!(N, d, k, u::AbstractVector) end # helper function for data manipulation -function munge_data(u::AbstractVector{<:Real}, t::AbstractVector{<:Real}, safetycopy::Bool) - if safetycopy - u = copy(u) - t = copy(t) - end - return readonly_wrap(u), readonly_wrap(t) +function munge_data(u::AbstractVector{<:Real}, t::AbstractVector{<:Real}) + return u, t end -function munge_data(u::AbstractVector, t::AbstractVector, safetycopy::Bool) +function munge_data(u::AbstractVector, t::AbstractVector) Tu = Base.nonmissingtype(eltype(u)) Tt = Base.nonmissingtype(eltype(t)) @assert length(t) == length(u) @@ -77,17 +73,13 @@ function munge_data(u::AbstractVector, t::AbstractVector, safetycopy::Bool) if !ismissing(u[i]) && !ismissing(t[i]) ) - if safetycopy - u = Tu.([u[i] for i in non_missing_indices]) - t = Tt.([t[i] for i in non_missing_indices]) - else - !isempty(non_missing_indices) && throw(MustCopyError()) - end + u = Tu.([u[i] for i in non_missing_indices]) + t = Tt.([t[i] for i in non_missing_indices]) - return readonly_wrap(u), readonly_wrap(t) + return u, t end -function munge_data(U::StridedMatrix, t::AbstractVector, safetycopy::Bool) +function munge_data(U::StridedMatrix, t::AbstractVector) TU = Base.nonmissingtype(eltype(U)) Tt = Base.nonmissingtype(eltype(t)) @assert length(t) == size(U, 2) @@ -96,20 +88,12 @@ function munge_data(U::StridedMatrix, t::AbstractVector, safetycopy::Bool) if !any(ismissing, U[:, i]) && !ismissing(t[i]) ) - if safetycopy - U = hcat([TU.(U[:, i]) for i in non_missing_indices]...) - t = Tt.([t[i] for i in non_missing_indices]) - else - !isempty(non_missing_indices) && throw(MustCopyError()) - end + U = hcat([TU.(U[:, i]) for i in non_missing_indices]...) + t = Tt.([t[i] for i in non_missing_indices]) - return readonly_wrap(U), readonly_wrap(t) + return U, t end -# Don't nest ReadOnlyArrays -readonly_wrap(a::AbstractArray) = ReadOnlyArray(a) -readonly_wrap(a::ReadOnlyArray) = a - function get_idx(tvec, t, iguess; lb = 1, ub_shift = -1, idx_shift = 0, side = :last) ub = length(tvec) + ub_shift return if side == :last @@ -130,3 +114,51 @@ function cumulative_integral(A) pushfirst!(integral_values, zero(first(integral_values))) return cumsum(integral_values) end + +function get_parameters(A::LinearInterpolation, idx) + if A.cache_parameters + A.p.slope[idx] + else + linear_interpolation_parameters(A.u, A.t, idx) + end +end + +function get_parameters(A::QuadraticInterpolation, idx) + if A.cache_parameters + A.p.l₀[idx], A.p.l₁[idx], A.p.l₂[idx] + else + quadratic_interpolation_parameters(A.u, A.t, idx) + end +end + +function get_parameters(A::QuadraticSpline, idx) + if A.cache_parameters + A.p.σ[idx] + else + quadratic_spline_parameters(A.z, A.t, idx) + end +end + +function get_parameters(A::CubicSpline, idx) + if A.cache_parameters + A.p.c₁[idx], A.p.c₂[idx] + else + cubic_spline_parameters(A.u, A.h, A.z, idx) + end +end + +function get_parameters(A::CubicHermiteSpline, idx) + if A.cache_parameters + A.p.c₁[idx], A.p.c₂[idx] + else + cubic_hermite_spline_parameters(A.du, A.u, A.t, idx) + end +end + +function get_parameters(A::QuinticHermiteSpline, idx) + if A.cache_parameters + A.p.c₁[idx], A.p.c₂[idx], A.p.c₃[idx] + else + quintic_hermite_spline_parameters(A.ddu, A.du, A.u, A.t, idx) + end +end diff --git a/src/online.jl b/src/online.jl index 0fab5d44..5193e6b2 100644 --- a/src/online.jl +++ b/src/online.jl @@ -9,69 +9,81 @@ function add_integral_values!(A) end function push!(A::LinearInterpolation{U, T}, u::eltype(U), t::eltype(T)) where {U, T} - push!(A.u.parent, u) - push!(A.t.parent, t) - slope = linear_interpolation_parameters(A.u, A.t, length(A.t) - 1) - push!(A.p.slope, slope) - add_integral_values!(A) + push!(A.u, u) + push!(A.t, t) + if A.cache_parameters + slope = linear_interpolation_parameters(A.u, A.t, length(A.t) - 1) + push!(A.p.slope, slope) + add_integral_values!(A) + end A end function push!(A::QuadraticInterpolation{U, T}, u::eltype(U), t::eltype(T)) where {U, T} - push!(A.u.parent, u) - push!(A.t.parent, t) - l₀, l₁, l₂ = quadratic_interpolation_parameters(A.u, A.t, length(A.t) - 2) - push!(A.p.l₀, l₀) - push!(A.p.l₁, l₁) - push!(A.p.l₂, l₂) - add_integral_values!(A) + push!(A.u, u) + push!(A.t, t) + if A.cache_parameters + l₀, l₁, l₂ = quadratic_interpolation_parameters(A.u, A.t, length(A.t) - 2) + push!(A.p.l₀, l₀) + push!(A.p.l₁, l₁) + push!(A.p.l₂, l₂) + add_integral_values!(A) + end A end function push!(A::ConstantInterpolation{U, T}, u::eltype(U), t::eltype(T)) where {U, T} - push!(A.u.parent, u) - push!(A.t.parent, t) - add_integral_values!(A) + push!(A.u, u) + push!(A.t, t) + if A.cache_parameters + add_integral_values!(A) + end A end function append!( - A::LinearInterpolation{ReadOnlyVector{eltypeU, U}, ReadOnlyVector{eltypeT, T}}, u::U, t::T) where { - eltypeU, U, eltypeT, T} + A::LinearInterpolation{U, T}, u::U, t::T) where { + U, T} length_old = length(A.t) - u, t = munge_data(u, t, true) - append!(A.u.parent, u) - append!(A.t.parent, t) - slope = linear_interpolation_parameters.( - Ref(A.u), Ref(A.t), length_old:(length(A.t) - 1)) - append!(A.p.slope, slope) - add_integral_values!(A) + u, t = munge_data(u, t) + append!(A.u, u) + append!(A.t, t) + if A.cache_parameters + slope = linear_interpolation_parameters.( + Ref(A.u), Ref(A.t), length_old:(length(A.t) - 1)) + append!(A.p.slope, slope) + add_integral_values!(A) + end A end function append!( - A::ConstantInterpolation{ReadOnlyVector{eltypeU, U}, ReadOnlyVector{eltypeT, T}}, u::U, t::T) where { - eltypeU, U, eltypeT, T} - u, t = munge_data(u, t, true) - append!(A.u.parent, u) - append!(A.t.parent, t) - add_integral_values!(A) + A::ConstantInterpolation{U, T}, u::U, t::T) where { + U, T} + u, t = munge_data(u, t) + append!(A.u, u) + append!(A.t, t) + if A.cache_parameters + add_integral_values!(A) + end A end function append!( - A::QuadraticInterpolation{ReadOnlyVector{eltypeU, U}, ReadOnlyVector{eltypeT, T}}, u::U, t::T) where { - eltypeU, U, eltypeT, T} + A::QuadraticInterpolation{U, T}, u::U, t::T) where { + U, T} length_old = length(A.t) - u, t = munge_data(u, t, true) - append!(A.u.parent, u) - append!(A.t.parent, t) - parameters = quadratic_interpolation_parameters.( - Ref(A.u), Ref(A.t), (length_old - 1):(length(A.t) - 2)) - l₀, l₁, l₂ = collect.(eachrow(hcat(collect.(parameters)...))) - append!(A.p.l₀, l₀) - append!(A.p.l₁, l₁) - append!(A.p.l₂, l₂) - add_integral_values!(A) + u, t = munge_data(u, t) + append!(A.u, u) + append!(A.t, t) + if A.cache_parameters + parameters = quadratic_interpolation_parameters.( + Ref(A.u), Ref(A.t), (length_old - 1):(length(A.t) - 2)) + l₀, l₁, l₂ = collect.(eachrow(hcat(collect.(parameters)...))) + append!(A.p.l₀, l₀) + append!(A.p.l₁, l₁) + append!(A.p.l₂, l₂) + add_integral_values!(A) + end A end diff --git a/test/interpolation_tests.jl b/test/interpolation_tests.jl index 9549038c..9562c7b4 100644 --- a/test/interpolation_tests.jl +++ b/test/interpolation_tests.jl @@ -9,7 +9,6 @@ function test_interpolation_type(T) @test hasfield(T, :t) @test hasfield(T, :extrapolate) @test hasfield(T, :idx_prev) - @test hasfield(T, :safetycopy) @test !isempty(methods(DataInterpolations._interpolate, (T, Any, Number))) @test !isempty(methods(DataInterpolations._integral, (T, Any, Number))) @test !isempty(methods(DataInterpolations._derivative, (T, Any, Number))) diff --git a/test/online_tests.jl b/test/online_tests.jl index f9c3e1dd..3ae6438e 100644 --- a/test/online_tests.jl +++ b/test/online_tests.jl @@ -9,9 +9,9 @@ u2 = [1.0, 2.0, 1.0] ts = 1.0:0.5:6.0 for method in [LinearInterpolation, QuadraticInterpolation, ConstantInterpolation] - func1 = method(u1, t1) + func1 = method(copy(u1), copy(t1); cache_parameters = true) append!(func1, u2, t2) - func2 = method(vcat(u1, u2), vcat(t1, t2)) + func2 = method(vcat(u1, u2), vcat(t1, t2); cache_parameters = true) @test func1.u == func2.u @test func1.t == func2.t for name in propertynames(func1.p) @@ -20,9 +20,9 @@ for method in [LinearInterpolation, QuadraticInterpolation, ConstantInterpolatio @test func1(ts) == func2(ts) @test func1.I == func2.I - func1 = method(u1, t1) + func1 = method(copy(u1), copy(t1); cache_parameters = true) push!(func1, 1.0, 4.0) - func2 = method(vcat(u1, 1.0), vcat(t1, 4.0)) + func2 = method(vcat(u1, 1.0), vcat(t1, 4.0); cache_parameters = true) @test func1.u == func2.u @test func1.t == func2.t for name in propertynames(func1.p) diff --git a/test/parameter_tests.jl b/test/parameter_tests.jl index bcd26cf7..2e84b98d 100644 --- a/test/parameter_tests.jl +++ b/test/parameter_tests.jl @@ -3,14 +3,14 @@ using DataInterpolations @testset "Linear Interpolation" begin u = [1.0, 5.0, 3.0, 4.0, 4.0] t = collect(1:5) - A = LinearInterpolation(u, t) + A = LinearInterpolation(u, t; cache_parameters = true) @test A.p.slope ≈ [4.0, -2.0, 1.0, 0.0] end @testset "Quadratic Interpolation" begin u = [1.0, 5.0, 3.0, 4.0, 4.0] t = collect(1:5) - A = QuadraticInterpolation(u, t) + A = QuadraticInterpolation(u, t; cache_parameters = true) @test A.p.l₀ ≈ [0.5, 2.5, 1.5] @test A.p.l₁ ≈ [-5.0, -3.0, -4.0] @test A.p.l₂ ≈ [1.5, 2.0, 2.0] @@ -19,14 +19,14 @@ end @testset "Quadratic Spline" begin u = [1.0, 5.0, 3.0, 4.0, 4.0] t = collect(1:5) - A = QuadraticSpline(u, t) + A = QuadraticSpline(u, t; cache_parameters = true) @test A.p.σ ≈ [4.0, -10.0, 13.0, -14.0] end @testset "Cubic Spline" begin u = [1, 5, 3, 4, 4] t = collect(1:5) - A = CubicSpline(u, t) + A = CubicSpline(u, t; cache_parameters = true) @test A.p.c₁ ≈ [6.839285714285714, 1.642857142857143, 4.589285714285714, 4.0] @test A.p.c₂ ≈ [1.0, 6.839285714285714, 1.642857142857143, 4.589285714285714] end @@ -35,7 +35,7 @@ end du = [5.0, 3.0, 6.0, 8.0, 1.0] u = [1.0, 5.0, 3.0, 4.0, 4.0] t = collect(1:5) - A = CubicHermiteSpline(du, u, t) + A = CubicHermiteSpline(du, u, t; cache_parameters = true) @test A.p.c₁ ≈ [-1.0, -5.0, -5.0, -8.0] @test A.p.c₂ ≈ [0.0, 13.0, 12.0, 9.0] end @@ -45,7 +45,7 @@ end du = [5.0, 3.0, 6.0, 8.0, 1.0] u = [1.0, 5.0, 3.0, 4.0, 4.0] t = collect(1:5) - A = QuinticHermiteSpline(ddu, du, u, t) + A = QuinticHermiteSpline(ddu, du, u, t; cache_parameters = true) @test A.p.c₁ ≈ [-1.0, -6.5, -8.0, -10.0] @test A.p.c₂ ≈ [1.0, 19.5, 20.0, 19.0] @test A.p.c₃ ≈ [1.5, -37.5, -37.0, -26.5] diff --git a/test/runtests.jl b/test/runtests.jl index 0c722b2d..80080a75 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -10,3 +10,4 @@ using SafeTestsets @safetestset "Online Tests" include("online_tests.jl") @safetestset "Regularization Smoothing" include("regularization.jl") @safetestset "Show methods" include("show.jl") +@safetestset "Zygote support" include("zygote_tests.jl") diff --git a/test/zygote_tests.jl b/test/zygote_tests.jl new file mode 100644 index 00000000..6887ddd2 --- /dev/null +++ b/test/zygote_tests.jl @@ -0,0 +1,66 @@ +using DataInterpolations +using ForwardDiff +using Zygote + +function test_zygote(method, u, t; args = [], kwargs = [], name::String) + func = method(args..., u, t; kwargs..., extrapolate = true) + (; u, t) = func + trange = collect(range(minimum(t) - 5.0, maximum(t) + 5.0, step = 0.1)) + trange_exclude = filter(x -> !in(x, t), trange) + @testset "$name, derivatives w.r.t. input" begin + for _t in trange_exclude + adiff = DataInterpolations.derivative(func, _t) + zdiff = only(Zygote.gradient(func, _t)) + zdiff == nothing && (zdiff = 0.0) + @test adiff ≈ zdiff + end + end + @testset "$name, derivatives w.r.t. u" begin + function f(u) + A = method(args..., u, t; kwargs..., extrapolate = true) + out = zero(eltype(u)) + for _t in trange + out += A(_t) + end + out + end + zgrad = only(Zygote.gradient(f, u)) + fgrad = ForwardDiff.gradient(f, u) + @test zgrad ≈ fgrad + end +end + +@testset "LinearInterpolation" begin + u = vcat(collect(1:5), 2 * collect(6:10)) + t = 1.0collect(1:10) + test_zygote( + LinearInterpolation, u, t; name = "Linear Interpolation") +end + +@testset "Quadratic Interpolation" begin + u = [1.0, 4.0, 9.0, 16.0] + t = [1.0, 2.0, 3.0, 4.0] + test_zygote(QuadraticInterpolation, u, t; name = "Quadratic Interpolation") +end + +@testset "Constant Interpolation" begin + u = [0.0, 2.0, 1.0, 3.0, 2.0, 6.0, 5.5, 5.5, 2.7, 5.1, 3.0] + t = collect(0.0:10.0) + test_zygote(ConstantInterpolation, u, t; name = "Constant Interpolation") +end + +@testset "Cubic Hermite Spline" begin + du = [-0.047, -0.058, 0.054, 0.012, -0.068, 0.0] + u = [14.7, 11.51, 10.41, 14.95, 12.24, 11.22] + t = [0.0, 62.25, 109.66, 162.66, 205.8, 252.3] + test_zygote(CubicHermiteSpline, u, t, args = [du], name = "Cubic Hermite Spline") +end + +@testset "Quintic Hermite Spline" begin + ddu = [0.0, -0.00033, 0.0051, -0.0067, 0.0029, 0.0] + du = [-0.047, -0.058, 0.054, 0.012, -0.068, 0.0] + u = [14.7, 11.51, 10.41, 14.95, 12.24, 11.22] + t = [0.0, 62.25, 109.66, 162.66, 205.8, 252.3] + test_zygote( + QuinticHermiteSpline, u, t, args = [ddu, du], name = "Quintic Hermite Spline") +end From f69ad15d61d8c46d7a786f01e5c84930f186dc71 Mon Sep 17 00:00:00 2001 From: Bart de Koning Date: Sat, 27 Jul 2024 20:55:48 +0200 Subject: [PATCH 3/9] Merge remote-tracking branch 'upstream/master' into cache_parameters_opt_in --- .github/workflows/CompatHelper.yml | 2 +- .github/workflows/Downgrade.yml | 2 +- Project.toml | 2 +- README.md | 7 + docs/Project.toml | 10 +- docs/make.jl | 3 +- docs/src/index.md | 13 +- docs/src/manual.md | 1 + docs/src/methods.md | 51 ++--- docs/src/symbolics.md | 65 ++++++ joss/paper.bib | 14 ++ joss/paper.md | 21 +- src/DataInterpolations.jl | 14 +- src/derivatives.jl | 2 +- src/integral_inverses.jl | 2 +- src/interpolation_caches.jl | 178 +++++---------- src/interpolation_utils.jl | 50 ++++- src/parameter_caches.jl | 105 ++++++--- src/plot_rec.jl | 349 +++++++++++++++++++++-------- test/derivative_tests.jl | 6 + test/interface.jl | 60 +++-- test/interpolation_tests.jl | 20 ++ test/online_tests.jl | 10 +- 23 files changed, 678 insertions(+), 309 deletions(-) create mode 100644 docs/src/symbolics.md diff --git a/.github/workflows/CompatHelper.yml b/.github/workflows/CompatHelper.yml index 36e59135..35cc34ba 100644 --- a/.github/workflows/CompatHelper.yml +++ b/.github/workflows/CompatHelper.yml @@ -13,7 +13,7 @@ jobs: CompatHelper: runs-on: ubuntu-latest steps: - - uses: julia-actions/setup-julia@3645a07f58c7f83b9f82ac8e0bb95583e69149e6 + - uses: julia-actions/setup-julia@780022b48dfc0c2c6b94cfee6a9284850107d037 with: version: 1.3 - name: Pkg.add("CompatHelper") diff --git a/.github/workflows/Downgrade.yml b/.github/workflows/Downgrade.yml index c0d0123e..4546ebd0 100644 --- a/.github/workflows/Downgrade.yml +++ b/.github/workflows/Downgrade.yml @@ -28,7 +28,7 @@ jobs: - windows-latest steps: - uses: actions/checkout@v4 - - uses: julia-actions/setup-julia@v2.2.0 + - uses: julia-actions/setup-julia@v2.3.0 with: version: ${{ matrix.version }} - uses: julia-actions/julia-downgrade-compat@v1 diff --git a/Project.toml b/Project.toml index 963ee8cd..ccef602b 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DataInterpolations" uuid = "82cc6244-b520-54b8-b5a6-8a565e85f1d0" -version = "5.2.0" +version = "5.3.1" [deps] FindFirstFunctions = "64ca27bc-2ba2-4a57-88aa-44e436879224" diff --git a/README.md b/README.md index 8a326579..5348685f 100644 --- a/README.md +++ b/README.md @@ -60,6 +60,7 @@ corresponding to `(u,t)` pairs. + `knotVec` - Symbol to Knot Vector, `knotVec = :Uniform` for uniform knot vector, `knotVec = :Average` for average spaced knot vector. - `BSplineApprox(u,t,d,h,pVec,knotVec)` - A regression B-spline which smooths the fitting curve. The argument choices are the same as the `BSplineInterpolation`, with the additional parameter `h "methods.md", - "Interface" => "interface.md", "Manual" => "manual.md", "Inverting Integrals" => "inverting_integrals.md"]) + "Interface" => "interface.md", "Using with Symbolics/ModelingToolkit" => "symbolics.md", + "Manual" => "manual.md", "Inverting Integrals" => "inverting_integrals.md"]) deploydocs(repo = "github.com/SciML/DataInterpolations.jl"; push_preview = true) diff --git a/docs/src/index.md b/docs/src/index.md index f5b1512b..f2075f8e 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -1,10 +1,6 @@ # DataInterpolations.jl -DataInterpolations.jl is a library for performing interpolations of one-dimensional data. By -"data interpolations" we mean techniques for interpolating possibly noisy data, and thus -some methods are mixtures of regressions with interpolations (i.e. do not hit the data -points exactly, smoothing out the lines). This library can be used to fill in intermediate -data points in applications like timeseries data. +DataInterpolations.jl is a library for performing interpolations of one-dimensional data. Interpolations are a very important component of many modeling workflows. Often, sampled or measured inputs need to be transformed into continuous functions or smooth curves for simulation purposes. In many scientific machine learning workflows, interpolating data is essential to learn continuous models. DataInterpolations.jl can be used for facilitating these types of workflows. By "data interpolations" we mean techniques for interpolating possibly noisy data, and thus some methods are mixtures of regressions with interpolations (i.e. do not hit the data points exactly, smoothing out the lines). ## Installation @@ -35,6 +31,7 @@ corresponding to `(u,t)` pairs. + `knotVec` - Symbol to Knot Vector, `knotVec = :Uniform` for uniform knot vector, `knotVec = :Average` for average spaced knot vector. - `BSplineApprox(u,t,d,h,pVec,knotVec)` - A regression B-spline which smooths the fitting curve. The argument choices are the same as the `BSplineInterpolation`, with the additional parameter `h 0.5)) == cos(0.5) * A(0.5) # true +``` + +### Symbolic Derivatives + +```@example symbolics +D = Differential(τ) + +ex1 = A(τ) + +# Derivative of interpolation +ex2 = expand_derivatives(D(ex1)) + +@test substitute(ex2, Dict(τ => 0.5)) == DataInterpolations.derivative(A, 0.5) # true + +# Higher Order Derivatives +ex3 = expand_derivatives(D(D(A(τ)))) + +@test substitute(ex3, Dict(τ => 0.5)) == DataInterpolations.derivative(A, 0.5, 2) # true +``` + +## Using with ModelingToolkit.jl + +Most common use case with [ModelingToolkit.jl](https://docs.sciml.ai/ModelingToolkit/stable/) is to plug in interpolation objects as input functions. This can be done using `TimeVaryingFunction` component of [ModelingToolkitStandardLibrary.jl](https://docs.sciml.ai/ModelingToolkitStandardLibrary/stable/). + +```@example mtk +using DataInterpolations +using ModelingToolkitStandardLibrary.Blocks +using ModelingToolkit +using ModelingToolkit: t_nounits as t, D_nounits as D +using OrdinaryDiffEq + +us = [0.0, 1.5, 0.0] +times = [0.0, 0.5, 1.0] +A = LinearInterpolation(us, times) + +@named src = TimeVaryingFunction(A) +vars = @variables x(t) out(t) +eqs = [out ~ src.output.u, D(x) ~ 1 + out] +@named sys = ODESystem(eqs, t, vars, []; systems = [src]) + +sys = structural_simplify(sys) +prob = ODEProblem(sys, [x => 0.0], (times[1], times[end])) +sol = solve(prob) +``` diff --git a/joss/paper.bib b/joss/paper.bib index d754d9cc..101b0181 100644 --- a/joss/paper.bib +++ b/joss/paper.bib @@ -134,3 +134,17 @@ @book{lagrange1898lectures year={1898}, publisher={Open court publishing Company} } + +@article{doi:10.1137/0905021, + author = {Fritsch, F. N. and Butland, J.}, + title = {A Method for Constructing Local Monotone Piecewise Cubic Interpolants}, + journal = {SIAM Journal on Scientific and Statistical Computing}, + volume = {5}, + number = {2}, + pages = {300-304}, + year = {1984}, + doi = {10.1137/0905021}, + URL = {https://doi.org/10.1137/0905021}, + eprint = {https://doi.org/10.1137/0905021}, + abstract = { A method is described for producing monotone piecewise cubic interpolants to monotone data which is completely local and which is extremely simple to implement. } +} diff --git a/joss/paper.md b/joss/paper.md index f2d2a4be..66d15aa0 100644 --- a/joss/paper.md +++ b/joss/paper.md @@ -31,15 +31,30 @@ bibliography: paper.bib # Summary -Interpolations are used to estimate values between known data points using an approximate continuous function.DataInterpolations.jl is a Julia [@Bezanson2017] package containing 1D implementations of some of the most commonly used interpolation functions. These include Constant Interpolation, Linear Interpolation, Quadratic Interpolation, Lagrange Interpolation [@lagrange], Quadratic Splines, Cubic Splines [@Schoenberg1988], Akima Splines [@10.1145/321607.321609], Cubic Hermite Splines, Quintic Hermite Splines, B-Splines [@Curry1988] [@DEBOOR197250] and Regression based B-Splines. Along with these, the package also has methods to fit parameterized curves with the data points and Tikhonov regularization [@Tikhonov1943OnTS] [@amt-14-7909-2021] for obtaining smooth curves. The package also provides functionality to compute integrals and derivatives upto second order for those interpolations methods. +Interpolations are used to estimate values between known data points using an approximate continuous function.DataInterpolations.jl is a Julia [@Bezanson2017] package containing 1D implementations of some of the most commonly used interpolation functions. These include: + + - Constant Interpolation + - Linear Interpolation + - Quadratic Interpolation + - Lagrange Interpolation [@lagrange] + - Quadratic Splines + - Cubic Splines [@Schoenberg1988] + - Akima Splines [@10.1145/321607.321609] + - Cubic Hermite Splines + - Piecewise Cubic Hermite Interpolating Polynomial (PCHIP) [@doi:10.1137/0905021] + - Quintic Hermite Splines + - B-Splines [@Curry1988] [@DEBOOR197250] + - Regression based B-Splines + +and a continually growing list. Along with these, the package also has methods to fit parameterized curves with the data points and Tikhonov regularization [@Tikhonov1943OnTS] [@amt-14-7909-2021] for obtaining smooth curves. The package also provides functionality to compute integrals and derivatives upto second order for those interpolations methods. It is also automatic differentiation friendly. It can also be used symbolically with Symbolics.jl [@gowda2021high] and plugged into models defined using ModelingToolkit.jl [@ma2021modelingtoolkit]. # Statement of need -Interpolations are a very important component of many modeling workflows. In many models, inputs which are sampled or measured need to be represented as a continuous function or a smooth curve for simulation. In many scientific machine learning workflows, we need interpolations of data to learn continuous models. There already have been a few interpolation packages in Julia like Interpolations.jl but it has a limitation of assuming uniformly spaced data which is not usually the case with data collected from real world. DataInterpolations.jl provides fast interpolation methods for arbitrary spaced 1D data with a consistent and simple interface. It is also automatic differentiation friendly. It can also be used symbolically with Symbolics.jl [@gowda2021high] and plugged into models defined using ModelingToolkit.jl [@ma2021modelingtoolkit]. +Interpolations are a very important component of many modeling workflows. Often, sampled or measured inputs need to be transformed into continuous functions or smooth curves for simulation purposes. In many scientific machine learning workflows, interpolating data is essential to learn continuous models. DataInterpolations.jl can be used for facilitating these types of workflows. Several interpolation packages already exist in Julia, such as [Interpolations.jl](https://juliamath.github.io/Interpolations.jl/stable/), which primarily specializes in B-Splines and uniformly spaced data with some support for irregularly spaced data. In contrast, DataInterpolations.jl does not assume any specific structure in the data, offering greater flexibility for diverse datasets. [Interpolations.jl](https://juliamath.github.io/Interpolations.jl/stable/) also doesn't offer methods like Quadratic Interpolation, Lagrange Interpolation, Hermite Splines etc. [BasicInterpolators.jl](https://github.com/markmbaum/BasicInterpolators.jl) is more similar to DataInterpolations.jl, although it doesn't offer methods like B-Splines. Rest of the interpolation packages focus on particular methods like [BSplineKit.jl](https://github.com/jipolanco/BSplineKit.jl) for B-Splines, [FastChebInterp.jl](https://github.com/JuliaMath/FastChebInterp.jl) for Chebyshev interpolation, [PCHIPInterpolation](https://github.com/gerlero/PCHIPInterpolation.jl) for PCHIP interpolation etc. Additionally, DataInterpolations.jl includes many novel techniques for accelerating the interpolation searches with specialized caching, quasi-linear guessing, and more to improve the performance algorithmically, beyond the simple computational optimizations. In summary, DataInterpolations.jl is more generic from other packages and offers many fast interpolation methods for arbitrarily spaced 1D data, all within a consistent and simple interface. # Example -The following tutorials in the documentation [1](https://docs.sciml.ai/DataInterpolations/stable/methods/) provides how to define each of the interpolation methods and compute the value at any point. [2](https://docs.sciml.ai/DataInterpolations/stable/interface/) provides explanation for using the interface and interpolated objects for evaluating at any point, computing the derivative at any point and computing the integral between any two points. +The following tutorials in the documentation [1](https://docs.sciml.ai/DataInterpolations/stable/methods/) provides how to define each of the interpolation methods and compute the value at any point. [2](https://docs.sciml.ai/DataInterpolations/stable/interface/) provides explanation for using the interface and interpolated objects for evaluating at any point, computing the derivative at any point and computing the integral between any two points. [3](https://docs.sciml.ai/DataInterpolations/stable/symbolics/) provides how to use interpolation objects with Symbolics.jl and ModelingToolkit.jl. A simple demonstration here: diff --git a/src/DataInterpolations.jl b/src/DataInterpolations.jl index 19cb47c0..7f44c878 100644 --- a/src/DataInterpolations.jl +++ b/src/DataInterpolations.jl @@ -22,7 +22,11 @@ include("online.jl") include("show.jl") (interp::AbstractInterpolation)(t::Number) = _interpolate(interp, t) -(interp::AbstractInterpolation)(t::Number, i::Integer) = _interpolate(interp, t, i) +function (interp::AbstractInterpolation)(t::Number, i::Integer) + interp.idx_prev[] = i + _interpolate(interp, t) +end + function (interp::AbstractInterpolation)(t::AbstractVector) u = get_u(interp.u, t) interp(u, t) @@ -43,16 +47,14 @@ function get_u(u::AbstractMatrix, t) end function (interp::AbstractInterpolation)(u::AbstractMatrix, t::AbstractVector) - iguess = firstindex(interp.t) @inbounds for i in eachindex(t) - u[:, i], iguess = interp(t[i], iguess) + u[:, i] = interp(t[i]) end u end function (interp::AbstractInterpolation)(u::AbstractVector, t::AbstractVector) - iguess = firstindex(interp.t) @inbounds for i in eachindex(u, t) - u[i], iguess = interp(t[i], iguess) + u[i] = interp(t[i]) end u end @@ -89,7 +91,7 @@ end export LinearInterpolation, QuadraticInterpolation, LagrangeInterpolation, AkimaInterpolation, ConstantInterpolation, QuadraticSpline, CubicSpline, - BSplineInterpolation, BSplineApprox, CubicHermiteSpline, + BSplineInterpolation, BSplineApprox, CubicHermiteSpline, PCHIPInterpolation, QuinticHermiteSpline, LinearInterpolationIntInv, ConstantInterpolationIntInv # added for RegularizationSmooth, JJS 11/27/21 diff --git a/src/derivatives.jl b/src/derivatives.jl index 01eb18bb..75872095 100644 --- a/src/derivatives.jl +++ b/src/derivatives.jl @@ -18,7 +18,7 @@ function derivative(A, t, order = 1) end function _derivative(A::LinearInterpolation, t::Number, iguess) - idx = get_idx(A.t, t, iguess; idx_shift = -1, ub_shift = -2, side = :first) + idx = get_idx(A.t, t, iguess; idx_shift = -1, ub_shift = -1, side = :first) slope = get_parameters(A, idx) slope, idx end diff --git a/src/integral_inverses.jl b/src/integral_inverses.jl index 38c14b14..31d0853c 100644 --- a/src/integral_inverses.jl +++ b/src/integral_inverses.jl @@ -50,7 +50,7 @@ function invertible_integral(A::LinearInterpolation{<:AbstractVector{<:Number}}) return all(A.u .> 0) end -get_I(A::AbstractInterpolation) = isnothing(A.I) ? cumulative_integral(A) : A.I +get_I(A::AbstractInterpolation) = isempty(A.I) ? cumulative_integral(A, true) : A.I function invert_integral(A::LinearInterpolation{<:AbstractVector{<:Number}}) !invertible_integral(A) && throw(IntegralNotInvertibleError()) diff --git a/src/interpolation_caches.jl b/src/interpolation_caches.jl index 286bf6bc..7f6991b8 100644 --- a/src/interpolation_caches.jl +++ b/src/interpolation_caches.jl @@ -30,24 +30,14 @@ end function LinearInterpolation(u, t; extrapolate = false, cache_parameters = false) u, t = munge_data(u, t) - p = if cache_parameters - LinearParameterCache(u, t) - else - LinearParameterCache(nothing) - end - + p = LinearParameterCache(u, t, cache_parameters) A = LinearInterpolation(u, t, nothing, p, extrapolate, cache_parameters) - - if cache_parameters - I = cumulative_integral(A) - A = LinearInterpolation(u, t, I, p, extrapolate, cache_parameters) - end - - A + I = cumulative_integral(A, cache_parameters) + LinearInterpolation(u, t, I, p, extrapolate, cache_parameters) end """ - QuadraticInterpolation(u, t, mode = :Forward; cache_parameters = false) + QuadraticInterpolation(u, t, mode = :Forward; extrapolate = false, cache_parameters = false) It is the method of interpolating between the data points using quadratic polynomials. For any point, three data points nearby are taken to fit a quadratic polynomial. Extrapolation extends the last quadratic polynomial on each side. @@ -82,20 +72,10 @@ end function QuadraticInterpolation(u, t, mode; extrapolate = false, cache_parameters = false) u, t = munge_data(u, t) - p = if cache_parameters - QuadraticParameterCache(u, t) - else - QuadraticParameterCache(nothing, nothing, nothing) - end - + p = QuadraticParameterCache(u, t, cache_parameters) A = QuadraticInterpolation(u, t, nothing, p, mode, extrapolate, cache_parameters) - - if cache_parameters - I = cumulative_integral(A) - A = QuadraticInterpolation(u, t, I, p, mode, extrapolate, cache_parameters) - end - - A + I = cumulative_integral(A, cache_parameters) + QuadraticInterpolation(u, t, I, p, mode, extrapolate, cache_parameters) end function QuadraticInterpolation(u, t; extrapolate = false, cache_parameters = false) @@ -103,7 +83,7 @@ function QuadraticInterpolation(u, t; extrapolate = false, cache_parameters = fa end """ - LagrangeInterpolation(u, t, n = length(t) - 1; extrapolate = false) + LagrangeInterpolation(u, t, n = length(t) - 1; extrapolate = false, safetycopy = true) It is the method of interpolation using Lagrange polynomials of (k-1)th order passing through all the data points where k is the number of data points. @@ -153,7 +133,7 @@ end """ AkimaInterpolation(u, t; extrapolate = false, cache_parameters = false) -It is a spline interpolation built from cubic polynomials. It forms a continuously differentiable function. For more details, refer: https://en.wikipedia.org/wiki/Akima_spline. +It is a spline interpolation built from cubic polynomials. It forms a continuously differentiable function. For more details, refer: [https://en.wikipedia.org/wiki/Akima_spline](https://en.wikipedia.org/wiki/Akima_spline). Extrapolation extends the last cubic polynomial on each side. ## Arguments @@ -215,13 +195,8 @@ function AkimaInterpolation(u, t; extrapolate = false, cache_parameters = false) d = (b[1:(end - 1)] .+ b[2:end] .- 2.0 .* m[3:(end - 2)]) ./ dt .^ 2 A = AkimaInterpolation(u, t, nothing, b, c, d, extrapolate, cache_parameters) - - if cache_parameters - I = cumulative_integral(A) - A = AkimaInterpolation(u, t, I, b, c, d, extrapolate, cache_parameters) - end - - A + I = cumulative_integral(A, cache_parameters) + AkimaInterpolation(u, t, I, b, c, d, extrapolate, cache_parameters) end """ @@ -261,13 +236,8 @@ function ConstantInterpolation( u, t; dir = :left, extrapolate = false, cache_parameters = false) u, t = munge_data(u, t) A = ConstantInterpolation(u, t, nothing, dir, extrapolate, cache_parameters) - - if cache_parameters - I = cumulative_integral(A) - A = ConstantInterpolation(u, t, I, dir, extrapolate, cache_parameters) - end - - A + I = cumulative_integral(A, cache_parameters) + ConstantInterpolation(u, t, I, dir, extrapolate, cache_parameters) end """ @@ -330,20 +300,10 @@ function QuadraticSpline( d = map(i -> i == 1 ? typed_zero : 2 // 1 * (u[i] - u[i - 1]) / (t[i] - t[i - 1]), 1:s) z = tA \ d - p = if cache_parameters - QuadraticSplineParameterCache(z, t) - else - QuadraticSplineParameterCache(nothing) - end - + p = QuadraticSplineParameterCache(z, t, cache_parameters) A = QuadraticSpline(u, t, nothing, p, tA, d, z, extrapolate, cache_parameters) - - if cache_parameters - I = cumulative_integral(A) - A = QuadraticSpline(u, t, I, p, tA, d, z, extrapolate, cache_parameters) - end - - A + I = cumulative_integral(A, cache_parameters) + QuadraticSpline(u, t, I, p, tA, d, z, extrapolate, cache_parameters) end function QuadraticSpline( @@ -362,19 +322,11 @@ function QuadraticSpline( d = transpose(reshape(reduce(hcat, d_), :, s)) z_ = reshape(transpose(tA \ d), size(u[1])..., :) z = [z_s for z_s in eachslice(z_, dims = ndims(z_))] - p = if cache_parameters - QuadraticSplineParameterCache(z, t) - else - QuadraticSplineParameterCache(nothing) - end - A = QuadraticSpline(u, t, nothing, p, tA, d, z, extrapolate, cache_parameters) - - if cache_parameters - I = cumulative_integral(A) - A = QuadraticSpline(u, t, I, p, tA, d, z, extrapolate, cache_parameters) - end - A + p = QuadraticSplineParameterCache(z, t, cache_parameters) + A = QuadraticSpline(u, t, nothing, p, tA, d, z, extrapolate, cache_parameters) + I = cumulative_integral(A, cache_parameters) + QuadraticSpline(u, t, I, p, tA, d, z, extrapolate, cache_parameters) end """ @@ -439,19 +391,11 @@ function CubicSpline(u::uType, 6(u[i + 1] - u[i]) / h[i + 1] - 6(u[i] - u[i - 1]) / h[i], 1:(n + 1)) z = tA \ d - p = if cache_parameters - CubicSplineParameterCache(u, h, z) - else - CubicSplineParameterCache(nothing, nothing) - end - A = CubicSpline(u, t, nothing, p, h[1:(n + 1)], z, extrapolate, cache_parameters) - if cache_parameters - I = cumulative_integral(A) - A = CubicSpline(u, t, I, p, h[1:(n + 1)], z, extrapolate, cache_parameters) - end - - A + p = CubicSplineParameterCache(u, h, z, cache_parameters) + A = CubicSpline(u, t, nothing, p, h[1:(n + 1)], z, extrapolate, cache_parameters) + I = cumulative_integral(A, cache_parameters) + CubicSpline(u, t, I, p, h[1:(n + 1)], z, extrapolate, cache_parameters) end function CubicSpline( @@ -471,26 +415,17 @@ function CubicSpline( d = transpose(reshape(reduce(hcat, d_), :, n + 1)) z_ = reshape(transpose(tA \ d), size(u[1])..., :) z = [z_s for z_s in eachslice(z_, dims = ndims(z_))] - p = if cache_parameters - CubicSplineParameterCache(u, h, z) - else - CubicSplineParameterCache(nothing, nothing) - end + p = CubicSplineParameterCache(u, h, z, cache_parameters) A = CubicSpline(u, t, nothing, p, h[1:(n + 1)], z, extrapolate, cache_parameters) - - if cache_parameters - I = cumulative_integral(A) - A = CubicSpline(u, t, I, p, h[1:(n + 1)], z, extrapolate, cache_parameters) - end - - A + I = cumulative_integral(A, cache_parameters) + CubicSpline(u, t, I, p, h[1:(n + 1)], z, extrapolate, cache_parameters) end """ - BSplineInterpolation(u, t, d, pVecType, knotVecType; extrapolate = false) + BSplineInterpolation(u, t, d, pVecType, knotVecType; extrapolate = false, safetycopy = true) -It is a curve defined by the linear combination of `n` basis functions of degree `d` where `n` is the number of data points. For more information, refer https://pages.mtu.edu/~shene/COURSES/cs3621/NOTES/spline/B-spline/bspline-curve.html. +It is a curve defined by the linear combination of `n` basis functions of degree `d` where `n` is the number of data points. For more information, refer [https://pages.mtu.edu/~shene/COURSES/cs3621/NOTES/spline/B-spline/bspline-curve.html](https://pages.mtu.edu/%7Eshene/COURSES/cs3621/NOTES/spline/B-spline/bspline-curve.html). Extrapolation is a constant polynomial of the end points on each side. ## Arguments @@ -613,10 +548,10 @@ function BSplineInterpolation( end """ - BSplineApprox(u, t, d, h, pVecType, knotVecType) + BSplineApprox(u, t, d, h, pVecType, knotVecType; extrapolate = false) It is a regression based B-spline. The argument choices are the same as the `BSplineInterpolation`, with the additional parameter `h < length(t)` which is the number of control points to use, with smaller `h` indicating more smoothing. -For more information, refer http://www.cad.zju.edu.cn/home/zhx/GM/009/00-bsia.pdf. +For more information, refer [http://www.cad.zju.edu.cn/home/zhx/GM/009/00-bsia.pdf](http://www.cad.zju.edu.cn/home/zhx/GM/009/00-bsia.pdf). Extrapolation is a constant polynomial of the end points on each side. ## Arguments @@ -798,23 +733,37 @@ end function CubicHermiteSpline(du, u, t; extrapolate = false, cache_parameters = false) @assert length(u)==length(du) "Length of `u` is not equal to length of `du`." u, t = munge_data(u, t) - p = if cache_parameters - CubicHermiteParameterCache(du, u, t) - else - CubicHermiteParameterCache(nothing, nothing) - end + p = CubicHermiteParameterCache(du, u, t, cache_parameters) A = CubicHermiteSpline(du, u, t, nothing, p, extrapolate, cache_parameters) + I = cumulative_integral(A, cache_parameters) + CubicHermiteSpline(du, u, t, I, p, extrapolate, cache_parameters) +end - if cache_parameters - I = cumulative_integral(A) - A = CubicHermiteSpline(du, u, t, I, p, extrapolate, cache_parameters) - end +""" + PCHIPInterpolation(u, t; extrapolate = false, safetycopy = true) + +It is a PCHIP Interpolation, which is a type of [`CubicHermiteSpline`](@ref) where the derivative values `du` are derived from the input data +in such a way that the interpolation never overshoots the data. See [here](https://www.mathworks.com/content/dam/mathworks/mathworks-dot-com/moler/interp.pdf), +section 3.4 for more details. + +## Arguments - A + - `u`: data points. + - `t`: time points. + +## Keyword Arguments + + - `extrapolate`: boolean value to allow extrapolation. Defaults to `false`. + - `cache_parameters`: precompute parameters at initialization for faster interpolation computations. Note: if activated, `u` and `t` should not be modified. Defaults to `false`. +""" +function PCHIPInterpolation(u, t; extrapolate = false, cache_parameters = false) + u, t = munge_data(u, t) + du = du_PCHIP(u, t) + CubicHermiteSpline(du, u, t; extrapolate, cache_parameters) end """ - QuinticHermiteSpline(ddu, du, u, t; extrapolate = false, cache_parameters = false) + QuinticHermiteSpline(ddu, du, u, t; extrapolate = false, safetycopy = true) It is a Quintic Hermite interpolation, which is a piece-wise fifth degree polynomial such that the value and the first and second derivative are equal to given values in the data points. @@ -851,17 +800,8 @@ end function QuinticHermiteSpline(ddu, du, u, t; extrapolate = false, cache_parameters = false) @assert length(u)==length(du)==length(ddu) "Length of `u` is not equal to length of `du` or `ddu`." u, t = munge_data(u, t) - p = if cache_parameters - QuinticHermiteParameterCache(ddu, du, u, t) - else - QuinticHermiteParameterCache(nothing, nothing, nothing) - end + p = QuinticHermiteParameterCache(ddu, du, u, t, cache_parameters) A = QuinticHermiteSpline(ddu, du, u, t, nothing, p, extrapolate, cache_parameters) - - if cache_parameters - I = cumulative_integral(A) - A = QuinticHermiteSpline(ddu, du, u, t, I, p, extrapolate, cache_parameters) - end - - A + I = cumulative_integral(A, cache_parameters) + QuinticHermiteSpline(ddu, du, u, t, I, p, extrapolate, cache_parameters) end diff --git a/src/interpolation_utils.jl b/src/interpolation_utils.jl index 17b08328..2a2392bd 100644 --- a/src/interpolation_utils.jl +++ b/src/interpolation_utils.jl @@ -105,14 +105,15 @@ function get_idx(tvec, t, iguess; lb = 1, ub_shift = -1, idx_shift = 0, side = : end end -function cumulative_integral(A) - if isempty(methods(_integral, (typeof(A), Any, Any))) - return nothing +function cumulative_integral(A, cache_parameters) + if cache_parameters && hasmethod(_integral, Tuple{typeof(A), Number, Number}) + integral_values = [_integral(A, idx, A.t[idx + 1]) - _integral(A, idx, A.t[idx]) + for idx in 1:(length(A.t) - 1)] + pushfirst!(integral_values, zero(first(integral_values))) + cumsum(integral_values) + else + promote_type(eltype(A.u), eltype(A.t))[] end - integral_values = [_integral(A, idx, A.t[idx + 1]) - _integral(A, idx, A.t[idx]) - for idx in 1:(length(A.t) - 1)] - pushfirst!(integral_values, zero(first(integral_values))) - return cumsum(integral_values) end function get_parameters(A::LinearInterpolation, idx) @@ -162,3 +163,38 @@ function get_parameters(A::QuinticHermiteSpline, idx) quintic_hermite_spline_parameters(A.ddu, A.du, A.u, A.t, idx) end end + +function du_PCHIP(u, t) + h = diff(u) + δ = h ./ diff(t) + s = sign.(δ) + + function _du(k) + sₖ₋₁, sₖ = if k == 1 + s[1], s[2] + elseif k == lastindex(t) + s[end - 1], s[end] + else + s[k - 1], s[k] + end + + if sₖ₋₁ == 0 && sₖ == 0 + zero(eltype(δ)) + elseif sₖ₋₁ == sₖ + if k == 1 + ((2 * h[1] + h[2]) * δ[1] - h[1] * δ[2]) / (h[1] + h[2]) + elseif k == lastindex(t) + ((2 * h[end] + h[end - 1]) * δ[end] - h[end] * δ[end - 1]) / + (h[end] + h[end - 1]) + else + w₁ = 2h[k] + h[k - 1] + w₂ = h[k] + 2h[k - 1] + δ[k - 1] * δ[k] * (w₁ + w₂) / (w₁ * δ[k] + w₂ * δ[k - 1]) + end + else + zero(eltype(δ)) + end + end + + return _du.(eachindex(t)) +end diff --git a/src/parameter_caches.jl b/src/parameter_caches.jl index 2820dc8f..0701b3a2 100644 --- a/src/parameter_caches.jl +++ b/src/parameter_caches.jl @@ -2,13 +2,28 @@ struct LinearParameterCache{pType} slope::pType end -function LinearParameterCache(u, t) - slope = linear_interpolation_parameters.(Ref(u), Ref(t), 1:(length(t) - 1)) - return LinearParameterCache(slope) +function LinearParameterCache(u, t, cache_parameters) + if cache_parameters + slope = linear_interpolation_parameters.(Ref(u), Ref(t), 1:(length(t) - 1)) + LinearParameterCache(slope) + else + # Compute parameters once to infer types + slope = linear_interpolation_parameters(u, t, 1) + LinearParameterCache(typeof(slope)[]) + end +end + +# Prevent e.g. Inf - Inf = NaN +function safe_diff(b, a::T) where {T} + b == a ? zero(T) : b - a end -function linear_interpolation_parameters(u, t, idx) - Δu = u isa AbstractMatrix ? u[:, idx + 1] - u[:, idx] : u[idx + 1] - u[idx] +function linear_interpolation_parameters(u::AbstractArray{T}, t, idx) where {T} + Δu = if u isa AbstractMatrix + [safe_diff(u[j, idx + 1], u[j, idx]) for j in 1:size(u)[1]] + else + safe_diff(u[idx + 1], u[idx]) + end Δt = t[idx + 1] - t[idx] slope = Δu / Δt slope = iszero(Δt) ? zero(slope) : slope @@ -21,11 +36,18 @@ struct QuadraticParameterCache{pType} l₂::pType end -function QuadraticParameterCache(u, t) - parameters = quadratic_interpolation_parameters.( - Ref(u), Ref(t), 1:(length(t) - 2)) - l₀, l₁, l₂ = collect.(eachrow(hcat(collect.(parameters)...))) - return QuadraticParameterCache(l₀, l₁, l₂) +function QuadraticParameterCache(u, t, cache_parameters) + if cache_parameters + parameters = quadratic_interpolation_parameters.( + Ref(u), Ref(t), 1:(length(t) - 2)) + l₀, l₁, l₂ = collect.(eachrow(stack(collect.(parameters)))) + QuadraticParameterCache(l₀, l₁, l₂) + else + # Compute parameters once to infer types + l₀, l₁, l₂ = quadratic_interpolation_parameters(u, t, 1) + pType = typeof(l₀) + QuadraticParameterCache(pType[], pType[], pType[]) + end end function quadratic_interpolation_parameters(u, t, idx) @@ -54,9 +76,15 @@ struct QuadraticSplineParameterCache{pType} σ::pType end -function QuadraticSplineParameterCache(z, t) - σ = quadratic_spline_parameters.(Ref(z), Ref(t), 1:(length(t) - 1)) - return QuadraticSplineParameterCache(σ) +function QuadraticSplineParameterCache(z, t, cache_parameters) + if cache_parameters + σ = quadratic_spline_parameters.(Ref(z), Ref(t), 1:(length(t) - 1)) + QuadraticSplineParameterCache(σ) + else + # Compute parameters once to infer types + σ = quadratic_spline_parameters(z, t, 1) + QuadraticSplineParameterCache(typeof(σ)[]) + end end function quadratic_spline_parameters(z, t, idx) @@ -69,11 +97,18 @@ struct CubicSplineParameterCache{pType} c₂::pType end -function CubicSplineParameterCache(u, h, z) - parameters = cubic_spline_parameters.( - Ref(u), Ref(h), Ref(z), 1:(size(u)[end] - 1)) - c₁, c₂ = collect.(eachrow(hcat(collect.(parameters)...))) - return CubicSplineParameterCache(c₁, c₂) +function CubicSplineParameterCache(u, h, z, cache_parameters) + if cache_parameters + parameters = cubic_spline_parameters.( + Ref(u), Ref(h), Ref(z), 1:(size(u)[end] - 1)) + c₁, c₂ = collect.(eachrow(stack(collect.(parameters)))) + CubicSplineParameterCache(c₁, c₂) + else + # Compute parameters once to infer types + c₁, c₂ = cubic_spline_parameters(u, h, z, 1) + pType = typeof(c₁) + CubicSplineParameterCache(pType[], pType[]) + end end function cubic_spline_parameters(u, h, z, idx) @@ -87,11 +122,18 @@ struct CubicHermiteParameterCache{pType} c₂::pType end -function CubicHermiteParameterCache(du, u, t) - parameters = cubic_hermite_spline_parameters.( - Ref(du), Ref(u), Ref(t), 1:(length(t) - 1)) - c₁, c₂ = collect.(eachrow(hcat(collect.(parameters)...))) - return CubicHermiteParameterCache(c₁, c₂) +function CubicHermiteParameterCache(du, u, t, cache_parameters) + if cache_parameters + parameters = cubic_hermite_spline_parameters.( + Ref(du), Ref(u), Ref(t), 1:(length(t) - 1)) + c₁, c₂ = collect.(eachrow(stack(collect.(parameters)))) + CubicHermiteParameterCache(c₁, c₂) + else + # Compute parameters once to infer types + c₁, c₂ = cubic_hermite_spline_parameters(du, u, t, 1) + pType = typeof(c₁) + CubicHermiteParameterCache(pType[], pType[]) + end end function cubic_hermite_spline_parameters(du, u, t, idx) @@ -111,11 +153,18 @@ struct QuinticHermiteParameterCache{pType} c₃::pType end -function QuinticHermiteParameterCache(ddu, du, u, t) - parameters = quintic_hermite_spline_parameters.( - Ref(ddu), Ref(du), Ref(u), Ref(t), 1:(length(t) - 1)) - c₁, c₂, c₃ = collect.(eachrow(hcat(collect.(parameters)...))) - return QuinticHermiteParameterCache(c₁, c₂, c₃) +function QuinticHermiteParameterCache(ddu, du, u, t, cache_parameters) + if cache_parameters + parameters = quintic_hermite_spline_parameters.( + Ref(ddu), Ref(du), Ref(u), Ref(t), 1:(length(t) - 1)) + c₁, c₂, c₃ = collect.(eachrow(stack(collect.(parameters)))) + QuinticHermiteParameterCache(c₁, c₂, c₃) + else + # Compute parameters once to infer types + c₁, c₂, c₃ = quintic_hermite_spline_parameters(ddu, du, u, t, 1) + pType = typeof(c₁) + QuinticHermiteParameterCache(pType[], pType[], pType[]) + end end function quintic_hermite_spline_parameters(ddu, du, u, t, idx) diff --git a/src/plot_rec.jl b/src/plot_rec.jl index a7bd8afc..6c576c49 100644 --- a/src/plot_rec.jl +++ b/src/plot_rec.jl @@ -16,7 +16,16 @@ function to_plottable(A::AbstractInterpolation; plotdensity = 10_000, denseplot end @recipe function f(A::AbstractInterpolation; plotdensity = 10_000, denseplot = true) - to_plottable(A; plotdensity = plotdensity, denseplot = denseplot) + @series begin + seriestype := :path + label --> string(nameof(typeof(A))) + to_plottable(A; plotdensity = plotdensity, denseplot = denseplot) + end + @series begin + seriestype := :scatter + label --> "Data points" + A.t, A.u + end end ################################################################################ @@ -35,18 +44,26 @@ end x, y, z; + extrapolate = false, + safetycopy = false, plotdensity = 10_000, denseplot = true) - seriestype := :path - - label --> "Linear fit" - - nx, ny = to_plottable(LinearInterpolation(y, x); + T = promote_type(eltype(y), eltype(x)) + nx, ny = to_plottable(LinearInterpolation(T.(y), T.(x); extrapolate, safetycopy); plotdensity = plotdensity, denseplot = denseplot) - - x := nx - y := ny + @series begin + seriestype := :path + label --> "LinearInterpolation" + x := nx + y := ny + end + @series begin + seriestype := :scatter + label --> "Data points" + x := x + y := y + end end ######################################## @@ -57,18 +74,60 @@ end x, y, z; + mode = :Forward, + extrapolate = false, + safetycopy = false, plotdensity = 10_000, denseplot = true) - seriestype := :path + T = promote_type(eltype(y), eltype(x)) + nx, ny = to_plottable( + QuadraticInterpolation(T.(y), + T.(x), mode; extrapolate, safetycopy); + plotdensity = plotdensity, + denseplot = denseplot) + @series begin + seriestype := :path + label --> "QuadraticInterpolation" + x := nx + y := ny + end + @series begin + seriestype := :scatter + label --> "Data points" + x := x + y := y + end +end - label --> "Quadratic fit" +######################################## +# Lagrange Interpolation # +######################################## - nx, ny = to_plottable(QuadraticInterpolation(T.(y), - T.(x)); +@recipe function f(::Type{Val{:lagrange_interp}}, + x, y, z; + n = length(x) - 1, + extrapolate = false, + safetycopy = false, + plotdensity = 10_000, + denseplot = true) + T = promote_type(eltype(y), eltype(x)) + nx, ny = to_plottable(LagrangeInterpolation(T.(y), + T.(x), + n; extrapolate, safetycopy); plotdensity = plotdensity, denseplot = denseplot) - x := nx - y := ny + @series begin + seriestype := :path + label --> "LagrangeInterpolation" + x := nx + y := ny + end + @series begin + seriestype := :scatter + label --> "Data points" + x := x + y := y + end end ######################################## @@ -79,96 +138,125 @@ end x, y, z; + extrapolate = false, + safetycopy = false, plotdensity = 10_000, denseplot = true) - seriestype := :path - - label --> "Quadratic Spline" - T = promote_type(eltype(y), eltype(x)) - nx, ny = to_plottable(QuadraticSpline(T.(y), - T.(x)); + T.(x); extrapolate, safetycopy); plotdensity = plotdensity, denseplot = denseplot) - - x := nx - y := ny + @series begin + seriestype := :path + label --> "QuadraticSpline" + x := nx + y := ny + end + @series begin + seriestype := :scatter + label --> "Data points" + x := x + y := y + end end ######################################## -# Lagrange Interpolation # +# Cubic Spline # ######################################## -@recipe function f(::Type{Val{:lagrange_interp}}, - x, y, z; - n = length(x) - 1, +@recipe function f(::Type{Val{:cubic_spline}}, + x, + y, + z; + extrapolate = false, + safetycopy = false, plotdensity = 10_000, denseplot = true) - seriestype := :path - - label --> "Lagrange Fit" - T = promote_type(eltype(y), eltype(x)) - - nx, ny = to_plottable(LagrangeInterpolation(T.(y), - T.(x), - n); + nx, ny = to_plottable(CubicSpline(T.(y), + T.(x); extrapolate, safetycopy); plotdensity = plotdensity, denseplot = denseplot) - - x := nx - y := ny + @series begin + seriestype := :path + label --> "CubicSpline" + x := nx + y := ny + end + @series begin + seriestype := :scatter + label --> "Data points" + x := x + y := y + end end ######################################## -# Cubic Spline # +# Akima interpolation # ######################################## -@recipe function f(::Type{Val{:cubic_spline}}, +@recipe function f(::Type{Val{:akima_interp}}, x, y, z; + extrapolate = false, + safetycopy = false, plotdensity = 10_000, denseplot = true) - seriestype := :path - - label --> "Cubic Spline" - T = promote_type(eltype(y), eltype(x)) - - nx, ny = to_plottable(CubicSpline(T.(y), - T.(x)); + nx, ny = to_plottable(AkimaInterpolation(T.(y), + T.(x); extrapolate, safetycopy); plotdensity = plotdensity, denseplot = denseplot) - x := nx - y := ny + @series begin + seriestype := :path + label --> "AkimaInterpolation" + x := nx + y := ny + end + @series begin + seriestype := :scatter + label --> "Data points" + x := x + y := y + end end +######################################## +# B-spline Interpolation # +######################################## + @recipe function f(::Type{Val{:bspline_interp}}, x, y, z; d = 5, - pVec = :ArcLen, - knotVec = :Average, - plotdensity = length(x) * 6, + pVecType = :ArcLen, + knotVecType = :Average, + extrapolate = false, + safetycopy = false, + plotdensity = 10_000, denseplot = true) - seriestype := :path - - label --> "B-Spline" - - @show x y eltype(x) - - # T = promote_type(eltype(y), eltype(x)) - - nx, ny = to_plottable(BSplineInterpolation(T.(y), + T = promote_type(eltype(y), eltype(x)) + nx, ny = to_plottable( + BSplineInterpolation(T.(y), T.(x), d, - pVec, - knotVec); + pVecType, + knotVecType; extrapolate, safetycopy); plotdensity = plotdensity, denseplot = denseplot) - x := nx - y := ny + @series begin + seriestype := :path + label --> "BSplineInterpolation" + x := nx + y := ny + end + @series begin + seriestype := :scatter + label --> "Data points" + x := x + y := y + end end ######################################## @@ -179,48 +267,133 @@ end x, y, z; d = 5, h = length(x) - 1, - pVec = :ArcLen, - knotVec = :Average, - plotdensity = length(x) * 6, + pVecType = :ArcLen, + knotVecType = :Average, + extrapolate = false, + safetycopy = false, + plotdensity = 10_000, denseplot = true) - seriestype := :path - - label --> "B-Spline" - T = promote_type(eltype(y), eltype(x)) - - nx, ny = to_plottable(BSplineApprox(T.(y), + nx, ny = to_plottable( + BSplineApprox(T.(y), T.(x), d, h, - pVec, - knotVec); + pVecType, + knotVecType; extrapolate, safetycopy); plotdensity = plotdensity, denseplot = denseplot) - x := nx - y := ny + @series begin + seriestype := :path + label --> "BSplineApprox" + x := nx + y := ny + end + @series begin + seriestype := :scatter + label --> "Data points" + x := x + y := y + end end ######################################## -# Akima interpolation # +# Cubic Hermite Spline # ######################################## -@recipe function f(::Type{Val{:akima}}, +@recipe function f(::Type{Val{:cubic_hermite_spline}}, x, y, z; - plotdensity = length(x) * 6, + du = nothing, + extrapolate = false, + safetycopy = false, + plotdensity = 10_000, denseplot = true) - seriestype := :path + isnothing(du) && error("Provide `du` as a keyword argument.") + T = promote_type(eltype(y), eltype(x)) + nx, ny = to_plottable( + CubicHermiteSpline(T.(du), T.(y), + T.(x); extrapolate, safetycopy); + plotdensity = plotdensity, + denseplot = denseplot) + @series begin + seriestype := :path + label --> "CubicHermiteSpline" + x := nx + y := ny + end + @series begin + seriestype := :scatter + label --> "Data points" + x := x + y := y + end +end - label --> "Akima" +######################################## +# PCHIP Interpolation # +######################################## +@recipe function f(::Type{Val{:pchip_interp}}, + x, + y, + z; + extrapolate = false, + safetycopy = false, + plotdensity = 10_000, + denseplot = true) T = promote_type(eltype(y), eltype(x)) + nx, ny = to_plottable(PCHIPInterpolation(T.(y), + T.(x); extrapolate, safetycopy); + plotdensity = plotdensity, + denseplot = denseplot) + @series begin + seriestype := :path + label --> "PCHIP Interpolation" + x := nx + y := ny + end + @series begin + seriestype := :scatter + label --> "Data points" + x := x + y := y + end +end - nx, ny = to_plottable(AkimaInterpolation(T.(y), - T.(x)); +######################################## +# Quintic Hermite Spline # +######################################## + +@recipe function f(::Type{Val{:quintic_hermite_spline}}, + x, + y, + z; + du = nothing, + ddu = nothing, + extrapolate = false, + safetycopy = false, + plotdensity = 10_000, + denseplot = true) + (isnothing(du) || isnothing(ddu)) && + error("Provide `du` and `ddu` as keyword arguments.") + T = promote_type(eltype(y), eltype(x)) + nx, ny = to_plottable( + QuinticHermiteSpline(T.(ddu), T.(du), T.(y), + T.(x); extrapolate, safetycopy); plotdensity = plotdensity, denseplot = denseplot) - x := nx - y := ny + @series begin + seriestype := :path + label --> "QuinticHermiteSpline" + x := nx + y := ny + end + @series begin + seriestype := :scatter + label --> "Data points" + x := x + y := y + end end diff --git a/test/derivative_tests.jl b/test/derivative_tests.jl index 37351d0d..50abe4ac 100644 --- a/test/derivative_tests.jl +++ b/test/derivative_tests.jl @@ -82,6 +82,12 @@ end u = vcat(2.0collect(1:10)', 3.0collect(1:10)') test_derivatives( LinearInterpolation; args = [u, t], name = "Linear Interpolation (Matrix)") + + # Issue: https://github.com/SciML/DataInterpolations.jl/issues/303 + u = [3.0, 3.0] + t = [0.0, 2.0] + test_derivatives( + LinearInterpolation; args = [u, t], name = "Linear Interpolation with two points") end @testset "Quadratic Interpolation" begin diff --git a/test/interface.jl b/test/interface.jl index 5d02a22a..3e910547 100644 --- a/test/interface.jl +++ b/test/interface.jl @@ -1,20 +1,52 @@ using DataInterpolations -u = 2.0collect(1:10) -t = 1.0collect(1:10) -A = LinearInterpolation(u, t) +using Symbolics -for i in 1:10 - @test u[i] == A.u[i] -end +@testset "Interface" begin + u = 2.0collect(1:10) + t = 1.0collect(1:10) + A = LinearInterpolation(u, t) + + for i in 1:10 + @test u[i] == A.u[i] + end -for i in 1:10 - @test t[i] == A.t[i] + for i in 1:10 + @test t[i] == A.t[i] + end end -using Symbolics -u = 2.0collect(1:10) -t = 1.0collect(1:10) -A = LinearInterpolation(u, t) +@testset "Symbolics" begin + u = 2.0collect(1:10) + t = 1.0collect(1:10) + A = LinearInterpolation(u, t) + @variables t x(t) + substitute(A(t), Dict(t => x)) +end -@variables t x(t) -substitute(A(t), Dict(t => x)) +@testset "Type Inference" begin + u = 2.0collect(1:10) + t = 1.0collect(1:10) + methods = [ + ConstantInterpolation, LinearInterpolation, + QuadraticInterpolation, LagrangeInterpolation, + QuadraticSpline, CubicSpline, AkimaInterpolation + ] + @testset "$method" for method in methods + @inferred method(u, t) + end + @testset "BSplineInterpolation" begin + @inferred BSplineInterpolation(u, t, 3, :Uniform, :Uniform) + @inferred BSplineInterpolation(u, t, 3, :ArcLen, :Average) + end + @testset "BSplineApprox" begin + @inferred BSplineApprox(u, t, 3, 5, :Uniform, :Uniform) + @inferred BSplineApprox(u, t, 3, 5, :ArcLen, :Average) + end + du = ones(10) + ddu = zeros(10) + @testset "Hermite Splines" begin + @inferred CubicHermiteSpline(du, u, t) + @inferred PCHIPInterpolation(u, t) + @inferred QuinticHermiteSpline(ddu, du, u, t) + end +end diff --git a/test/interpolation_tests.jl b/test/interpolation_tests.jl index 9562c7b4..53fff25e 100644 --- a/test/interpolation_tests.jl +++ b/test/interpolation_tests.jl @@ -160,6 +160,13 @@ end @test A(5.5) == fill(11.0) @test A(11) == fill(22) + # Test constant -Inf interpolation + u = [-Inf, -Inf] + t = [0.0, 1.0] + A = LinearInterpolation(u, t) + @test A(0.0) == -Inf + @test A(0.5) == -Inf + # Test extrapolation u = 2.0collect(1:10) t = 1.0collect(1:10) @@ -169,6 +176,7 @@ end A = LinearInterpolation(u, t) @test_throws DataInterpolations.ExtrapolationError A(-1.0) @test_throws DataInterpolations.ExtrapolationError A(11.0) + @test_throws DataInterpolations.ExtrapolationError A([-1.0, 11.0]) end @testset "Quadratic Interpolation" begin @@ -669,6 +677,18 @@ end @test_throws AssertionError CubicHermiteSpline(du, u, t) end +@testset "PCHIPInterpolation" begin + u = [14.7, 11.51, 10.41, 14.95, 12.24, 11.22] + t = [0.0, 62.25, 109.66, 162.66, 205.8, 250.0] + A = PCHIPInterpolation(u, t) + @test A isa CubicHermiteSpline + ts = 0.0:0.1:250.0 + us = A(ts) + @test all(minimum(u) .<= us) + @test all(maximum(u) .>= us) + @test all(A.du[3:4] .== 0.0) +end + @testset "Quintic Hermite Spline" begin test_interpolation_type(QuinticHermiteSpline) diff --git a/test/online_tests.jl b/test/online_tests.jl index 3ae6438e..1872e0cc 100644 --- a/test/online_tests.jl +++ b/test/online_tests.jl @@ -6,9 +6,11 @@ u1 = [0.0, 1.0, 0.0] t2 = [4.0, 5.0, 6.0] u2 = [1.0, 2.0, 1.0] -ts = 1.0:0.5:6.0 +ts_append = 1.0:0.5:6.0 +ts_push = 1.0:0.5:4.0 -for method in [LinearInterpolation, QuadraticInterpolation, ConstantInterpolation] +@testset "$method" for method in [ + LinearInterpolation, QuadraticInterpolation, ConstantInterpolation] func1 = method(copy(u1), copy(t1); cache_parameters = true) append!(func1, u2, t2) func2 = method(vcat(u1, u2), vcat(t1, t2); cache_parameters = true) @@ -17,7 +19,7 @@ for method in [LinearInterpolation, QuadraticInterpolation, ConstantInterpolatio for name in propertynames(func1.p) @test getfield(func1.p, name) == getfield(func2.p, name) end - @test func1(ts) == func2(ts) + @test func1(ts_append) == func2(ts_append) @test func1.I == func2.I func1 = method(copy(u1), copy(t1); cache_parameters = true) @@ -28,6 +30,6 @@ for method in [LinearInterpolation, QuadraticInterpolation, ConstantInterpolatio for name in propertynames(func1.p) @test getfield(func1.p, name) == getfield(func2.p, name) end - @test func1(ts) == func2(ts) + @test func1(ts_push) == func2(ts_push) @test func1.I == func2.I end From a5e0f2061a3e0dfa8492e2d1365efbcf878b0207 Mon Sep 17 00:00:00 2001 From: Bart de Koning Date: Sat, 27 Jul 2024 21:08:17 +0200 Subject: [PATCH 4/9] Zygote compat --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index ccef602b..9219a059 100644 --- a/Project.toml +++ b/Project.toml @@ -39,7 +39,7 @@ SafeTestsets = "0.1" StableRNGs = "1" Symbolics = "5.29" Test = "1" -Zygote = "0.6" +Zygote = "0" julia = "1.10" [extras] From 8f44fa36f0b0f426b07b327f2235063d9ca2a4ff Mon Sep 17 00:00:00 2001 From: Bart de Koning Date: Sat, 27 Jul 2024 21:11:27 +0200 Subject: [PATCH 5/9] Zygote compat --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 9219a059..baf7d6ac 100644 --- a/Project.toml +++ b/Project.toml @@ -39,7 +39,7 @@ SafeTestsets = "0.1" StableRNGs = "1" Symbolics = "5.29" Test = "1" -Zygote = "0" +Zygote = "0.*" julia = "1.10" [extras] From 53905c141302e681b69a1e9c368b873f2c8b80f7 Mon Sep 17 00:00:00 2001 From: Bart de Koning Date: Sat, 27 Jul 2024 21:13:57 +0200 Subject: [PATCH 6/9] Zygote compat --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index baf7d6ac..518bfc2f 100644 --- a/Project.toml +++ b/Project.toml @@ -39,7 +39,7 @@ SafeTestsets = "0.1" StableRNGs = "1" Symbolics = "5.29" Test = "1" -Zygote = "0.*" +Zygote = "^0" julia = "1.10" [extras] From 8a3fd7ba480fa7b911d646025c150b51231dc25a Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Sat, 27 Jul 2024 18:03:17 -0400 Subject: [PATCH 7/9] Update Project.toml --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 518bfc2f..ccef602b 100644 --- a/Project.toml +++ b/Project.toml @@ -39,7 +39,7 @@ SafeTestsets = "0.1" StableRNGs = "1" Symbolics = "5.29" Test = "1" -Zygote = "^0" +Zygote = "0.6" julia = "1.10" [extras] From b8370765536e3653d6bb1a8c38e71ab4c7bbe2dc Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Sat, 27 Jul 2024 18:03:33 -0400 Subject: [PATCH 8/9] Update Project.toml --- Project.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/Project.toml b/Project.toml index ccef602b..c9b7a2e9 100644 --- a/Project.toml +++ b/Project.toml @@ -15,6 +15,7 @@ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Optim = "429524aa-4258-5aef-a3af-852621145aeb" RegularizationTools = "29dad682-9a27-4bc3-9c72-016788665182" Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [extensions] DataInterpolationsChainRulesCoreExt = "ChainRulesCore" From e0cce0b491e7c05e1e3cab9876f105a0ecdad442 Mon Sep 17 00:00:00 2001 From: Bart de Koning Date: Sun, 28 Jul 2024 17:04:04 +0200 Subject: [PATCH 9/9] Add examples of how to speed up gradients w.r.t. u, add more zygote tests --- ext/DataInterpolationsChainRulesCoreExt.jl | 78 ++++++++++++++++++++-- test/zygote_tests.jl | 65 +++++++++++++----- 2 files changed, 124 insertions(+), 19 deletions(-) diff --git a/ext/DataInterpolationsChainRulesCoreExt.jl b/ext/DataInterpolationsChainRulesCoreExt.jl index 34e27841..9c33b09c 100644 --- a/ext/DataInterpolationsChainRulesCoreExt.jl +++ b/ext/DataInterpolationsChainRulesCoreExt.jl @@ -1,19 +1,87 @@ module DataInterpolationsChainRulesCoreExt - if isdefined(Base, :get_extension) using DataInterpolations: _interpolate, derivative, AbstractInterpolation, + LinearInterpolation, QuadraticInterpolation, LagrangeInterpolation, AkimaInterpolation, - BSplineInterpolation, BSplineApprox + BSplineInterpolation, BSplineApprox, get_idx, get_parameters, + _quad_interp_indices using ChainRulesCore else using ..DataInterpolations: _interpolate, derivative, AbstractInterpolation, + LinearInterpolation, QuadraticInterpolation, LagrangeInterpolation, AkimaInterpolation, - BSplineInterpolation, BSplineApprox + BSplineInterpolation, BSplineApprox, get_parameters, + _quad_interp_indices using ..ChainRulesCore end +function ChainRulesCore.rrule( + ::Type{LinearInterpolation}, u, t, I, p, extrapolate, cache_parameters) + A = LinearInterpolation(u, t, I, p, extrapolate, cache_parameters) + function LinearInterpolation_pullback(ΔA) + df = NoTangent() + du = ΔA.u + dt = NoTangent() + dI = NoTangent() + dp = NoTangent() + dextrapolate = NoTangent() + dcache_parameters = NoTangent() + df, du, dt, dI, dp, dextrapolate, dcache_parameters + end + + A, LinearInterpolation_pullback +end + +function ChainRulesCore.rrule( + ::Type{QuadraticInterpolation}, u, t, I, p, mode, extrapolate, cache_parameters) + A = QuadraticInterpolation(u, t, I, p, mode, extrapolate, cache_parameters) + function LinearInterpolation_pullback(ΔA) + df = NoTangent() + du = ΔA.u + dt = NoTangent() + dI = NoTangent() + dp = NoTangent() + dmode = NoTangent() + dextrapolate = NoTangent() + dcache_parameters = NoTangent() + df, du, dt, dI, dp, dmode, dextrapolate, dcache_parameters + end + + A, LinearInterpolation_pullback +end + +function u_tangent(A::LinearInterpolation, t, Δ) + out = zero(A.u) + idx = get_idx(A.t, t, A.idx_prev[]) + t_factor = (t - A.t[idx]) / (A.t[idx + 1] - A.t[idx]) + out[idx] = Δ * (one(eltype(out)) - t_factor) + out[idx + 1] = Δ * t_factor + out +end + +function u_tangent(A::QuadraticInterpolation, t, Δ) + out = zero(A.u) + i₀, i₁, i₂ = _quad_interp_indices(A, t, A.idx_prev[]) + t₀ = A.t[i₀] + t₁ = A.t[i₁] + t₂ = A.t[i₂] + Δt₀ = t₁ - t₀ + Δt₁ = t₂ - t₁ + Δt₂ = t₂ - t₀ + out[i₀] = Δ * (t - A.t[i₁]) * (t - A.t[i₂]) / (Δt₀ * Δt₂) + out[i₁] = -Δ * (t - A.t[i₀]) * (t - A.t[i₂]) / (Δt₀ * Δt₁) + out[i₂] = Δ * (t - A.t[i₀]) * (t - A.t[i₁]) / (Δt₂ * Δt₁) + out +end + +function u_tangent(A, t, Δ) + NoTangent() +end + function ChainRulesCore.rrule(::typeof(_interpolate), A::Union{ + LinearInterpolation, + QuadraticInterpolation, LagrangeInterpolation, AkimaInterpolation, BSplineInterpolation, @@ -21,7 +89,9 @@ function ChainRulesCore.rrule(::typeof(_interpolate), }, t::Number) deriv = derivative(A, t) - interpolate_pullback(Δ) = (NoTangent(), NoTangent(), deriv * Δ) + function interpolate_pullback(Δ) + (NoTangent(), Tangent{typeof(A)}(; u = u_tangent(A, t, Δ)), deriv * Δ) + end return _interpolate(A, t), interpolate_pullback end diff --git a/test/zygote_tests.jl b/test/zygote_tests.jl index 6887ddd2..1a7fc447 100644 --- a/test/zygote_tests.jl +++ b/test/zygote_tests.jl @@ -2,8 +2,8 @@ using DataInterpolations using ForwardDiff using Zygote -function test_zygote(method, u, t; args = [], kwargs = [], name::String) - func = method(args..., u, t; kwargs..., extrapolate = true) +function test_zygote(method, u, t; args = [], args_after = [], kwargs = [], name::String) + func = method(args..., u, t, args_after...; kwargs..., extrapolate = true) (; u, t) = func trange = collect(range(minimum(t) - 5.0, maximum(t) + 5.0, step = 0.1)) trange_exclude = filter(x -> !in(x, t), trange) @@ -11,28 +11,30 @@ function test_zygote(method, u, t; args = [], kwargs = [], name::String) for _t in trange_exclude adiff = DataInterpolations.derivative(func, _t) zdiff = only(Zygote.gradient(func, _t)) - zdiff == nothing && (zdiff = 0.0) + isnothing(zdiff) && (zdiff = 0.0) @test adiff ≈ zdiff end end - @testset "$name, derivatives w.r.t. u" begin - function f(u) - A = method(args..., u, t; kwargs..., extrapolate = true) - out = zero(eltype(u)) - for _t in trange - out += A(_t) + if method ∉ [LagrangeInterpolation, BSplineInterpolation, BSplineApprox] + @testset "$name, derivatives w.r.t. u" begin + function f(u) + A = method(args..., u, t, args_after...; kwargs..., extrapolate = true) + out = zero(eltype(u)) + for _t in trange + out += A(_t) + end + out end - out + zgrad = only(Zygote.gradient(f, u)) + fgrad = ForwardDiff.gradient(f, u) + @test zgrad ≈ fgrad end - zgrad = only(Zygote.gradient(f, u)) - fgrad = ForwardDiff.gradient(f, u) - @test zgrad ≈ fgrad end end @testset "LinearInterpolation" begin - u = vcat(collect(1:5), 2 * collect(6:10)) - t = 1.0collect(1:10) + u = vcat(collect(1.0:5.0), 2 * collect(6.0:10.0)) + t = collect(1.0:10.0) test_zygote( LinearInterpolation, u, t; name = "Linear Interpolation") end @@ -64,3 +66,36 @@ end test_zygote( QuinticHermiteSpline, u, t, args = [ddu, du], name = "Quintic Hermite Spline") end + +@testset "Quadratic Spline" begin + u = [1.0, 4.0, 9.0, 16.0] + t = [1.0, 2.0, 3.0, 4.0] + test_zygote(QuadraticSpline, u, t, name = "Quadratic Spline") +end + +@testset "Lagrange Interpolation" begin + u = [1.0, 4.0, 9.0] + t = [1.0, 2.0, 3.0] + test_zygote(LagrangeInterpolation, u, t, name = "Lagrange Interpolation") +end + +@testset "Constant Interpolation" begin + u = [0.0, 2.0, 1.0, 3.0, 2.0, 6.0, 5.5, 5.5, 2.7, 5.1, 3.0] + t = collect(0.0:10.0) + test_zygote(ConstantInterpolation, u, t, name = "Constant Interpolation") +end + +@testset "Cubic Spline" begin + u = [0.0, 1.0, 3.0] + t = [-1.0, 0.0, 1.0] + test_zygote(CubicSpline, u, t, name = "Cubic Spline") +end + +@testset "BSplines" begin + t = [0, 62.25, 109.66, 162.66, 205.8, 252.3] + u = [14.7, 11.51, 10.41, 14.95, 12.24, 11.22] + test_zygote(BSplineInterpolation, u, t; args_after = [2, :Uniform, :Uniform], + name = "BSpline Interpolation") + test_zygote(BSplineApprox, u, t; args_after = [2, 4, :Uniform, :Uniform], + name = "BSpline approximation") +end