Skip to content

Commit

Permalink
Add some tests, refactor most of integration
Browse files Browse the repository at this point in the history
  • Loading branch information
SouthEndMusic committed Nov 16, 2024
1 parent 3b3b626 commit c6ab856
Show file tree
Hide file tree
Showing 4 changed files with 176 additions and 66 deletions.
143 changes: 80 additions & 63 deletions src/integrals.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,89 +4,112 @@ end

function integral(A::AbstractInterpolation, t1::Number, t2::Number)
!hasfield(typeof(A), :I) && throw(IntegralNotFoundError())

if t1 == t2
# If the integration interval is trivial then the result is 0
return zero(eltype(A.I))
elseif t1 > t2
# Make sure that t1 < t2
return -integral(A, t2, t1)
end

# the index less than or equal to t1
idx1 = get_idx(A, t1, 0)
# the index less than t2
idx2 = get_idx(A, t2, 0; idx_shift = -1, side = :first)

if A.cache_parameters
total = A.I[idx2] - A.I[idx1]
return if t1 == t2
zero(total)
else
total += __integral(A, idx1, A.t[idx1])
total -= __integral(A, idx1, t1)
total += __integral(A, idx2, t2)
total -= __integral(A, idx2, A.t[idx2])
total
total = zero(eltype(A.I))

# Lower potentially incomplete interval
if t1 < first(A.t)

if t2 < first(A.t)
# If interval is entirely below data
return _extrapolate_integral_down(A, t2) - extrapolate_integral_down(A.t1)
end

idx1 -= 1 # Make sure lowest complete interval is included
total += _extrapolate_integral_down(A, t1)
else
total = zero(eltype(A.u))
for idx in idx1:idx2
lt1 = idx == idx1 ? t1 : A.t[idx]
lt2 = idx == idx2 ? t2 : A.t[idx + 1]
total += __integral(A, idx, lt2) - __integral(A, idx, lt1)
total += _integral(A, idx1, t1, A.t[idx1 + 1])
end

# Upper potentially incomplete interval
if t2 > last(A.t)

if t1 > last(A.t)
# If interval is entirely above data
return _extrapolate_integral_up(A, t2) - extrapolate_integral_up(A.t, t1)
end
total

idx2 += 1 # Make sure highest complete interval is included
total += _extrapolate_integral_up(A, t2)
else
total += _integral(A, idx2, A.t[idx2], t2)
end
end

function __integral(A::AbstractInterpolation, idx::Number, t::Number)
if t < first(A.t)
_extrapolate_integral_down(A, idx, t)
elseif t > last(A.t)
_extrapolate_integral_up(A, idx, t)
if idx1 == idx2
return _integral(A, idx1, t1, t2)
end

# Complete intervals
if A.cache_parameters
total += A.I[idx2] - A.I[idx1 + 1]
else
_integral(A, idx, t)
for idx in (idx1 + 1):(idx2 - 1)
total += _integral(A, idx, A.t[idx], A.t[idx + 1])
end
end

return total
end

function _extrapolate_integral_down(A, idx, t)
function _extrapolate_integral_down(A, t)
(; extrapolation_down) = A
if extrapolation_down == ExtrapolationType.none
throw(DownExtrapolationError())
elseif extrapolation_down == ExtrapolationType.constant
first(A.u) * (t - first(A.t))
first(A.u) * (first(A.t) - t)
elseif extrapolation_down == ExtrapolationType.linear
slope = derivative(A, first(A.t))
Δt = t - first(A.t)
(first(A.u) + slope * Δt / 2) * Δt
Δt = first(A.t) - t
(first(A.u) - slope * Δt / 2) * Δt
elseif extrapolation_down == ExtrapolationType.extension
_integral(A, idx, t)
_integral(A, 1, t, first(A.t))
end
end

function _extrapolate_integral_up(A, idx, t)
function _extrapolate_integral_up(A, t)
(; extrapolation_up) = A
if extrapolation_up == ExtrapolationType.none
throw(UpExtrapolationError())
elseif extrapolation_up == ExtrapolationType.constant
integral(A, A.t[end - 1], A.t[end]) + last(A.u) * (t - last(A.t))
last(A.u) * (t - last(A.t))
elseif extrapolation_up == ExtrapolationType.linear
slope = derivative(A, last(A.t))
Δt = t - last(A.t)
integral(A, A.t[end - 1], A.t[end]) + (last(A.u) + slope * Δt / 2) * Δt
(last(A.u) + slope * Δt / 2) * Δt
elseif extrapolation_up == ExtrapolationType.extension
_integral(A, idx, t)
_integral(A, length(A.t) - 1, last(A.t), t)
end
end

function _integral(A::LinearInterpolation{<:AbstractVector{<:Number}},
idx::Number,
t::Number)
Δt = t - A.t[idx]
idx::Number, t1::Number, t2::Number)
slope = get_parameters(A, idx)
Δt * (A.u[idx] + slope * Δt / 2)
u_mean = A.u[idx] + slope * ((t1 + t2)/2 - A.t[idx])
u_mean * (t2 - t1)
end

function _integral(
A::ConstantInterpolation{<:AbstractVector{<:Number}}, idx::Number, t::Number)
A::ConstantInterpolation{<:AbstractVector{<:Number}}, idx::Number, t1::Number, t2::Number)
Δt = t2 - t1
if A.dir === :left
# :left means that value to the left is used for interpolation
return A.u[idx] * t
return A.u[idx] * Δt
else
# :right means that value to the right is used for interpolation
return A.u[idx + 1] * t
return A.u[idx + 1] * Δt
end
end

Expand All @@ -107,30 +130,27 @@ function _integral(A::QuadraticInterpolation{<:AbstractVector{<:Number}},
return Iu₀ + Iu₁ + Iu₂
end

function _integral(A::QuadraticSpline{<:AbstractVector{<:Number}}, idx::Number, t::Number)
function _integral(A::QuadraticSpline{<:AbstractVector{<:Number}}, idx::Number, t1::Number, t2::Number)
α, β = get_parameters(A, idx)
uᵢ = A.u[idx]
Δt = t - A.t[idx]
Δt_full = A.t[idx + 1] - A.t[idx]
Δt ** Δt^2 / (3Δt_full^2) + β * Δt / (2Δt_full) + uᵢ)
tᵢ = A.t[idx]
t1_rel = t1 - tᵢ
t2_rel = t2 - tᵢ
Δt = t2 - t1
Δt ** (t2_rel^2 + t1_rel * t2_rel + t1_rel^2) / 3 + β * (t2_rel + t1_rel) / 2 + uᵢ)
end

function _integral(A::CubicSpline{<:AbstractVector{<:Number}}, idx::Number, t::Number)
Δt₁sq = (t - A.t[idx])^2 / 2
Δt₂sq = (A.t[idx + 1] - t)^2 / 2
II = (-A.z[idx] * Δt₂sq^2 + A.z[idx + 1] * Δt₁sq^2) / (6A.h[idx + 1])
function _integral(A::CubicSpline{<:AbstractVector{<:Number}}, idx::Number, t1::Number, t2::Number)
tᵢ = A.t[idx]
tᵢ₊₁ = A.t[idx + 1]
c₁, c₂ = get_parameters(A, idx)
IC = c₁ * Δt₁sq
ID = -c₂ * Δt₂sq
II + IC + ID
integrate_cubic_polynomial(t1, t2, tᵢ, 0, c₁, 0, A.z[idx + 1] / (6A.h[idx + 1])) +
integrate_cubic_polynomial(t1, t2, tᵢ₊₁, 0, -c₂, 0, -A.z[idx] / (6A.h[idx + 1]))
end

function _integral(A::AkimaInterpolation{<:AbstractVector{<:Number}},
idx::Number,
t::Number)
t1 = A.t[idx]
A.u[idx] * (t - t1) + A.b[idx] * ((t - t1)^2 / 2) + A.c[idx] * ((t - t1)^3 / 3) +
A.d[idx] * ((t - t1)^4 / 4)
idx::Number, t1::Number, t2::Number)
integrate_cubic_polynomial(t1, t2, A.t[idx], A.u[idx], A.b[idx], A.c[idx], A.d[idx])
end

_integral(A::LagrangeInterpolation, idx::Number, t::Number) = throw(IntegralNotFoundError())
Expand All @@ -139,15 +159,12 @@ _integral(A::BSplineApprox, idx::Number, t::Number) = throw(IntegralNotFoundErro

# Cubic Hermite Spline
function _integral(
A::CubicHermiteSpline{<:AbstractVector{<:Number}}, idx::Number, t::Number)
Δt₀ = t - A.t[idx]
Δt₁ = t - A.t[idx + 1]
out = Δt₀ * (A.u[idx] + Δt₀ * A.du[idx] / 2)
A::CubicHermiteSpline{<:AbstractVector{<:Number}}, idx::Number, t1::Number, t2::Number)
c₁, c₂ = get_parameters(A, idx)
p = c₁ + Δt₁ * c₂
dp = c₂
out += Δt₀^3 / 3 * (p - dp * Δt₀ / 4)
out
tᵢ = A.t[idx]
tᵢ₊₁ = A.t[idx + 1]
c = c₁ - c₂ * (tᵢ₊₁ - tᵢ)
integrate_cubic_polynomial(t1, t2, tᵢ, A.u[idx], A.du[idx], c, c₂)
end

# Quintic Hermite Spline
Expand Down
14 changes: 11 additions & 3 deletions src/interpolation_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -190,9 +190,8 @@ function get_idx(A::AbstractInterpolation, t, iguess::Union{<:Integer, Guesser};
end

function cumulative_integral(A, cache_parameters)
if cache_parameters && hasmethod(_integral, Tuple{typeof(A), Number, Number})
integral_values = [_integral(A, idx, A.t[idx + 1]) - _integral(A, idx, A.t[idx])
for idx in 1:(length(A.t) - 1)]
if cache_parameters && hasmethod(_integral, Tuple{typeof(A), Number, Number, Number})
integral_values = _integral.(Ref(A), 1:(length(A.t) - 1), A.t[1:end-1], A.t[2:end])
pushfirst!(integral_values, zero(first(integral_values)))
cumsum(integral_values)
else
Expand Down Expand Up @@ -282,3 +281,12 @@ function du_PCHIP(u, t)

return _du.(eachindex(t))
end

function integrate_cubic_polynomial(t1, t2, offset, a, b, c, d)
t1_rel = t1 - offset
t2_rel = t2 - offset
t_sum = t1_rel + t2_rel
t_sq_sum = t1_rel^2 + t2_rel^2
Δt = t2 - t1
Δt * (a + t_sum * (b / 2 + d * t_sq_sum / 4) + c * (t_sq_sum + t1_rel * t2_rel) / 3)
end
84 changes: 84 additions & 0 deletions test/extrapolation_tests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
using DataInterpolations

@testset "Linear Interpolation" begin
u = [1.0, 2.0]
t = [1.0, 2.0]

A = LinearInterpolation(u, t; extrapolation_down = ExtrapolationType.constant)
t_eval = 0.0
@test A(t_eval) == 1.0
@test DataInterpolations.derivative(A, t_eval) == 0.0
@test DataInterpolations.derivative(A, t_eval, 2) == 0.0
@test DataInterpolations.integral(A, t_eval) == -1.0
t_eval = 3.0
@test_throws DataInterpolations.UpExtrapolationError A(t_eval)
@test_throws DataInterpolations.UpExtrapolationError DataInterpolations.derivative(A, t_eval)
@test_throws DataInterpolations.UpExtrapolationError DataInterpolations.derivative(A, t_eval, 2)
@test_throws DataInterpolations.UpExtrapolationError DataInterpolations.integral(A, t_eval)

A = LinearInterpolation(u, t; extrapolation_up = ExtrapolationType.constant)
t_eval = 3.0
@test A(t_eval) == 2.0
@test DataInterpolations.derivative(A, t_eval) == 0.0
@test DataInterpolations.derivative(A, t_eval, 2) == 0.0
@test DataInterpolations.integral(A, t_eval) == 3.5
t_eval = 0.0
@test_throws DataInterpolations.DownExtrapolationError A(t_eval)
@test_throws DataInterpolations.DownExtrapolationError DataInterpolations.derivative(A, t_eval)
@test_throws DataInterpolations.DownExtrapolationError DataInterpolations.derivative(A, t_eval, 2)
@test_throws DataInterpolations.DownExtrapolationError DataInterpolations.integral(A, t_eval)

for extrapolation_type in [ExtrapolationType.linear, ExtrapolationType.extension]
A = LinearInterpolation(u, t; extrapolation_down = extrapolation_type)
t_eval = 0.0
@test A(t_eval) == 0.0
@test DataInterpolations.derivative(A, t_eval) == 1.0
@test DataInterpolations.derivative(A, t_eval, 2) == 0.0
@test DataInterpolations.integral(A, t_eval) == -0.5
t_eval = 3.0
@test_throws DataInterpolations.UpExtrapolationError A(t_eval)
@test_throws DataInterpolations.UpExtrapolationError DataInterpolations.derivative(A, t_eval)
@test_throws DataInterpolations.UpExtrapolationError DataInterpolations.derivative(A, t_eval, 2)
@test_throws DataInterpolations.UpExtrapolationError DataInterpolations.integral(A, t_eval)

A = LinearInterpolation(u, t; extrapolation_up = extrapolation_type)
t_eval = 3.0
@test A(t_eval) == 3.0
@test DataInterpolations.derivative(A, t_eval) == 1.0
@test DataInterpolations.derivative(A, t_eval, 2) == 0.0
@test DataInterpolations.integral(A, t_eval) == 4.0
t_eval = 0.0
@test_throws DataInterpolations.DownExtrapolationError A(t_eval)
@test_throws DataInterpolations.DownExtrapolationError DataInterpolations.derivative(A, t_eval)
@test_throws DataInterpolations.DownExtrapolationError DataInterpolations.derivative(A, t_eval, 2)
@test_throws DataInterpolations.DownExtrapolationError DataInterpolations.integral(A, t_eval)
end
end

@testset "Quadratic Interpolation" begin
u = [1.0, 3.0, 2.0]
t = 1:3
f = t -> (-3t^2 + 13t - 8)/2

A = QuadraticInterpolation(u, t; extrapolation_down = ExtrapolationType.constant)
t_eval = 0.0
@test A(t_eval) 1.0
@test DataInterpolations.derivative(A, t_eval) == 0.0
@test DataInterpolations.derivative(A, t_eval, 2) == 0.0
@test DataInterpolations.integral(A, t_eval) -1.0
t_eval = 4.0
@test_throws DataInterpolations.UpExtrapolationError DataInterpolations.derivative(A, t_eval)
@test_throws DataInterpolations.UpExtrapolationError DataInterpolations.derivative(A, t_eval, 2)
@test_throws DataInterpolations.UpExtrapolationError DataInterpolations.integral(A, t_eval)

A = QuadraticInterpolation(u, t; extrapolation_up = ExtrapolationType.constant)
t_eval = 4.0
@test A(t_eval) 2.0
@test DataInterpolations.derivative(A, t_eval) == 0.0
@test DataInterpolations.derivative(A, t_eval, 2) == 0.0
@test DataInterpolations.integral(A, t[end], t_eval) 2.0
t_eval = 0.0
@test_throws DataInterpolations.DownExtrapolationError DataInterpolations.derivative(A, t_eval)
@test_throws DataInterpolations.DownExtrapolationError DataInterpolations.derivative(A, t_eval, 2)
@test_throws DataInterpolations.DownExtrapolationError DataInterpolations.integral(A, t_eval)
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ using SafeTestsets
@safetestset "Derivative Tests" include("derivative_tests.jl")
@safetestset "Integral Tests" include("integral_tests.jl")
@safetestset "Integral Inverse Tests" include("integral_inverse_tests.jl")
@safetestset "Extrapolation" include("extrapolation_tests.jl")
@safetestset "Online Tests" include("online_tests.jl")
@safetestset "Regularization Smoothing" include("regularization.jl")
@safetestset "Show methods" include("show.jl")
Expand Down

0 comments on commit c6ab856

Please sign in to comment.