From f3350bb581dbd99b768a50eec09e14eeb0cf7389 Mon Sep 17 00:00:00 2001 From: Sathvik Bhagavan Date: Sat, 6 Jul 2024 18:57:11 +0000 Subject: [PATCH 1/4] refactor: use BaryCentric Formula for Lagrange Interpolation --- src/interpolation_caches.jl | 27 +++++++++------- src/interpolation_methods.jl | 60 +++++++++--------------------------- src/interpolation_utils.jl | 32 +++---------------- src/parameter_caches.jl | 24 +++++++++++++++ 4 files changed, 59 insertions(+), 84 deletions(-) diff --git a/src/interpolation_caches.jl b/src/interpolation_caches.jl index bb6339fe..7e57117e 100644 --- a/src/interpolation_caches.jl +++ b/src/interpolation_caches.jl @@ -94,25 +94,24 @@ It is the method of interpolation using Lagrange polynomials of (k-1)th order pa - `extrapolate`: boolean value to allow extrapolation. Defaults to `false`. - `safetycopy`: boolean value to make a copy of `u` and `t`. Defaults to `true`. """ -struct LagrangeInterpolation{uType, tType, T, bcacheType} <: +struct LagrangeInterpolation{uType, tType, duType, T} <: AbstractInterpolation{T} u::uType t::tType n::Int - bcache::bcacheType - idxs::Vector{Int} + p::LagrangeParameterCache + du::duType + derp::LagrangeParameterCache extrapolate::Bool idx_prev::Base.RefValue{Int} safetycopy::Bool - function LagrangeInterpolation(u, t, n, extrapolate, safetycopy) - bcache = zeros(eltype(u[1]), n + 1) - idxs = zeros(Int, n + 1) - fill!(bcache, NaN) - new{typeof(u), typeof(t), eltype(u), typeof(bcache)}(u, + function LagrangeInterpolation(u, t, n, p, du, derp, extrapolate, safetycopy) + new{typeof(u), typeof(t), typeof(du), eltype(u)}(u, t, n, - bcache, - idxs, + p, + du, + derp, extrapolate, Ref(1), safetycopy @@ -126,7 +125,13 @@ function LagrangeInterpolation( if n != length(t) - 1 error("Currently only n=length(t) - 1 is supported") end - LagrangeInterpolation(u, t, n, extrapolate, safetycopy) + p = lagrange_parameter_cache(u.parent, t.parent) + derpw = similar(t.parent) + derpw .= NaN + # Vector{Union{eltype(u), Missing}}(missing, s) + derp = LagrangeParameterCache(derpw, similar(u.parent)) + du = similar(u.parent) + LagrangeInterpolation(u, t, n, p, du, derp, extrapolate, safetycopy) end """ diff --git a/src/interpolation_methods.jl b/src/interpolation_methods.jl index 5d09ceff..a36aba4a 100644 --- a/src/interpolation_methods.jl +++ b/src/interpolation_methods.jl @@ -58,61 +58,29 @@ end # Lagrange Interpolation function _interpolate(A::LagrangeInterpolation{<:AbstractVector}, t::Number, iguess) - idx = get_idx(A.t, t, iguess) - findRequiredIdxs!(A, t, idx) - if A.t[A.idxs[1]] == t - return A.u[A.idxs[1]], idx - end + idx = _searchsortedfirst(A.t, t) + !isnothing(idx) && return A.u[idx], idx N = zero(A.u[1]) D = zero(A.t[1]) - tmp = N - for i in 1:length(A.idxs) - if isnan(A.bcache[A.idxs[i]]) - mult = one(A.t[1]) - for j in 1:(i - 1) - mult *= (A.t[A.idxs[i]] - A.t[A.idxs[j]]) - end - for j in (i + 1):length(A.idxs) - mult *= (A.t[A.idxs[i]] - A.t[A.idxs[j]]) - end - A.bcache[A.idxs[i]] = mult - else - mult = A.bcache[A.idxs[i]] - end - tmp = inv((t - A.t[A.idxs[i]]) * mult) - D += tmp - N += (tmp * A.u[A.idxs[i]]) + for i in 1:(A.n + 1) + ti = t - A.t[i] + N += (A.p.wu[i]) / ti + D += (A.p.w[i]) / ti end - N / D, idx + N / D, iguess end function _interpolate(A::LagrangeInterpolation{<:AbstractMatrix}, t::Number, iguess) - idx = get_idx(A.t, t, iguess) - findRequiredIdxs!(A, t, idx) - if A.t[A.idxs[1]] == t - return A.u[:, A.idxs[1]], idx - end + idx = _searchsortedfirst(A.t, t) + !isnothing(idx) && return A.u[:, idx], idx N = zero(A.u[:, 1]) D = zero(A.t[1]) - tmp = D - for i in 1:length(A.idxs) - if isnan(A.bcache[A.idxs[i]]) - mult = one(A.t[1]) - for j in 1:(i - 1) - mult *= (A.t[A.idxs[i]] - A.t[A.idxs[j]]) - end - for j in (i + 1):length(A.idxs) - mult *= (A.t[A.idxs[i]] - A.t[A.idxs[j]]) - end - A.bcache[A.idxs[i]] = mult - else - mult = A.bcache[A.idxs[i]] - end - tmp = inv((t - A.t[A.idxs[i]]) * mult) - D += tmp - @. N += (tmp * A.u[:, A.idxs[i]]) + for i in 1:(A.n + 1) + ti = t - A.t[i] + @. N += (A.p.wu[:, i]) / ti + D += (A.p.w[i]) / ti end - N / D, idx + N / D, iguess end function _interpolate(A::AkimaInterpolation{<:AbstractVector}, t::Number, iguess) diff --git a/src/interpolation_utils.jl b/src/interpolation_utils.jl index 370c4062..1b6b1979 100644 --- a/src/interpolation_utils.jl +++ b/src/interpolation_utils.jl @@ -1,32 +1,10 @@ -function findRequiredIdxs!(A::LagrangeInterpolation, t, idx) - n = length(A.t) - 1 - i_min, idx_min, idx_max = if t == A.t[idx] - A.idxs[1] = idx - 2, idx, idx +function _searchsortedfirst(ts, t) + idx = searchsortedfirst(ts, t) + if idx > lastindex(ts) || ts[idx] != t + return nothing else - 1, idx + 1, idx + return idx end - for i in i_min:(n + 1) - if idx_min == 1 - A.idxs[i:end] .= range(idx_max + 1, idx_max + (n + 2 - i)) - break - elseif idx_max == length(A.t) - A.idxs[i:end] .= (idx_min - 1):-1:(idx_min - (n + 2 - i)) - break - else - left_diff = abs(t - A.t[idx_min - 1]) - right_diff = abs(t - A.t[idx_max + 1]) - left_expand = left_diff <= right_diff - end - if left_expand - idx_min -= 1 - A.idxs[i] = idx_min - else - idx_max += 1 - A.idxs[i] = idx_max - end - end - return idx end function spline_coefficients!(N, d, k, u::Number) diff --git a/src/parameter_caches.jl b/src/parameter_caches.jl index 2820dc8f..7fdf18ac 100644 --- a/src/parameter_caches.jl +++ b/src/parameter_caches.jl @@ -50,6 +50,30 @@ function quadratic_interpolation_parameters(u, t, idx) return l₀, l₁, l₂ end +struct LagrangeParameterCache{wType, wuType} + w::wType + wu::wuType +end + +function lagrange_parameter_cache(u, t) + w = similar(t) + wu = similar(u) + for i in eachindex(w) + mul = one(eltype(t)) + for j in eachindex(t) + i != j && (mul *= (t[i] - t[j])) + end + w[i] = inv(mul) + val = u isa Matrix ? w[i] .* u[:, i] : w[i] * u[i] + if u isa Matrix + wu[:, i] .= val + else + wu[i] = val + end + end + return LagrangeParameterCache(w, wu) +end + struct QuadraticSplineParameterCache{pType} σ::pType end From c17fc2657e93252abc010d6633f565ddcc05fef4 Mon Sep 17 00:00:00 2001 From: Sathvik Bhagavan Date: Sat, 6 Jul 2024 18:58:07 +0000 Subject: [PATCH 2/4] refactor: derivatives with BaryCentric Formula for Lagrange Interpolation --- src/derivatives.jl | 100 +++++++++++++++++++++------------------------ 1 file changed, 46 insertions(+), 54 deletions(-) diff --git a/src/derivatives.jl b/src/derivatives.jl index 30c76fd0..30a2defa 100644 --- a/src/derivatives.jl +++ b/src/derivatives.jl @@ -32,70 +32,62 @@ end function _derivative(A::LagrangeInterpolation{<:AbstractVector}, t::Number) ((t < A.t[1] || t > A.t[end]) && !A.extrapolate) && throw(ExtrapolationError()) - der = zero(A.u[1]) - for j in eachindex(A.t) - tmp = zero(A.t[1]) - if isnan(A.bcache[j]) - mult = one(A.t[1]) - for i in 1:(j - 1) - mult *= (A.t[j] - A.t[i]) - end - for i in (j + 1):length(A.t) - mult *= (A.t[j] - A.t[i]) - end - A.bcache[j] = mult - else - mult = A.bcache[j] - end - for l in eachindex(A.t) - if l != j - k = one(A.t[1]) - for m in eachindex(A.t) - if m != j && m != l - k *= (t - A.t[m]) - end - end - k *= inv(mult) - tmp += k + if all(isnan.(A.derp.w)) + for i in eachindex(A.u) + deru1 = zero(A.du[1]) + deru2 = zero(A.t.parent[1]) + for j in eachindex(A.t) + i == j && continue + val = (A.p.w[j] / A.p.w[i]) / (A.t[i] - A.t[j]) + deru1 += val * A.u[j] + deru2 += val end + A.du[i] = deru1 - deru2 * A.u[i] end - der += A.u[j] * tmp + der_temp_p = lagrange_parameter_cache(A.du, A.t.parent) + A.derp.w .= der_temp_p.w + A.derp.wu .= der_temp_p.wu + end + idx = _searchsortedfirst(A.t, t) + !isnothing(idx) && return A.du[idx] + N = zero(A.du[1]) + D = zero(A.t[1]) + for i in 1:(A.n + 1) + ti = t - A.t[i] + N += (A.derp.wu[i]) / ti + D += (A.derp.w[i]) / ti end - der + N / D end function _derivative(A::LagrangeInterpolation{<:AbstractMatrix}, t::Number) ((t < A.t[1] || t > A.t[end]) && !A.extrapolate) && throw(ExtrapolationError()) - der = zero(A.u[:, 1]) - for j in eachindex(A.t) - tmp = zero(A.t[1]) - if isnan(A.bcache[j]) - mult = one(A.t[1]) - for i in 1:(j - 1) - mult *= (A.t[j] - A.t[i]) - end - for i in (j + 1):length(A.t) - mult *= (A.t[j] - A.t[i]) - end - A.bcache[j] = mult - else - mult = A.bcache[j] - end - for l in eachindex(A.t) - if l != j - k = one(A.t[1]) - for m in eachindex(A.t) - if m != j && m != l - k *= (t - A.t[m]) - end - end - k *= inv(mult) - tmp += k + if all(isnan.(A.derp.w)) + for i in 1:size(A.u, 2) + deru1 = zero(A.du[:, 1]) + deru2 = zero(A.t.parent[1]) + for j in eachindex(A.t) + i == j && continue + val = (A.p.w[j] / A.p.w[i]) / (A.t[i] - A.t[j]) + @. deru1 += val * A.u[:, j] + deru2 += val end + @. A.du[:, i] = deru1 - deru2 * A.u[:, i] end - der += A.u[:, j] * tmp + der_temp_p = lagrange_parameter_cache(A.du, A.t.parent) + A.derp.w .= der_temp_p.w + A.derp.wu .= der_temp_p.wu + end + idx = _searchsortedfirst(A.t, t) + !isnothing(idx) && return A.du[:, idx] + N = zeros(promote_type(eltype(A.u), eltype(t)), length(A.u[:, 1])) + D = zero(A.t[1]) + for i in 1:(A.n + 1) + ti = t - A.t[i] + @. N += (A.derp.wu[:, i]) / ti + D += (A.derp.w[i]) / ti end - der + N / D end function _derivative(A::LagrangeInterpolation{<:AbstractVector}, t::Number, idx) From d5786eabd7e8cecee964a0a5155035855cc0739b Mon Sep 17 00:00:00 2001 From: Sathvik Bhagavan Date: Sat, 6 Jul 2024 18:58:54 +0000 Subject: [PATCH 3/4] test: approx for an exterior point instead of exact equality test --- test/interpolation_tests.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/interpolation_tests.jl b/test/interpolation_tests.jl index 779c5bd8..2db72c96 100644 --- a/test/interpolation_tests.jl +++ b/test/interpolation_tests.jl @@ -323,7 +323,7 @@ end t = [1.0, 2.0, 3.0] A = LagrangeInterpolation(u, t; extrapolate = true) @test A(0.0) == 0.0 - @test A(4.0) == 16.0 + @test A(4.0) ≈ 16.0 A = LagrangeInterpolation(u, t) @test_throws DataInterpolations.ExtrapolationError A(-1.0) @test_throws DataInterpolations.ExtrapolationError A(4.0) From ab433952a9f5f6ad42253958253271b8814e8604 Mon Sep 17 00:00:00 2001 From: Sathvik Bhagavan Date: Sun, 7 Jul 2024 07:46:28 +0000 Subject: [PATCH 4/4] fixup! refactor: use BaryCentric Formula for Lagrange Interpolation --- src/interpolation_caches.jl | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/interpolation_caches.jl b/src/interpolation_caches.jl index 7e57117e..c5e292ae 100644 --- a/src/interpolation_caches.jl +++ b/src/interpolation_caches.jl @@ -128,9 +128,8 @@ function LagrangeInterpolation( p = lagrange_parameter_cache(u.parent, t.parent) derpw = similar(t.parent) derpw .= NaN - # Vector{Union{eltype(u), Missing}}(missing, s) derp = LagrangeParameterCache(derpw, similar(u.parent)) - du = similar(u.parent) + du = !(u.parent[1] isa AbstractVector || u.parent[1] isa AbstractMatrix) ? similar(u.parent) : similar.(u.parent) LagrangeInterpolation(u, t, n, p, du, derp, extrapolate, safetycopy) end