Skip to content

Commit

Permalink
Merge pull request #354 from SouthEndMusic/stable_quadratic_spline
Browse files Browse the repository at this point in the history
Update `QuadraticSpline`
  • Loading branch information
ChrisRackauckas authored Nov 12, 2024
2 parents d0dd800 + e5bd8dc commit 26d6f11
Show file tree
Hide file tree
Showing 10 changed files with 117 additions and 81 deletions.
8 changes: 5 additions & 3 deletions src/derivatives.jl
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,11 @@ end

# QuadraticSpline Interpolation
function _derivative(A::QuadraticSpline{<:AbstractVector}, t::Number, iguess)
idx = get_idx(A, t, iguess; lb = 2, ub_shift = 0, side = :first)
σ = get_parameters(A, idx - 1)
A.z[idx - 1] + 2σ * (t - A.t[idx - 1])
idx = get_idx(A, t, iguess)
α, β = get_parameters(A, idx)
Δt = t - A.t[idx]
Δt_full = A.t[idx + 1] - A.t[idx]
2α * Δt / Δt_full^2 + β / Δt_full
end

# CubicSpline Interpolation
Expand Down
7 changes: 4 additions & 3 deletions src/integrals.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,11 @@ function _integral(A::QuadraticInterpolation{<:AbstractVector{<:Number}},
end

function _integral(A::QuadraticSpline{<:AbstractVector{<:Number}}, idx::Number, t::Number)
Cᵢ = A.u[idx]
α, β = get_parameters(A, idx)
uᵢ = A.u[idx]
Δt = t - A.t[idx]
σ = get_parameters(A, idx)
return A.z[idx] * Δt^2 / 2 + σ * Δt^3 / 3 + Cᵢ * Δt
Δt_full = A.t[idx + 1] - A.t[idx]
Δt ** Δt^2 / (3Δt_full^2) + β * Δt / (2Δt_full) + uᵢ)
end

function _integral(A::CubicSpline{<:AbstractVector{<:Number}}, idx::Number, t::Number)
Expand Down
77 changes: 37 additions & 40 deletions src/interpolation_caches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -305,31 +305,31 @@ Extrapolation extends the last quadratic polynomial on each side.
for a test based on the normalized standard deviation of the difference with respect
to the straight line (see [`looks_linear`](@ref)). Defaults to 1e-2.
"""
struct QuadraticSpline{uType, tType, IType, pType, tAType, dType, zType, T, N} <:
struct QuadraticSpline{uType, tType, IType, pType, kType, cType, scType, T, N} <:
AbstractInterpolation{T, N}
u::uType
t::tType
I::IType
p::QuadraticSplineParameterCache{pType}
tA::tAType
d::dType
z::zType
k::kType # knot vector
c::cType # B-spline control points
sc::scType # Spline coefficients (preallocated memory)
extrapolate::Bool
iguesser::Guesser{tType}
cache_parameters::Bool
linear_lookup::Bool
function QuadraticSpline(
u, t, I, p, tA, d, z, extrapolate, cache_parameters, assume_linear_t)
u, t, I, p, k, c, sc, extrapolate, cache_parameters, assume_linear_t)
linear_lookup = seems_linear(assume_linear_t, t)
N = get_output_dim(u)
new{typeof(u), typeof(t), typeof(I), typeof(p.σ), typeof(tA),
typeof(d), typeof(z), eltype(u), N}(u,
new{typeof(u), typeof(t), typeof(I), typeof(p.α), typeof(k),
typeof(c), typeof(sc), eltype(u), N}(u,
t,
I,
p,
tA,
d,
z,
k,
c,
sc,
extrapolate,
Guesser(t),
cache_parameters,
Expand All @@ -343,50 +343,47 @@ function QuadraticSpline(
cache_parameters = false, assume_linear_t = 1e-2) where {uType <:
AbstractVector{<:Number}}
u, t = munge_data(u, t)
linear_lookup = seems_linear(assume_linear_t, t)
s = length(t)
dl = ones(eltype(t), s - 1)
d_tmp = ones(eltype(t), s)
du = zeros(eltype(t), s - 1)
tA = Tridiagonal(dl, d_tmp, du)

# zero for element type of d, which we don't know yet
typed_zero = zero(2 // 1 * (u[begin + 1] - u[begin]) / (t[begin + 1] - t[begin]))

d = map(i -> i == 1 ? typed_zero : 2 // 1 * (u[i] - u[i - 1]) / (t[i] - t[i - 1]), 1:s)
z = tA \ d
n = length(t)
dtype_sc = typeof(t[1] / t[1])
sc = zeros(dtype_sc, n)
k, A = quadratic_spline_params(t, sc)
c = A \ u

p = QuadraticSplineParameterCache(z, t, cache_parameters)
p = QuadraticSplineParameterCache(u, t, k, c, sc, cache_parameters)
A = QuadraticSpline(
u, t, nothing, p, tA, d, z, extrapolate, cache_parameters, linear_lookup)
u, t, nothing, p, k, c, sc, extrapolate, cache_parameters, assume_linear_t)
I = cumulative_integral(A, cache_parameters)
QuadraticSpline(u, t, I, p, tA, d, z, extrapolate, cache_parameters, linear_lookup)
QuadraticSpline(u, t, I, p, k, c, sc, extrapolate, cache_parameters, assume_linear_t)
end

function QuadraticSpline(
u::uType, t; extrapolate = false, cache_parameters = false,
assume_linear_t = 1e-2) where {uType <:
AbstractVector}
u, t = munge_data(u, t)
linear_lookup = seems_linear(assume_linear_t, t)
s = length(t)
dl = ones(eltype(t), s - 1)
d_tmp = ones(eltype(t), s)
du = zeros(eltype(t), s - 1)
tA = Tridiagonal(dl, d_tmp, du)
d_ = map(
i -> i == 1 ? zeros(eltype(t), size(u[1])) :
2 // 1 * (u[i] - u[i - 1]) / (t[i] - t[i - 1]),
1:s)
d = transpose(reshape(reduce(hcat, d_), :, s))
z_ = reshape(transpose(tA \ d), size(u[1])..., :)
z = [z_s for z_s in eachslice(z_, dims = ndims(z_))]

p = QuadraticSplineParameterCache(z, t, cache_parameters)
n = length(t)
dtype_sc = typeof(t[1] / t[1])
sc = zeros(dtype_sc, n)
k, A = quadratic_spline_params(t, sc)

eltype_c_prototype = one(dtype_sc) * first(u)
c = [similar(eltype_c_prototype) for _ in 1:n]

# Assuming u contains arrays of equal shape
for j in eachindex(eltype_c_prototype)
c_dim = A \ [u_[j] for u_ in u]
for (i, c_dim_) in enumerate(c_dim)
c[i][j] = c_dim_
end
end

p = QuadraticSplineParameterCache(u, t, k, c, sc, cache_parameters)
A = QuadraticSpline(
u, t, nothing, p, tA, d, z, extrapolate, cache_parameters, linear_lookup)
u, t, nothing, p, k, c, sc, extrapolate, cache_parameters, assume_linear_t)
I = cumulative_integral(A, cache_parameters)
QuadraticSpline(u, t, I, p, tA, d, z, extrapolate, cache_parameters, linear_lookup)
QuadraticSpline(u, t, I, p, k, c, sc, extrapolate, cache_parameters, assume_linear_t)
end

"""
Expand Down
8 changes: 4 additions & 4 deletions src/interpolation_methods.jl
Original file line number Diff line number Diff line change
Expand Up @@ -148,10 +148,10 @@ end
# QuadraticSpline Interpolation
function _interpolate(A::QuadraticSpline{<:AbstractVector}, t::Number, iguess)
idx = get_idx(A, t, iguess)
Cᵢ = A.u[idx]
Δt = t - A.t[idx]
σ = get_parameters(A, idx)
return A.z[idx] * Δt + σ * Δt^2 + Cᵢ
α, β = get_parameters(A, idx)
uᵢ = A.u[idx]
Δt_scaled = (t - A.t[idx]) / (A.t[idx + 1] - A.t[idx])
Δt_scaled ** Δt_scaled + β) + uᵢ
end

# CubicSpline Interpolation
Expand Down
35 changes: 33 additions & 2 deletions src/interpolation_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,37 @@ function get_output_dim(u::AbstractArray)
return size(u)[1:(end - 1)]
end

function quadratic_spline_params(t::AbstractVector, sc::AbstractVector)

# Create knot vector
# Don't use x[end-1] as knot to match number of degrees of freedom with data
k = zeros(length(t) + 3)
k[1:3] .= t[1]
k[(end - 2):end] .= t[end]
k[4:(end - 3)] .= t[2:(end - 2)]

# Create linear system Ac = u, where:
# - A consists of basis function evaulations in t
# - c are 1D control points
n = length(t)
dtype_sc = typeof(t[1] / t[1])

diag = Vector{dtype_sc}(undef, n)
diag_hi = Vector{dtype_sc}(undef, n - 1)
diag_lo = Vector{dtype_sc}(undef, n - 1)

for (i, tᵢ) in enumerate(t)
spline_coefficients!(sc, 2, k, tᵢ)
diag[i] = sc[i]
(i > 1) && (diag_lo[i - 1] = sc[i - 1])
(i < n) && (diag_hi[i] = sc[i + 1])
end

A = Tridiagonal(diag_lo, diag, diag_hi)

return k, A
end

# helper function for data manipulation
function munge_data(u::AbstractVector{<:Real}, t::AbstractVector{<:Real})
return u, t
Expand Down Expand Up @@ -187,9 +218,9 @@ end

function get_parameters(A::QuadraticSpline, idx)
if A.cache_parameters
A.p.σ[idx]
A.p.α[idx], A.p.β[idx]
else
quadratic_spline_parameters(A.z, A.t, idx)
quadratic_spline_parameters(A.u, A.t, A.k, A.c, A.sc, idx)
end
end

Expand Down
28 changes: 19 additions & 9 deletions src/parameter_caches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -75,23 +75,33 @@ function quadratic_interpolation_parameters(u, t, idx)
end

struct QuadraticSplineParameterCache{pType}
σ::pType
α::pType
β::pType
end

function QuadraticSplineParameterCache(z, t, cache_parameters)
function QuadraticSplineParameterCache(u, t, k, c, sc, cache_parameters)
if cache_parameters
σ = quadratic_spline_parameters.(Ref(z), Ref(t), 1:(length(t) - 1))
QuadraticSplineParameterCache(σ)
parameters = quadratic_spline_parameters.(
Ref(u), Ref(t), Ref(k), Ref(c), Ref(sc), 1:(length(t) - 1))
α, β = collect.(eachrow(stack(collect.(parameters))))
QuadraticSplineParameterCache(α, β)
else
# Compute parameters once to infer types
σ = quadratic_spline_parameters(z, t, 1)
QuadraticSplineParameterCache(typeof(σ)[])
α, β = quadratic_spline_parameters(u, t, k, c, sc, 1)
QuadraticSplineParameterCache(typeof(α)[], typeof)[])
end
end

function quadratic_spline_parameters(z, t, idx)
σ = 1 // 2 * (z[idx + 1] - z[idx]) / (t[idx + 1] - t[idx])
return σ
function quadratic_spline_parameters(u, t, k, c, sc, idx)
tᵢ₊ = (t[idx] + t[idx + 1]) / 2
nonzero_coefficient_idxs = spline_coefficients!(sc, 2, k, tᵢ₊)
uᵢ₊ = zero(first(u))
for j in nonzero_coefficient_idxs
uᵢ₊ += sc[j] * c[j]
end
α = 2 * (u[idx + 1] + u[idx]) - 4uᵢ₊
β = 4 * (uᵢ₊ - u[idx]) - (u[idx + 1] - u[idx])
return α, β
end

struct CubicSplineParameterCache{pType}
Expand Down
4 changes: 2 additions & 2 deletions test/derivative_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -299,8 +299,8 @@ end
derivexpr2 = expand_derivatives(substitute(D2(A(ω)), Dict=> 0.5τ)))
symfunc1 = Symbolics.build_function(derivexpr1, τ; expression = Val{false})
symfunc2 = Symbolics.build_function(derivexpr2, τ; expression = Val{false})
@test symfunc1(0.5) == 0.5 * 3
@test symfunc2(0.5) == 0.5 * 6
@test symfunc1(0.5) == 1.5
@test symfunc2(0.5) == -3.0

u = [0.0, 1.5, 0.0]
t = [0.0, 0.5, 1.0]
Expand Down
19 changes: 9 additions & 10 deletions test/interpolation_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -510,39 +510,38 @@ end
A = QuadraticSpline(u, t; extrapolate = true)

# Solution
P₁ = x -> (x + 1)^2 # for x ∈ [-1, 0]
P₂ = x -> 2 * x + 1 # for x ∈ [ 0, 1]
P₁ = x -> 0.5 * (x + 1) * (x + 2)

for (_t, _u) in zip(t, u)
@test A(_t) == _u
end
@test A(-2.0) == P₁(-2.0)
@test A(-0.5) == P₁(-0.5)
@test A(0.7) == P(0.7)
@test A(2.0) == P(2.0)
@test A(0.7) == P(0.7)
@test A(2.0) == P(2.0)
test_cached_index(A)

u_ = [0.0, 1.0, 3.0]' .* ones(4)
u = [u_[:, i] for i in 1:size(u_, 2)]
A = QuadraticSpline(u, t; extrapolate = true)
@test A(-2.0) == P₁(-2.0) * ones(4)
@test A(-0.5) == P₁(-0.5) * ones(4)
@test A(0.7) == P(0.7) * ones(4)
@test A(2.0) == P(2.0) * ones(4)
@test A(0.7) == P(0.7) * ones(4)
@test A(2.0) == P(2.0) * ones(4)

u = [repeat(u[i], 1, 3) for i in 1:3]
A = QuadraticSpline(u, t; extrapolate = true)
@test A(-2.0) == P₁(-2.0) * ones(4, 3)
@test A(-0.5) == P₁(-0.5) * ones(4, 3)
@test A(0.7) == P(0.7) * ones(4, 3)
@test A(2.0) == P(2.0) * ones(4, 3)
@test A(0.7) == P(0.7) * ones(4, 3)
@test A(2.0) == P(2.0) * ones(4, 3)

# Test extrapolation
u = [0.0, 1.0, 3.0]
t = [-1.0, 0.0, 1.0]
A = QuadraticSpline(u, t; extrapolate = true)
@test A(-2.0) == 1.0
@test A(2.0) == 5.0
@test A(-2.0) == 0.0
@test A(2.0) == 6.0
A = QuadraticSpline(u, t)
@test_throws DataInterpolations.ExtrapolationError A(-2.0)
@test_throws DataInterpolations.ExtrapolationError A(2.0)
Expand Down
3 changes: 2 additions & 1 deletion test/parameter_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ end
u = [1.0, 5.0, 3.0, 4.0, 4.0]
t = collect(1:5)
A = QuadraticSpline(u, t; cache_parameters = true)
@test A.p.σ [4.0, -10.0, 13.0, -14.0]
@test A.p.α [-9.5, 3.5, -0.5, -0.5]
@test A.p.β [13.5, -5.5, 1.5, 0.5]
end

@testset "Cubic Spline" begin
Expand Down
9 changes: 2 additions & 7 deletions test/zygote_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ function test_zygote(method, u, t; args = [], args_after = [], kwargs = [], name
@test adiff zdiff
end
end
if method [LagrangeInterpolation, BSplineInterpolation, BSplineApprox]
if method
[LagrangeInterpolation, BSplineInterpolation, BSplineApprox, QuadraticSpline]
@testset "$name, derivatives w.r.t. u" begin
function f(u)
A = method(args..., u, t, args_after...; kwargs..., extrapolate = true)
Expand Down Expand Up @@ -86,12 +87,6 @@ end
QuinticHermiteSpline, u, t, args = [ddu, du], name = "Quintic Hermite Spline")
end

@testset "Quadratic Spline" begin
u = [1.0, 4.0, 9.0, 16.0]
t = [1.0, 2.0, 3.0, 4.0]
test_zygote(QuadraticSpline, u, t, name = "Quadratic Spline")
end

@testset "Lagrange Interpolation" begin
u = [1.0, 4.0, 9.0]
t = [1.0, 2.0, 3.0]
Expand Down

0 comments on commit 26d6f11

Please sign in to comment.