From cae43d3550615d26a75a8035df13be6bab89e455 Mon Sep 17 00:00:00 2001 From: Bart de Koning Date: Fri, 12 Jul 2024 08:09:48 +0200 Subject: [PATCH] Throw extrapolation error also when interpolation is called with an `AbstractVector` --- src/DataInterpolations.jl | 12 +++++++----- test/interpolation_tests.jl | 1 + test/online_tests.jl | 7 ++++--- 3 files changed, 12 insertions(+), 8 deletions(-) diff --git a/src/DataInterpolations.jl b/src/DataInterpolations.jl index c86a6579..752b0388 100644 --- a/src/DataInterpolations.jl +++ b/src/DataInterpolations.jl @@ -23,7 +23,11 @@ include("online.jl") include("show.jl") (interp::AbstractInterpolation)(t::Number) = _interpolate(interp, t) -(interp::AbstractInterpolation)(t::Number, i::Integer) = _interpolate(interp, t, i) +function (interp::AbstractInterpolation)(t::Number, i::Integer) + interp.idx_prev[] = i + _interpolate(interp, t) +end + function (interp::AbstractInterpolation)(t::AbstractVector) u = get_u(interp.u, t) interp(u, t) @@ -44,16 +48,14 @@ function get_u(u::AbstractMatrix, t) end function (interp::AbstractInterpolation)(u::AbstractMatrix, t::AbstractVector) - iguess = firstindex(interp.t) @inbounds for i in eachindex(t) - u[:, i], iguess = interp(t[i], iguess) + u[:, i] = interp(t[i]) end u end function (interp::AbstractInterpolation)(u::AbstractVector, t::AbstractVector) - iguess = firstindex(interp.t) @inbounds for i in eachindex(u, t) - u[i], iguess = interp(t[i], iguess) + u[i] = interp(t[i]) end u end diff --git a/test/interpolation_tests.jl b/test/interpolation_tests.jl index 779c5bd8..b47a8075 100644 --- a/test/interpolation_tests.jl +++ b/test/interpolation_tests.jl @@ -170,6 +170,7 @@ end A = LinearInterpolation(u, t) @test_throws DataInterpolations.ExtrapolationError A(-1.0) @test_throws DataInterpolations.ExtrapolationError A(11.0) + @test_throws DataInterpolations.ExtrapolationError A([-1.0, 11.0]) end @testset "Quadratic Interpolation" begin diff --git a/test/online_tests.jl b/test/online_tests.jl index 571c2dc0..95818cc8 100644 --- a/test/online_tests.jl +++ b/test/online_tests.jl @@ -6,7 +6,8 @@ u1 = [0.0, 1.0, 0.0] t2 = [4.0, 5.0, 6.0] u2 = [1.0, 2.0, 1.0] -ts = 1.0:0.5:6.0 +ts_append = 1.0:0.5:6.0 +ts_push = 1.0:0.5:4.0 for method in [LinearInterpolation, QuadraticInterpolation, ConstantInterpolation] func1 = method(u1, t1) @@ -17,7 +18,7 @@ for method in [LinearInterpolation, QuadraticInterpolation, ConstantInterpolatio for name in propertynames(func1.p) @test getfield(func1.p, name) == getfield(func2.p, name) end - @test func1(ts) == func2(ts) + @test func1(ts_append) == func2(ts_append) func1 = method(u1, t1) push!(func1, 1.0, 4.0) @@ -27,5 +28,5 @@ for method in [LinearInterpolation, QuadraticInterpolation, ConstantInterpolatio for name in propertynames(func1.p) @test getfield(func1.p, name) == getfield(func2.p, name) end - @test func1(ts) == func2(ts) + @test func1(ts_push) == func2(ts_push) end