From 321e70f4f0c2aa50666d88321d23d02e9c0a4d0a Mon Sep 17 00:00:00 2001 From: Bart de Koning Date: Tue, 3 Dec 2024 16:10:40 +0100 Subject: [PATCH] Fix tests --- src/interpolation_caches.jl | 5 +++-- src/interpolation_utils.jl | 3 ++- src/parameter_caches.jl | 8 +++++--- test/parameter_tests.jl | 4 ++-- 4 files changed, 12 insertions(+), 8 deletions(-) diff --git a/src/interpolation_caches.jl b/src/interpolation_caches.jl index 40c76e24..e9ad5a81 100644 --- a/src/interpolation_caches.jl +++ b/src/interpolation_caches.jl @@ -384,7 +384,7 @@ drifts far from the integral of constant interpolation. In this interpolation ty - `d_max`: the maximum distance in `t` from the data points the smoothing is allowed to reach. - `extrapolation`: The extrapolation type applied left and right of the data. Possible options are `ExtrapolationType.None` (default), `ExtrapolationType.Constant`, `ExtrapolationType.Linear` - `ExtrapolationType.Extension`, `ExtrapolationType.Periodic` and `ExtrapolationType.Reflective`. + `ExtrapolationType.Extension`, `ExtrapolationType.Periodic` (also made smooth at the boundaries) and `ExtrapolationType.Reflective`. - `extrapolation_left`: The extrapolation type applied left of the data. See `extrapolation` for the possible options. This keyword is ignored if `extrapolation != Extrapolation.none`. - `extrapolation_right`: The extrapolation type applied right of the data. See `extrapolation` for @@ -427,7 +427,8 @@ function SmoothedConstantInterpolation( extrapolation_left, extrapolation_right = munge_extrapolation( extrapolation, extrapolation_left, extrapolation_right) u, t = munge_data(u, t) - p = SmoothedConstantParameterCache(u, t, cache_parameters, d_max) + p = SmoothedConstantParameterCache( + u, t, cache_parameters, d_max, extrapolation_left, extrapolation_right) A = SmoothedConstantInterpolation( u, t, nothing, p, d_max, extrapolation_left, extrapolation_right, cache_parameters, assume_linear_t) diff --git a/src/interpolation_utils.jl b/src/interpolation_utils.jl index bb1f7492..eff36bb8 100644 --- a/src/interpolation_utils.jl +++ b/src/interpolation_utils.jl @@ -215,7 +215,8 @@ function get_parameters(A::SmoothedConstantInterpolation, idx) c_upper = A.p.c[idx + 1] d_lower, d_upper, c_lower, c_upper else - d_lower, c_lower = smoothed_constant_interpolation_parameters(A.u, A.t, A.d_max, idx, A.extrapolation_left, A.extrapolation_right) + d_lower, c_lower = smoothed_constant_interpolation_parameters( + A.u, A.t, A.d_max, idx, A.extrapolation_left, A.extrapolation_right) d_upper, c_upper = smoothed_constant_interpolation_parameters( A.u, A.t, A.d_max, idx + 1, A.extrapolation_left, A.extrapolation_right) d_lower, d_upper, c_lower, c_upper diff --git a/src/parameter_caches.jl b/src/parameter_caches.jl index ca4c5773..c5e1daf1 100644 --- a/src/parameter_caches.jl +++ b/src/parameter_caches.jl @@ -37,10 +37,11 @@ struct SmoothedConstantParameterCache{dType, cType} c::cType end -function SmoothedConstantParameterCache(u, t, cache_parameters, d_max) +function SmoothedConstantParameterCache( + u, t, cache_parameters, d_max, extrapolation_left, extrapolation_right) if cache_parameters parameters = smoothed_constant_interpolation_parameters.( - Ref(u), Ref(t), d_max, eachindex(t)) + Ref(u), Ref(t), d_max, eachindex(t), extrapolation_left, extrapolation_right) d, c = collect.(eachrow(stack(collect.(parameters)))) SmoothedConstantParameterCache(d, c) else @@ -48,7 +49,8 @@ function SmoothedConstantParameterCache(u, t, cache_parameters, d_max) end end -function smoothed_constant_interpolation_parameters(u, t, d_max, idx, extrapolation_left, extrapolation_right) +function smoothed_constant_interpolation_parameters( + u, t, d_max, idx, extrapolation_left, extrapolation_right) if isone(idx) || (idx == length(t)) # If extrapolation is periodic, make the transition differentiable if extrapolation_left == extrapolation_right == ExtrapolationType.Periodic diff --git a/test/parameter_tests.jl b/test/parameter_tests.jl index 978a1d51..40293517 100644 --- a/test/parameter_tests.jl +++ b/test/parameter_tests.jl @@ -19,8 +19,8 @@ end u = [1.0, 5.0, 3.0, 4.0, 4.0] t = collect(1:5) A = SmoothedConstantInterpolation(u, t; cache_parameters = true) - A.p.d ≈ [0.0, 0.5, 0.5, 0.5, 0.0] - A.p.c ≈ [0.0, 2.0, -1.0, 0.5, 0.0] + @test A.p.d ≈ [0.0, 0.5, 0.5, 0.5, 0.0] + @test A.p.c ≈ [0.0, 2.0, -1.0, 0.5, 0.0] end @testset "Quadratic Interpolation" begin