Skip to content

Commit

Permalink
Add tests + formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
SouthEndMusic committed Nov 22, 2024
1 parent d547e32 commit 6023946
Show file tree
Hide file tree
Showing 7 changed files with 67 additions and 42 deletions.
4 changes: 2 additions & 2 deletions docs/src/extrapolation_methods.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ this extrapolation type extends the interpolation such that `A(t + T) == A(t)` f
T = last(A.t) - first(A.t)
t_eval_left = range(first(t) - 2T, first(t), length = 100)
t_eval_right = range(last(t), last(t) + 2T, length = 100)
A = QuadraticSpline(u,t; extrapolation = ExtrapolationType.periodic)
A = QuadraticSpline(u, t; extrapolation = ExtrapolationType.periodic)
plot(A)
plot!(t_eval_left, A.(t_eval_left); label = "extrapolation down")
plot!(t_eval_right, A.(t_eval_right); label = "extrapolation up")
Expand All @@ -71,7 +71,7 @@ plot!(t_eval_right, A.(t_eval_right); label = "extrapolation up")
this extrapolation type extends the interpolation such that `A(t_ + t) == A(t_ - t)` for all `t_, t` such that `(t_ - first(A.t)) % T == 0` and `0 < t < T`, where `T = last(A.t) - first(A.t)`.

```@example tutorial
A = QuadraticSpline(u,t; extrapolation = ExtrapolationType.reflective)
A = QuadraticSpline(u, t; extrapolation = ExtrapolationType.reflective)
plot(A)
plot!(t_eval_left, A.(t_eval_left); label = "extrapolation down")
plot!(t_eval_right, A.(t_eval_right); label = "extrapolation up")
Expand Down
8 changes: 4 additions & 4 deletions src/derivatives.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ function _extrapolate_derivative_down(A, t, order)
derivative(A, t_, order)
else
# extrapolation_left == ExtrapolationType.reflective
t_, _ = transformation_reflective(A, t)
derivative(A, t_, order)
t_, n = transformation_reflective(A, t)
isodd(n) ? -derivative(A, t_, order) : derivative(A, t_, order)
end
end

Expand All @@ -58,8 +58,8 @@ function _extrapolate_derivative_up(A, t, order)
derivative(A, t_, order)
else
# extrapolation_right == ExtrapolationType.reflective
t_, _ = transformation_reflective(A, t)
derivative(A, t_, order)
t_, n = transformation_reflective(A, t)
iseven(n) ? -derivative(A, t_, order) : derivative(A, t_, order)
end
end

Expand Down
5 changes: 4 additions & 1 deletion src/integrals.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,10 @@ function integral(A::AbstractInterpolation, t1::Number, t2::Number)

# Complete intervals
if A.cache_parameters
total += A.I[idx2] - A.I[idx1 + 1]
total += A.I[idx2 - 1]
if idx1 > 0
total -= A.I[idx1]
end
else
for idx in (idx1 + 1):(idx2 - 1)
total += _integral(A, idx, A.t[idx], A.t[idx + 1])
Expand Down
2 changes: 1 addition & 1 deletion src/interpolation_methods.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ function _extrapolate_up(A, t)
# extrapolation_right == ExtrapolationType.reflective
t_, _ = transformation_reflective(A, t)
_interpolate(A, t_, A.iguesser)
end
end
end

# Linear Interpolation
Expand Down
2 changes: 1 addition & 1 deletion src/interpolation_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -326,4 +326,4 @@ function transformation_reflective(A, t)
t_ = isodd(n) ? last(A.t) - t_ : first(A.t) + t_
(n > 0) && (n -= 1)
t_, n
end
end
75 changes: 42 additions & 33 deletions test/extrapolation_tests.jl
Original file line number Diff line number Diff line change
@@ -1,43 +1,53 @@
using DataInterpolations
using DataInterpolations, Test
using ForwardDiff
using QuadGK

function test_extrapolation_errors(method, u, t)
A = method(u, t)
@test A.extrapolation_right == ExtrapolationType.none
@test A.extrapolation_left == ExtrapolationType.none
for (error_type, t_eval) in zip(
(DataInterpolations.LeftExtrapolationError,
DataInterpolations.RightExtrapolationError),
(first(t) - 1, last(t) + 1))
@test_throws error_type A(t_eval)
@test_throws error_type DataInterpolations.derivative(
A, t_eval)
@test_throws error_type DataInterpolations.derivative(
A, t_eval, 2)
@test_throws error_type DataInterpolations.integral(
A, t_eval)
function test_extrapolation(method, u, t)
@testset "Extrapolation errors" begin
A = method(u, t)
@test A.extrapolation_right == ExtrapolationType.none
@test A.extrapolation_left == ExtrapolationType.none
for (error_type, t_eval) in zip(
(DataInterpolations.LeftExtrapolationError,
DataInterpolations.RightExtrapolationError),
(first(t) - 1, last(t) + 1))
@test_throws error_type A(t_eval)
@test_throws error_type DataInterpolations.derivative(
A, t_eval)
@test_throws error_type DataInterpolations.derivative(
A, t_eval, 2)
@test_throws error_type DataInterpolations.integral(
A, t_eval)
end
end
end

function test_constant_extrapolation(method, u, t)
A = method(u, t; extrapolation_left = ExtrapolationType.constant,
extrapolation_right = ExtrapolationType.constant)
t_lower = first(t) - 1
t_upper = last(t) + 1
@test A(t_lower) == first(u)
@test A(t_upper) == last(u)
@test DataInterpolations.derivative(A, t_lower) == 0
@test DataInterpolations.derivative(A, t_upper) == 0
@test DataInterpolations.integral(A, t_lower, first(t))
first(u) * (first(t) - t_lower)
@test DataInterpolations.integral(A, last(t), t_upper) last(u) * (t_upper - last(t))
for extrapolation_type in instances(ExtrapolationType.T)
(extrapolation_type == ExtrapolationType.none) && continue
@testset "extrapolation type $extrapolation_type" begin
A = method(u, t; extrapolation = extrapolation_type)

t_eval = first(t) - 1.5
@test DataInterpolations.derivative(A, t_eval)
ForwardDiff.derivative(A, t_eval)

t_eval = last(t) + 1.5
@test DataInterpolations.derivative(A, t_eval)
ForwardDiff.derivative(A, t_eval)

T = last(A.t) - first(A.t)
t1 = first(t) - 2.5T
t2 = last(t) + 3.5T
@test DataInterpolations.integral(A, t1, t2)
quadgk(A, t1, t2; atol = 1e-12, rtol = 1e-12)[1]
end
end
end

@testset "Linear Interpolation" begin
u = [1.0, 2.0]
t = [1.0, 2.0]

test_extrapolation_errors(LinearInterpolation, u, t)
test_constant_extrapolation(LinearInterpolation, u, t)
test_extrapolation(LinearInterpolation, u, t)

for extrapolation_type in [ExtrapolationType.linear, ExtrapolationType.extension]
# Down extrapolation
Expand All @@ -64,8 +74,7 @@ end
u = [1.0, 3.0, 2.0]
t = 1:3

test_extrapolation_errors(QuadraticInterpolation, u, t)
test_constant_extrapolation(LinearInterpolation, u, t)
test_extrapolation(QuadraticInterpolation, u, t)

# Linear down extrapolation
A = QuadraticInterpolation(u, t; extrapolation_left = ExtrapolationType.linear)
Expand Down
13 changes: 13 additions & 0 deletions test/parameter_tests.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,18 @@
using DataInterpolations

function test_cached_integration(method, args...)
A_c = method(args...; cache_parameters = true)
A_nc = method(args...; cache_parameters = false)
@test DataInterpolations.integral(A_c, last(A_c.t))
DataInterpolations.integral(A_nc, last(A_nc.t))
end

@testset "Linear Interpolation" begin
u = [1.0, 5.0, 3.0, 4.0, 4.0]
t = collect(1:5)
A = LinearInterpolation(u, t; cache_parameters = true)
@test A.p.slope [4.0, -2.0, 1.0, 0.0]
test_cached_integration(LinearInterpolation, u, t)
end

@testset "Quadratic Interpolation" begin
Expand All @@ -13,6 +21,7 @@ end
A = QuadraticInterpolation(u, t; cache_parameters = true)
@test A.p.α [-3.0, 1.5, -0.5, -0.5]
@test A.p.β [7.0, -3.5, 1.5, 0.5]
test_cached_integration(QuadraticInterpolation, u, t)
end

@testset "Quadratic Spline" begin
Expand All @@ -21,6 +30,7 @@ end
A = QuadraticSpline(u, t; cache_parameters = true)
@test A.p.α [-9.5, 3.5, -0.5, -0.5]
@test A.p.β [13.5, -5.5, 1.5, 0.5]
test_cached_integration(QuadraticSpline, u, t)
end

@testset "Cubic Spline" begin
Expand All @@ -29,6 +39,7 @@ end
A = CubicSpline(u, t; cache_parameters = true)
@test A.p.c₁ [6.839285714285714, 1.642857142857143, 4.589285714285714, 4.0]
@test A.p.c₂ [1.0, 6.839285714285714, 1.642857142857143, 4.589285714285714]
test_cached_integration(CubicSpline, u, t)
end

@testset "Cubic Hermite Spline" begin
Expand All @@ -38,6 +49,7 @@ end
A = CubicHermiteSpline(du, u, t; cache_parameters = true)
@test A.p.c₁ [-1.0, -5.0, -5.0, -8.0]
@test A.p.c₂ [0.0, 13.0, 12.0, 9.0]
test_cached_integration(CubicHermiteSpline, du, u, t)
end

@testset "Quintic Hermite Spline" begin
Expand All @@ -49,4 +61,5 @@ end
@test A.p.c₁ [-1.0, -6.5, -8.0, -10.0]
@test A.p.c₂ [1.0, 19.5, 20.0, 19.0]
@test A.p.c₃ [1.5, -37.5, -37.0, -26.5]
test_cached_integration(QuinticHermiteSpline, ddu, du, u, t)
end

0 comments on commit 6023946

Please sign in to comment.