diff --git a/Project.toml b/Project.toml index 80cd1153..a32005d3 100644 --- a/Project.toml +++ b/Project.toml @@ -41,6 +41,7 @@ SafeTestsets = "0.1" StableRNGs = "1" Symbolics = "5.29, 6" Test = "1" +Unitful = "1" Zygote = "0.6.70" julia = "1.10" @@ -57,7 +58,8 @@ SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["Aqua", "BenchmarkTools", "SafeTestsets", "ChainRulesCore", "Optim", "RegularizationTools", "Test", "StableRNGs", "FiniteDifferences", "QuadGK", "ForwardDiff", "Symbolics", "Zygote"] +test = ["Aqua", "BenchmarkTools", "SafeTestsets", "ChainRulesCore", "Optim", "RegularizationTools", "Test", "StableRNGs", "FiniteDifferences", "QuadGK", "ForwardDiff", "Symbolics", "Unitful", "Zygote"] diff --git a/src/interpolation_methods.jl b/src/interpolation_methods.jl index 54129ae6..bf33f31d 100644 --- a/src/interpolation_methods.jl +++ b/src/interpolation_methods.jl @@ -9,9 +9,9 @@ function _interpolate(A::LinearInterpolation{<:AbstractVector}, t::Number, igues if isnan(t) # For correct derivative with NaN idx = firstindex(A.u) - t1 = t2 = one(eltype(A.t)) - u1 = u2 = one(eltype(A.u)) - slope = t * get_parameters(A, idx) + t1 = t2 = oneunit(eltype(A.t)) + u1 = u2 = oneunit(eltype(A.u)) + slope = t/t * get_parameters(A, idx) else idx = get_idx(A, t, iguess) t1, t2 = A.t[idx], A.t[idx + 1] diff --git a/test/interpolation_tests.jl b/test/interpolation_tests.jl index 65b48113..ec201560 100644 --- a/test/interpolation_tests.jl +++ b/test/interpolation_tests.jl @@ -3,6 +3,7 @@ using FindFirstFunctions: searchsortedfirstcorrelated using StableRNGs using Optim, ForwardDiff using BenchmarkTools +using Unitful function test_interpolation_type(T) @test T <: DataInterpolations.AbstractInterpolation @@ -105,6 +106,14 @@ end @test isnan(A(3.5)) @test isnan(A(4.0)) + u = [0.0, 1.0, 2.0, NaN] + A = LinearInterpolation(u, t; extrapolate = true) + @test A(1.0) == 0.0 + @test A(2.0) == 1.0 + @test A(3.0) == 2.0 + @test isnan(A(3.5)) + @test isnan(A(4.0)) + # Test type stability u = Float32.(1:5) t = Float32.(1:5) @@ -134,6 +143,12 @@ end @test @inferred(A(R64)) === A(R64) end + # NaN time value for Unitful arrays: issue #365 + t = (0:3)u"s" # Unitful quantities + u = [0, -2, -1, -2]u"m" + A = LinearInterpolation(u, t; extrapolate = true) + @test isnan(A(NaN*u"s")) + # Nan time value: t = 0.0:3 # Floats u = [0, -2, -1, -2]