Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update QuadraticSpline #354

Merged
merged 9 commits into from
Nov 12, 2024
Merged
8 changes: 5 additions & 3 deletions src/derivatives.jl
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,11 @@ end

# QuadraticSpline Interpolation
function _derivative(A::QuadraticSpline{<:AbstractVector}, t::Number, iguess)
idx = get_idx(A, t, iguess; lb = 2, ub_shift = 0, side = :first)
σ = get_parameters(A, idx - 1)
A.z[idx - 1] + 2σ * (t - A.t[idx - 1])
idx = get_idx(A, t, iguess)
α, β = get_parameters(A, idx)
Δt = t - A.t[idx]
Δt_full = A.t[idx + 1] - A.t[idx]
2α * Δt / Δt_full^2 + β / Δt_full
end

# CubicSpline Interpolation
Expand Down
7 changes: 4 additions & 3 deletions src/integrals.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,11 @@ function _integral(A::QuadraticInterpolation{<:AbstractVector{<:Number}},
end

function _integral(A::QuadraticSpline{<:AbstractVector{<:Number}}, idx::Number, t::Number)
Cᵢ = A.u[idx]
α, β = get_parameters(A, idx)
uᵢ = A.u[idx]
Δt = t - A.t[idx]
σ = get_parameters(A, idx)
return A.z[idx] * Δt^2 / 2 + σ * Δt^3 / 3 + Cᵢ * Δt
Δt_full = A.t[idx + 1] - A.t[idx]
Δt * (α * Δt^2 / (3Δt_full^2) + β * Δt / (2Δt_full) + uᵢ)
end

function _integral(A::CubicSpline{<:AbstractVector{<:Number}}, idx::Number, t::Number)
Expand Down
77 changes: 37 additions & 40 deletions src/interpolation_caches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -305,31 +305,31 @@ Extrapolation extends the last quadratic polynomial on each side.
for a test based on the normalized standard deviation of the difference with respect
to the straight line (see [`looks_linear`](@ref)). Defaults to 1e-2.
"""
struct QuadraticSpline{uType, tType, IType, pType, tAType, dType, zType, T, N} <:
struct QuadraticSpline{uType, tType, IType, pType, kType, cType, scType, T, N} <:
AbstractInterpolation{T, N}
u::uType
t::tType
I::IType
p::QuadraticSplineParameterCache{pType}
tA::tAType
d::dType
z::zType
k::kType # knot vector
c::cType # B-spline control points
sc::scType # Spline coefficients (preallocated memory)
extrapolate::Bool
iguesser::Guesser{tType}
cache_parameters::Bool
linear_lookup::Bool
function QuadraticSpline(
u, t, I, p, tA, d, z, extrapolate, cache_parameters, assume_linear_t)
u, t, I, p, k, c, sc, extrapolate, cache_parameters, assume_linear_t)
linear_lookup = seems_linear(assume_linear_t, t)
N = get_output_dim(u)
new{typeof(u), typeof(t), typeof(I), typeof(p.σ), typeof(tA),
typeof(d), typeof(z), eltype(u), N}(u,
new{typeof(u), typeof(t), typeof(I), typeof(p.α), typeof(k),
typeof(c), typeof(sc), eltype(u), N}(u,
t,
I,
p,
tA,
d,
z,
k,
c,
sc,
extrapolate,
Guesser(t),
cache_parameters,
Expand All @@ -343,50 +343,47 @@ function QuadraticSpline(
cache_parameters = false, assume_linear_t = 1e-2) where {uType <:
AbstractVector{<:Number}}
u, t = munge_data(u, t)
linear_lookup = seems_linear(assume_linear_t, t)
s = length(t)
dl = ones(eltype(t), s - 1)
d_tmp = ones(eltype(t), s)
du = zeros(eltype(t), s - 1)
tA = Tridiagonal(dl, d_tmp, du)

# zero for element type of d, which we don't know yet
typed_zero = zero(2 // 1 * (u[begin + 1] - u[begin]) / (t[begin + 1] - t[begin]))

d = map(i -> i == 1 ? typed_zero : 2 // 1 * (u[i] - u[i - 1]) / (t[i] - t[i - 1]), 1:s)
z = tA \ d
n = length(t)
dtype_sc = typeof(t[1] / t[1])
sc = zeros(dtype_sc, n)
k, A = quadratic_spline_params(t, sc)
c = A \ u

p = QuadraticSplineParameterCache(z, t, cache_parameters)
p = QuadraticSplineParameterCache(u, t, k, c, sc, cache_parameters)
A = QuadraticSpline(
u, t, nothing, p, tA, d, z, extrapolate, cache_parameters, linear_lookup)
u, t, nothing, p, k, c, sc, extrapolate, cache_parameters, assume_linear_t)
I = cumulative_integral(A, cache_parameters)
QuadraticSpline(u, t, I, p, tA, d, z, extrapolate, cache_parameters, linear_lookup)
QuadraticSpline(u, t, I, p, k, c, sc, extrapolate, cache_parameters, assume_linear_t)
end

function QuadraticSpline(
u::uType, t; extrapolate = false, cache_parameters = false,
assume_linear_t = 1e-2) where {uType <:
AbstractVector}
u, t = munge_data(u, t)
linear_lookup = seems_linear(assume_linear_t, t)
s = length(t)
dl = ones(eltype(t), s - 1)
d_tmp = ones(eltype(t), s)
du = zeros(eltype(t), s - 1)
tA = Tridiagonal(dl, d_tmp, du)
d_ = map(
i -> i == 1 ? zeros(eltype(t), size(u[1])) :
2 // 1 * (u[i] - u[i - 1]) / (t[i] - t[i - 1]),
1:s)
d = transpose(reshape(reduce(hcat, d_), :, s))
z_ = reshape(transpose(tA \ d), size(u[1])..., :)
z = [z_s for z_s in eachslice(z_, dims = ndims(z_))]

p = QuadraticSplineParameterCache(z, t, cache_parameters)
n = length(t)
dtype_sc = typeof(t[1] / t[1])
sc = zeros(dtype_sc, n)
k, A = quadratic_spline_params(t, sc)

eltype_c_prototype = one(dtype_sc) * first(u)
c = [similar(eltype_c_prototype) for _ in 1:n]

# Assuming u contains arrays of equal shape
for j in eachindex(eltype_c_prototype)
c_dim = A \ [u_[j] for u_ in u]
for (i, c_dim_) in enumerate(c_dim)
c[i][j] = c_dim_
end
end

p = QuadraticSplineParameterCache(u, t, k, c, sc, cache_parameters)
A = QuadraticSpline(
u, t, nothing, p, tA, d, z, extrapolate, cache_parameters, linear_lookup)
u, t, nothing, p, k, c, sc, extrapolate, cache_parameters, assume_linear_t)
I = cumulative_integral(A, cache_parameters)
QuadraticSpline(u, t, I, p, tA, d, z, extrapolate, cache_parameters, linear_lookup)
QuadraticSpline(u, t, I, p, k, c, sc, extrapolate, cache_parameters, assume_linear_t)
end

"""
Expand Down
8 changes: 4 additions & 4 deletions src/interpolation_methods.jl
Original file line number Diff line number Diff line change
Expand Up @@ -148,10 +148,10 @@ end
# QuadraticSpline Interpolation
function _interpolate(A::QuadraticSpline{<:AbstractVector}, t::Number, iguess)
idx = get_idx(A, t, iguess)
Cᵢ = A.u[idx]
Δt = t - A.t[idx]
σ = get_parameters(A, idx)
return A.z[idx] * Δt + σ * Δt^2 + Cᵢ
α, β = get_parameters(A, idx)
uᵢ = A.u[idx]
Δt_scaled = (t - A.t[idx]) / (A.t[idx + 1] - A.t[idx])
Δt_scaled * (α * Δt_scaled + β) + uᵢ
end

# CubicSpline Interpolation
Expand Down
35 changes: 33 additions & 2 deletions src/interpolation_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,37 @@ function get_output_dim(u::AbstractArray)
return size(u)[1:(end - 1)]
end

function quadratic_spline_params(t::AbstractVector, sc::AbstractVector)

# Create knot vector
# Don't use x[end-1] as knot to match number of degrees of freedom with data
k = zeros(length(t) + 3)
k[1:3] .= t[1]
k[(end - 2):end] .= t[end]
k[4:(end - 3)] .= t[2:(end - 2)]

# Create linear system Ac = u, where:
# - A consists of basis function evaulations in t
# - c are 1D control points
n = length(t)
dtype_sc = typeof(t[1] / t[1])

diag = Vector{dtype_sc}(undef, n)
diag_hi = Vector{dtype_sc}(undef, n - 1)
diag_lo = Vector{dtype_sc}(undef, n - 1)

for (i, tᵢ) in enumerate(t)
spline_coefficients!(sc, 2, k, tᵢ)
diag[i] = sc[i]
(i > 1) && (diag_lo[i - 1] = sc[i - 1])
(i < n) && (diag_hi[i] = sc[i + 1])
end

A = Tridiagonal(diag_lo, diag, diag_hi)

return k, A
end

# helper function for data manipulation
function munge_data(u::AbstractVector{<:Real}, t::AbstractVector{<:Real})
return u, t
Expand Down Expand Up @@ -187,9 +218,9 @@ end

function get_parameters(A::QuadraticSpline, idx)
if A.cache_parameters
A.p.σ[idx]
A.p.α[idx], A.p.β[idx]
else
quadratic_spline_parameters(A.z, A.t, idx)
quadratic_spline_parameters(A.u, A.t, A.k, A.c, A.sc, idx)
end
end

Expand Down
28 changes: 19 additions & 9 deletions src/parameter_caches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -75,23 +75,33 @@ function quadratic_interpolation_parameters(u, t, idx)
end

struct QuadraticSplineParameterCache{pType}
σ::pType
α::pType
β::pType
end

function QuadraticSplineParameterCache(z, t, cache_parameters)
function QuadraticSplineParameterCache(u, t, k, c, sc, cache_parameters)
if cache_parameters
σ = quadratic_spline_parameters.(Ref(z), Ref(t), 1:(length(t) - 1))
QuadraticSplineParameterCache(σ)
parameters = quadratic_spline_parameters.(
Ref(u), Ref(t), Ref(k), Ref(c), Ref(sc), 1:(length(t) - 1))
α, β = collect.(eachrow(stack(collect.(parameters))))
QuadraticSplineParameterCache(α, β)
else
# Compute parameters once to infer types
σ = quadratic_spline_parameters(z, t, 1)
QuadraticSplineParameterCache(typeof(σ)[])
α, β = quadratic_spline_parameters(u, t, k, c, sc, 1)
QuadraticSplineParameterCache(typeof(α)[], typeof(β)[])
end
end

function quadratic_spline_parameters(z, t, idx)
σ = 1 // 2 * (z[idx + 1] - z[idx]) / (t[idx + 1] - t[idx])
return σ
function quadratic_spline_parameters(u, t, k, c, sc, idx)
tᵢ₊ = (t[idx] + t[idx + 1]) / 2
nonzero_coefficient_idxs = spline_coefficients!(sc, 2, k, tᵢ₊)
uᵢ₊ = zero(first(u))
for j in nonzero_coefficient_idxs
uᵢ₊ += sc[j] * c[j]
end
α = 2 * (u[idx + 1] + u[idx]) - 4uᵢ₊
β = 4 * (uᵢ₊ - u[idx]) - (u[idx + 1] - u[idx])
return α, β
end

struct CubicSplineParameterCache{pType}
Expand Down
4 changes: 2 additions & 2 deletions test/derivative_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -299,8 +299,8 @@ end
derivexpr2 = expand_derivatives(substitute(D2(A(ω)), Dict(ω => 0.5τ)))
symfunc1 = Symbolics.build_function(derivexpr1, τ; expression = Val{false})
symfunc2 = Symbolics.build_function(derivexpr2, τ; expression = Val{false})
@test symfunc1(0.5) == 0.5 * 3
@test symfunc2(0.5) == 0.5 * 6
@test symfunc1(0.5) == 1.5
@test symfunc2(0.5) == -3.0

u = [0.0, 1.5, 0.0]
t = [0.0, 0.5, 1.0]
Expand Down
19 changes: 9 additions & 10 deletions test/interpolation_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -510,39 +510,38 @@ end
A = QuadraticSpline(u, t; extrapolate = true)

# Solution
P₁ = x -> (x + 1)^2 # for x ∈ [-1, 0]
P₂ = x -> 2 * x + 1 # for x ∈ [ 0, 1]
P₁ = x -> 0.5 * (x + 1) * (x + 2)

for (_t, _u) in zip(t, u)
@test A(_t) == _u
end
@test A(-2.0) == P₁(-2.0)
@test A(-0.5) == P₁(-0.5)
@test A(0.7) == P(0.7)
@test A(2.0) == P(2.0)
@test A(0.7) == P(0.7)
@test A(2.0) == P(2.0)
test_cached_index(A)

u_ = [0.0, 1.0, 3.0]' .* ones(4)
u = [u_[:, i] for i in 1:size(u_, 2)]
A = QuadraticSpline(u, t; extrapolate = true)
@test A(-2.0) == P₁(-2.0) * ones(4)
@test A(-0.5) == P₁(-0.5) * ones(4)
@test A(0.7) == P(0.7) * ones(4)
@test A(2.0) == P(2.0) * ones(4)
@test A(0.7) == P(0.7) * ones(4)
@test A(2.0) == P(2.0) * ones(4)

u = [repeat(u[i], 1, 3) for i in 1:3]
A = QuadraticSpline(u, t; extrapolate = true)
@test A(-2.0) == P₁(-2.0) * ones(4, 3)
@test A(-0.5) == P₁(-0.5) * ones(4, 3)
@test A(0.7) == P(0.7) * ones(4, 3)
@test A(2.0) == P(2.0) * ones(4, 3)
@test A(0.7) == P(0.7) * ones(4, 3)
@test A(2.0) == P(2.0) * ones(4, 3)

# Test extrapolation
u = [0.0, 1.0, 3.0]
t = [-1.0, 0.0, 1.0]
A = QuadraticSpline(u, t; extrapolate = true)
@test A(-2.0) == 1.0
@test A(2.0) == 5.0
@test A(-2.0) == 0.0
@test A(2.0) == 6.0
A = QuadraticSpline(u, t)
@test_throws DataInterpolations.ExtrapolationError A(-2.0)
@test_throws DataInterpolations.ExtrapolationError A(2.0)
Expand Down
3 changes: 2 additions & 1 deletion test/parameter_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ end
u = [1.0, 5.0, 3.0, 4.0, 4.0]
t = collect(1:5)
A = QuadraticSpline(u, t; cache_parameters = true)
@test A.p.σ ≈ [4.0, -10.0, 13.0, -14.0]
@test A.p.α ≈ [-9.5, 3.5, -0.5, -0.5]
@test A.p.β ≈ [13.5, -5.5, 1.5, 0.5]
end

@testset "Cubic Spline" begin
Expand Down
9 changes: 2 additions & 7 deletions test/zygote_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ function test_zygote(method, u, t; args = [], args_after = [], kwargs = [], name
@test adiff ≈ zdiff
end
end
if method ∉ [LagrangeInterpolation, BSplineInterpolation, BSplineApprox]
if method ∉
[LagrangeInterpolation, BSplineInterpolation, BSplineApprox, QuadraticSpline]
@testset "$name, derivatives w.r.t. u" begin
function f(u)
A = method(args..., u, t, args_after...; kwargs..., extrapolate = true)
Expand Down Expand Up @@ -86,12 +87,6 @@ end
QuinticHermiteSpline, u, t, args = [ddu, du], name = "Quintic Hermite Spline")
end

@testset "Quadratic Spline" begin
u = [1.0, 4.0, 9.0, 16.0]
t = [1.0, 2.0, 3.0, 4.0]
test_zygote(QuadraticSpline, u, t, name = "Quadratic Spline")
end

@testset "Lagrange Interpolation" begin
u = [1.0, 4.0, 9.0]
t = [1.0, 2.0, 3.0]
Expand Down
Loading