Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor integration and QuadraticInterpolation #359

Merged
merged 4 commits into from
Nov 17, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions ext/DataInterpolationsChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@ if isdefined(Base, :get_extension)
LinearInterpolation, QuadraticInterpolation,
LagrangeInterpolation, AkimaInterpolation,
BSplineInterpolation, BSplineApprox, get_idx, get_parameters,
_quad_interp_indices, munge_data
munge_data
using ChainRulesCore
else
using ..DataInterpolations: _interpolate, derivative, AbstractInterpolation,
LinearInterpolation, QuadraticInterpolation,
LagrangeInterpolation, AkimaInterpolation,
BSplineInterpolation, BSplineApprox, get_parameters,
_quad_interp_indices, munge_data
munge_data
using ..ChainRulesCore
end

Expand Down Expand Up @@ -74,6 +74,11 @@ function u_tangent(A::LinearInterpolation, t, Δ)
out
end

function _quad_interp_indices(A::QuadraticInterpolation, t::Number, iguess)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to confirm, can't u_tangent for QuadraticInterpolation be reformulated using α and β parameters you added?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'ts possible, but I suspect this is more efficient. Also, that might break AD types that do not support mutation when those parameters are cached?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And you still need to have the logic somewhere to see which mode (forward/backward looking interpolation) was used for the u gradient.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alright, then we can keep this.

idx = get_idx(A, t, iguess; idx_shift = A.mode == :Backward ? -1 : 0, ub_shift = -2)
idx, idx + 1, idx + 2
end

function u_tangent(A::QuadraticInterpolation, t, Δ)
out = zero.(A.u)
i₀, i₁, i₂ = _quad_interp_indices(A, t, A.iguesser)
Expand Down
10 changes: 4 additions & 6 deletions src/derivatives.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,10 @@ function _derivative(A::LinearInterpolation, t::Number, iguess)
end

function _derivative(A::QuadraticInterpolation, t::Number, iguess)
i₀, i₁, i₂ = _quad_interp_indices(A, t, iguess)
l₀, l₁, l₂ = get_parameters(A, i₀)
du₀ = l₀ * (2t - A.t[i₁] - A.t[i₂])
du₁ = l₁ * (2t - A.t[i₀] - A.t[i₂])
du₂ = l₂ * (2t - A.t[i₀] - A.t[i₁])
return @views @. du₀ + du₁ + du₂
idx = get_idx(A, t, iguess)
Δt = t - A.t[idx]
α, β = get_parameters(A, idx)
return 2α * Δt + β
end

function _derivative(A::LagrangeInterpolation{<:AbstractVector}, t::Number)
Expand Down
6 changes: 5 additions & 1 deletion src/integral_inverses.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,11 @@ function invertible_integral(A::LinearInterpolation{<:AbstractVector{<:Number}})
return all(A.u .> 0)
end

get_I(A::AbstractInterpolation) = isempty(A.I) ? cumulative_integral(A, true) : A.I
function get_I(A::AbstractInterpolation)
I = isempty(A.I) ? cumulative_integral(A, true) : copy(A.I)
pushfirst!(I, 0)
I
end

function invert_integral(A::LinearInterpolation{<:AbstractVector{<:Number}})
!invertible_integral(A) && throw(IntegralNotInvertibleError())
Expand Down
122 changes: 59 additions & 63 deletions src/integrals.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,122 +7,118 @@ function integral(A::AbstractInterpolation, t1::Number, t2::Number)
((t1 < A.t[1] || t1 > A.t[end]) && !A.extrapolate) && throw(ExtrapolationError())
((t2 < A.t[1] || t2 > A.t[end]) && !A.extrapolate) && throw(ExtrapolationError())
!hasfield(typeof(A), :I) && throw(IntegralNotFoundError())
(t2 < t1) && return -integral(A, t2, t1)
# the index less than or equal to t1
idx1 = get_idx(A, t1, 0)
# the index less than t2
idx2 = get_idx(A, t2, 0; idx_shift = -1, side = :first)

if A.cache_parameters
total = A.I[idx2] - A.I[idx1]
total = A.I[max(1, idx2 - 1)] - A.I[idx1]
return if t1 == t2
zero(total)
else
total += _integral(A, idx1, A.t[idx1])
total -= _integral(A, idx1, t1)
total += _integral(A, idx2, t2)
total -= _integral(A, idx2, A.t[idx2])
if idx1 == idx2
total += _integral(A, idx1, t1, t2)
else
total += _integral(A, idx1, t1, A.t[idx1 + 1])
total += _integral(A, idx2, A.t[idx2], t2)
end
total
end
else
total = zero(eltype(A.u))
for idx in idx1:idx2
lt1 = idx == idx1 ? t1 : A.t[idx]
lt2 = idx == idx2 ? t2 : A.t[idx + 1]
total += _integral(A, idx, lt2) - _integral(A, idx, lt1)
total += _integral(A, idx, lt1, lt2)
end
total
end
end

function _integral(A::LinearInterpolation{<:AbstractVector{<:Number}},
idx::Number,
t::Number)
Δt = t - A.t[idx]
idx::Number, t1::Number, t2::Number)
slope = get_parameters(A, idx)
Δt * (A.u[idx] + slope * Δt / 2)
u_mean = A.u[idx] + slope * ((t1 + t2) / 2 - A.t[idx])
u_mean * (t2 - t1)
end

function _integral(
A::ConstantInterpolation{<:AbstractVector{<:Number}}, idx::Number, t::Number)
A::ConstantInterpolation{<:AbstractVector{<:Number}}, idx::Number, t1::Number, t2::Number)
Δt = t2 - t1
if A.dir === :left
# :left means that value to the left is used for interpolation
return A.u[idx] * t
return A.u[idx] * Δt
else
# :right means that value to the right is used for interpolation
return A.u[idx + 1] * t
return A.u[idx + 1] * Δt
end
end

function _integral(A::QuadraticInterpolation{<:AbstractVector{<:Number}},
idx::Number,
t::Number)
A.mode == :Backward && idx > 1 && (idx -= 1)
idx = min(length(A.t) - 2, idx)
t₀ = A.t[idx]
t₁ = A.t[idx + 1]
t₂ = A.t[idx + 2]

t_sq = (t^2) / 3
l₀, l₁, l₂ = get_parameters(A, idx)
Iu₀ = l₀ * t * (t_sq - t * (t₁ + t₂) / 2 + t₁ * t₂)
Iu₁ = l₁ * t * (t_sq - t * (t₀ + t₂) / 2 + t₀ * t₂)
Iu₂ = l₂ * t * (t_sq - t * (t₀ + t₁) / 2 + t₀ * t₁)
return Iu₀ + Iu₁ + Iu₂
idx::Number, t1::Number, t2::Number)
α, β = get_parameters(A, idx)
uᵢ = A.u[idx]
tᵢ = A.t[idx]
t1_rel = t1 - tᵢ
t2_rel = t2 - tᵢ
Δt = t2 - t1
Δt * (α * (t2_rel^2 + t1_rel * t2_rel + t1_rel^2) / 3 + β * (t2_rel + t1_rel) / 2 + uᵢ)
end

function _integral(A::QuadraticSpline{<:AbstractVector{<:Number}}, idx::Number, t::Number)
function _integral(
A::QuadraticSpline{<:AbstractVector{<:Number}}, idx::Number, t1::Number, t2::Number)
α, β = get_parameters(A, idx)
uᵢ = A.u[idx]
Δt = t - A.t[idx]
Δt_full = A.t[idx + 1] - A.t[idx]
Δt * (α * Δt^2 / (3Δt_full^2) + β * Δt / (2Δt_full) + uᵢ)
tᵢ = A.t[idx]
t1_rel = t1 - tᵢ
t2_rel = t2 - tᵢ
Δt = t2 - t1
Δt * (α * (t2_rel^2 + t1_rel * t2_rel + t1_rel^2) / 3 + β * (t2_rel + t1_rel) / 2 + uᵢ)
end

function _integral(A::CubicSpline{<:AbstractVector{<:Number}}, idx::Number, t::Number)
Δt₁sq = (t - A.t[idx])^2 / 2
Δt₂sq = (A.t[idx + 1] - t)^2 / 2
II = (-A.z[idx] * Δt₂sq^2 + A.z[idx + 1] * Δt₁sq^2) / (6A.h[idx + 1])
function _integral(
A::CubicSpline{<:AbstractVector{<:Number}}, idx::Number, t1::Number, t2::Number)
tᵢ = A.t[idx]
tᵢ₊₁ = A.t[idx + 1]
c₁, c₂ = get_parameters(A, idx)
IC = c₁ * Δt₁sq
ID = -c₂ * Δt₂sq
II + IC + ID
integrate_cubic_polynomial(t1, t2, tᵢ, 0, c₁, 0, A.z[idx + 1] / (6A.h[idx + 1])) +
integrate_cubic_polynomial(t1, t2, tᵢ₊₁, 0, -c₂, 0, -A.z[idx] / (6A.h[idx + 1]))
end

function _integral(A::AkimaInterpolation{<:AbstractVector{<:Number}},
idx::Number,
t::Number)
t1 = A.t[idx]
A.u[idx] * (t - t1) + A.b[idx] * ((t - t1)^2 / 2) + A.c[idx] * ((t - t1)^3 / 3) +
A.d[idx] * ((t - t1)^4 / 4)
idx::Number, t1::Number, t2::Number)
integrate_cubic_polynomial(t1, t2, A.t[idx], A.u[idx], A.b[idx], A.c[idx], A.d[idx])
end

_integral(A::LagrangeInterpolation, idx::Number, t::Number) = throw(IntegralNotFoundError())
_integral(A::BSplineInterpolation, idx::Number, t::Number) = throw(IntegralNotFoundError())
_integral(A::BSplineApprox, idx::Number, t::Number) = throw(IntegralNotFoundError())
function _integral(A::LagrangeInterpolation, idx::Number, t1::Number, t2::Number)
throw(IntegralNotFoundError())
end
function _integral(A::BSplineInterpolation, idx::Number, t1::Number, t2::Number)
throw(IntegralNotFoundError())
end
function _integral(A::BSplineApprox, idx::Number, t1::Number, t2::Number)
throw(IntegralNotFoundError())
end

# Cubic Hermite Spline
function _integral(
A::CubicHermiteSpline{<:AbstractVector{<:Number}}, idx::Number, t::Number)
Δt₀ = t - A.t[idx]
Δt₁ = t - A.t[idx + 1]
out = Δt₀ * (A.u[idx] + Δt₀ * A.du[idx] / 2)
A::CubicHermiteSpline{<:AbstractVector{<:Number}}, idx::Number, t1::Number, t2::Number)
c₁, c₂ = get_parameters(A, idx)
p = c₁ + Δt₁ * c₂
dp = c₂
out += Δt₀^3 / 3 * (p - dp * Δt₀ / 4)
out
tᵢ = A.t[idx]
tᵢ₊₁ = A.t[idx + 1]
c = c₁ - c₂ * (tᵢ₊₁ - tᵢ)
integrate_cubic_polynomial(t1, t2, tᵢ, A.u[idx], A.du[idx], c, c₂)
end

# Quintic Hermite Spline
function _integral(
A::QuinticHermiteSpline{<:AbstractVector{<:Number}}, idx::Number, t::Number)
Δt₀ = t - A.t[idx]
Δt₁ = t - A.t[idx + 1]
out = Δt₀ * (A.u[idx] + A.du[idx] * Δt₀ / 2 + A.ddu[idx] * Δt₀^2 / 6)
A::QuinticHermiteSpline{<:AbstractVector{<:Number}}, idx::Number, t1::Number, t2::Number)
tᵢ = A.t[idx]
tᵢ₊₁ = A.t[idx + 1]
Δt = tᵢ₊₁ - tᵢ
c₁, c₂, c₃ = get_parameters(A, idx)
p = c₁ + c₂ * Δt₁ + c₃ * Δt₁^2
dp = c₂ + 2c₃ * Δt₁
ddp = 2c₃
out += Δt₀^4 / 4 * (p - Δt₀ / 5 * dp + Δt₀^2 / 30 * ddp)
out
integrate_quintic_polynomial(t1, t2, tᵢ, A.u[idx], A.du[idx], A.ddu[idx] / 2,
c₁ + Δt * (-c₂ + c₃ * Δt), c₂ - 2c₃ * Δt, c₃)
end
4 changes: 2 additions & 2 deletions src/interpolation_caches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ struct QuadraticInterpolation{uType, tType, IType, pType, T, N} <:
error("mode should be :Forward or :Backward for QuadraticInterpolation")
linear_lookup = seems_linear(assume_linear_t, t)
N = get_output_dim(u)
new{typeof(u), typeof(t), typeof(I), typeof(p.l₀), eltype(u), N}(
new{typeof(u), typeof(t), typeof(I), typeof(p.α), eltype(u), N}(
u, t, I, p, mode, extrapolate, Guesser(t), cache_parameters, linear_lookup)
end
end
Expand All @@ -93,7 +93,7 @@ function QuadraticInterpolation(
u, t, mode; extrapolate = false, cache_parameters = false, assume_linear_t = 1e-2)
u, t = munge_data(u, t)
linear_lookup = seems_linear(assume_linear_t, t)
p = QuadraticParameterCache(u, t, cache_parameters)
p = QuadraticParameterCache(u, t, cache_parameters, mode)
A = QuadraticInterpolation(
u, t, nothing, p, mode, extrapolate, cache_parameters, linear_lookup)
I = cumulative_integral(A, cache_parameters)
Expand Down
18 changes: 6 additions & 12 deletions src/interpolation_methods.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,19 +42,13 @@ function _interpolate(A::LinearInterpolation{<:AbstractArray}, t::Number, iguess
end

# Quadratic Interpolation
_quad_interp_indices(A, t) = _quad_interp_indices(A, t, firstindex(A.t) - 1)
function _quad_interp_indices(A::QuadraticInterpolation, t::Number, iguess)
idx = get_idx(A, t, iguess; idx_shift = A.mode == :Backward ? -1 : 0, ub_shift = -2)
idx, idx + 1, idx + 2
end

function _interpolate(A::QuadraticInterpolation, t::Number, iguess)
i₀, i₁, i₂ = _quad_interp_indices(A, t, iguess)
l₀, l₁, l₂ = get_parameters(A, i₀)
u₀ = l₀ * (t - A.t[i₁]) * (t - A.t[i₂])
u₁ = l₁ * (t - A.t[i₀]) * (t - A.t[i₂])
u₂ = l₂ * (t - A.t[i₀]) * (t - A.t[i₁])
return u₀ + u₁ + u₂
idx = get_idx(A, t, iguess)
Δt = t - A.t[idx]
α, β = get_parameters(A, idx)
out = A.u isa AbstractMatrix ? A.u[:, idx] : A.u[idx]
out += @. Δt * (α * Δt + β)
out
end

# Lagrange Interpolation
Expand Down
30 changes: 24 additions & 6 deletions src/interpolation_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -190,10 +190,9 @@ function get_idx(A::AbstractInterpolation, t, iguess::Union{<:Integer, Guesser};
end

function cumulative_integral(A, cache_parameters)
if cache_parameters && hasmethod(_integral, Tuple{typeof(A), Number, Number})
integral_values = [_integral(A, idx, A.t[idx + 1]) - _integral(A, idx, A.t[idx])
for idx in 1:(length(A.t) - 1)]
pushfirst!(integral_values, zero(first(integral_values)))
if cache_parameters && hasmethod(_integral, Tuple{typeof(A), Number, Number, Number})
integral_values = _integral.(
Ref(A), 1:(length(A.t) - 1), A.t[1:(end - 1)], A.t[2:end])
cumsum(integral_values)
else
promote_type(eltype(A.u), eltype(A.t))[]
Expand All @@ -210,9 +209,9 @@ end

function get_parameters(A::QuadraticInterpolation, idx)
if A.cache_parameters
A.p.l₀[idx], A.p.l₁[idx], A.p.l₂[idx]
A.p.α[idx], A.p.β[idx]
else
quadratic_interpolation_parameters(A.u, A.t, idx)
quadratic_interpolation_parameters(A.u, A.t, idx, A.mode)
end
end

Expand Down Expand Up @@ -282,3 +281,22 @@ function du_PCHIP(u, t)

return _du.(eachindex(t))
end

function integrate_cubic_polynomial(t1, t2, offset, a, b, c, d)
t1_rel = t1 - offset
t2_rel = t2 - offset
t_sum = t1_rel + t2_rel
t_sq_sum = t1_rel^2 + t2_rel^2
Δt = t2 - t1
Δt * (a + t_sum * (b / 2 + d * t_sq_sum / 4) + c * (t_sq_sum + t1_rel * t2_rel) / 3)
end

function integrate_quintic_polynomial(t1, t2, offset, a, b, c, d, e, f)
t1_rel = t1 - offset
t2_rel = t2 - offset
t_sum = t1_rel + t2_rel
t_sq_sum = t1_rel^2 + t2_rel^2
Δt = t2 - t1
Δt * (a + t_sum * (b / 2 + d * t_sq_sum / 4) + c * (t_sq_sum + t1_rel * t2_rel) / 3) +
e * (t2_rel^5 - t1_rel^5) / 5 + f * (t2_rel^6 - t1_rel^6) / 6
end
Loading
Loading