diff --git a/Project.toml b/Project.toml index 80cd1153..48e92f6f 100644 --- a/Project.toml +++ b/Project.toml @@ -3,6 +3,7 @@ uuid = "82cc6244-b520-54b8-b5a6-8a565e85f1d0" version = "6.6.0" [deps] +EnumX = "4e289a0a-7415-4d19-859d-a7e5c4648b56" FindFirstFunctions = "64ca27bc-2ba2-4a57-88aa-44e436879224" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" @@ -27,6 +28,7 @@ DataInterpolationsSymbolicsExt = "Symbolics" Aqua = "0.8" BenchmarkTools = "1" ChainRulesCore = "1.24" +EnumX = "1.0.4" FindFirstFunctions = "1.3" FiniteDifferences = "0.12.31" ForwardDiff = "0.10.36" diff --git a/src/DataInterpolations.jl b/src/DataInterpolations.jl index 9e56531f..eca8f783 100644 --- a/src/DataInterpolations.jl +++ b/src/DataInterpolations.jl @@ -7,9 +7,12 @@ abstract type AbstractInterpolation{T, N} end using LinearAlgebra, RecipesBase using PrettyTables using ForwardDiff +using EnumX import FindFirstFunctions: searchsortedfirstcorrelated, searchsortedlastcorrelated, Guesser +@enumx ExtrapolationType none constant linear extension + include("parameter_caches.jl") include("interpolation_caches.jl") include("interpolation_utils.jl") @@ -55,13 +58,13 @@ function (interp::AbstractInterpolation)(u::AbstractVector, t::AbstractVector) u end -const DOWN_EXTRAPOLATION_ERROR = "Cannot extrapolate down as `extrapolation_down` keyword passed was `:none`" +const DOWN_EXTRAPOLATION_ERROR = "Cannot extrapolate down as `extrapolation_down` keyword passed was `none`" struct DownExtrapolationError <: Exception end function Base.showerror(io::IO, ::DownExtrapolationError) print(io, DOWN_EXTRAPOLATION_ERROR) end -const UP_EXTRAPOLATION_ERROR = "Cannot extrapolate up as `extrapolation_up` keyword passed was `:none`" +const UP_EXTRAPOLATION_ERROR = "Cannot extrapolate up as `extrapolation_up` keyword passed was `none`" struct UpExtrapolationError <: Exception end function Base.showerror(io::IO, ::UpExtrapolationError) print(io, UP_EXTRAPOLATION_ERROR) @@ -94,9 +97,8 @@ end export LinearInterpolation, QuadraticInterpolation, LagrangeInterpolation, AkimaInterpolation, ConstantInterpolation, QuadraticSpline, CubicSpline, BSplineInterpolation, BSplineApprox, CubicHermiteSpline, PCHIPInterpolation, - QuinticHermiteSpline, LinearInterpolationIntInv, ConstantInterpolationIntInv - -const extrapolation_types::Vector{Symbol} = [:none, :constant, :linear, :extension] + QuinticHermiteSpline, LinearInterpolationIntInv, ConstantInterpolationIntInv, + ExtrapolationType # added for RegularizationSmooth, JJS 11/27/21 ### Regularization data smoothing and interpolation diff --git a/src/interpolation_caches.jl b/src/interpolation_caches.jl index f8f65755..992ab265 100644 --- a/src/interpolation_caches.jl +++ b/src/interpolation_caches.jl @@ -314,8 +314,8 @@ struct QuadraticSpline{uType, tType, IType, pType, kType, cType, scType, T, N} < k::kType # knot vector c::cType # B-spline control points sc::scType # Spline coefficients (preallocated memory) - extrapolation_down::Symbol - extrapolation_up::Symbol + extrapolation_down::ExtrapolationType.T + extrapolation_up::ExtrapolationType.T iguesser::Guesser{tType} cache_parameters::Bool linear_lookup::Bool @@ -324,8 +324,6 @@ struct QuadraticSpline{uType, tType, IType, pType, kType, cType, scType, T, N} < extrapolation_up, cache_parameters, assume_linear_t) linear_lookup = seems_linear(assume_linear_t, t) N = get_output_dim(u) - validate_extrapolation(extrapolation_down) - validate_extrapolation(extrapolation_up) new{typeof(u), typeof(t), typeof(I), typeof(p.α), typeof(k), typeof(c), typeof(sc), eltype(u), N}(u, t, @@ -344,7 +342,8 @@ struct QuadraticSpline{uType, tType, IType, pType, kType, cType, scType, T, N} < end function QuadraticSpline( - u::uType, t; extrapolation_down::Symbol = :none, extrapolation_up::Symbol = :none, + u::uType, t; extrapolation_down::ExtrapolationType.T = ExtrapolationType.none, + extrapolation_up::ExtrapolationType.T = ExtrapolationType.none, cache_parameters = false, assume_linear_t = 1e-2) where {uType <: AbstractVector{<:Number}} u, t = munge_data(u, t) @@ -357,9 +356,11 @@ function QuadraticSpline( p = QuadraticSplineParameterCache(u, t, k, c, sc, cache_parameters) A = QuadraticSpline( - u, t, nothing, p, k, c, sc, extrapolation_down, extrapolation_up, cache_parameters, assume_linear_t) + u, t, nothing, p, k, c, sc, extrapolation_down, + extrapolation_up, cache_parameters, assume_linear_t) I = cumulative_integral(A, cache_parameters) - QuadraticSpline(u, t, I, p, k, c, sc, extrapolation_down, extrapolation_up, cache_parameters, assume_linear_t) + QuadraticSpline(u, t, I, p, k, c, sc, extrapolation_down, + extrapolation_up, cache_parameters, assume_linear_t) end function QuadraticSpline( diff --git a/src/interpolation_methods.jl b/src/interpolation_methods.jl index 9dc02b95..a7692434 100644 --- a/src/interpolation_methods.jl +++ b/src/interpolation_methods.jl @@ -10,28 +10,28 @@ end function _extrapolate_down(A, t) (; extrapolation_down) = A - if extrapolation_down == :none + if extrapolation_down == ExtrapolationType.none throw(ExtrapolationError(DOWN_EXTRAPOLATION_ERROR)) - elseif extrapolation_down == :constant + elseif extrapolation_down == ExtrapolationType.constant first(A.u) - elseif extrapolation_down == :linear + elseif extrapolation_down == ExtrapolationType.linear slope = derivative(A, first(A.t)) first(A.u) + slope * (t - first(A.t)) - elseif extrapolation_down == :extension + elseif extrapolation_down == ExtrapolationType.extension _interpolate(A, t, A.iguesser) end end function _extrapolate_up(A, t) (; extrapolation_up) = A - if extrapolation_up == :none + if extrapolation_up == ExtrapolationType.none throw(ExtrapolationError(UP_EXTRAPOLATION_ERROR)) - elseif extrapolation_up == :constant + elseif extrapolation_up == ExtrapolationType.constant last(A.u) - elseif extrapolation_up == :linear + elseif extrapolation_up == ExtrapolationType.linear slope = derivative(A, last(A.t)) last(A.u) + slope * (t - last(A.t)) - elseif extrapolation_up == :extension + elseif extrapolation_up == ExtrapolationType.extension _interpolate(A, t, A.iguesser) end end diff --git a/src/interpolation_utils.jl b/src/interpolation_utils.jl index 77470423..2d1c6432 100644 --- a/src/interpolation_utils.jl +++ b/src/interpolation_utils.jl @@ -248,12 +248,6 @@ function get_parameters(A::QuinticHermiteSpline, idx) end end -function validate_extrapolation(method::Symbol) - if method ∉ extrapolation_types - error("Invalid extrapolation method `$method` supplied, use one of $extrapolation_types.") - end -end - function du_PCHIP(u, t) h = diff(u) δ = h ./ diff(t)