Skip to content

Commit

Permalink
Use enumx for extrapolation types
Browse files Browse the repository at this point in the history
  • Loading branch information
SouthEndMusic committed Nov 14, 2024
1 parent 76859cd commit 832ca31
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 26 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ uuid = "82cc6244-b520-54b8-b5a6-8a565e85f1d0"
version = "6.6.0"

[deps]
EnumX = "4e289a0a-7415-4d19-859d-a7e5c4648b56"
FindFirstFunctions = "64ca27bc-2ba2-4a57-88aa-44e436879224"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand All @@ -27,6 +28,7 @@ DataInterpolationsSymbolicsExt = "Symbolics"
Aqua = "0.8"
BenchmarkTools = "1"
ChainRulesCore = "1.24"
EnumX = "1.0.4"
FindFirstFunctions = "1.3"
FiniteDifferences = "0.12.31"
ForwardDiff = "0.10.36"
Expand Down
12 changes: 7 additions & 5 deletions src/DataInterpolations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,12 @@ abstract type AbstractInterpolation{T, N} end
using LinearAlgebra, RecipesBase
using PrettyTables
using ForwardDiff
using EnumX
import FindFirstFunctions: searchsortedfirstcorrelated, searchsortedlastcorrelated,
Guesser

@enumx ExtrapolationType none constant linear extension

include("parameter_caches.jl")
include("interpolation_caches.jl")
include("interpolation_utils.jl")
Expand Down Expand Up @@ -55,13 +58,13 @@ function (interp::AbstractInterpolation)(u::AbstractVector, t::AbstractVector)
u
end

const DOWN_EXTRAPOLATION_ERROR = "Cannot extrapolate down as `extrapolation_down` keyword passed was `:none`"
const DOWN_EXTRAPOLATION_ERROR = "Cannot extrapolate down as `extrapolation_down` keyword passed was `none`"
struct DownExtrapolationError <: Exception end
function Base.showerror(io::IO, ::DownExtrapolationError)
print(io, DOWN_EXTRAPOLATION_ERROR)
end

const UP_EXTRAPOLATION_ERROR = "Cannot extrapolate up as `extrapolation_up` keyword passed was `:none`"
const UP_EXTRAPOLATION_ERROR = "Cannot extrapolate up as `extrapolation_up` keyword passed was `none`"
struct UpExtrapolationError <: Exception end
function Base.showerror(io::IO, ::UpExtrapolationError)
print(io, UP_EXTRAPOLATION_ERROR)
Expand Down Expand Up @@ -94,9 +97,8 @@ end
export LinearInterpolation, QuadraticInterpolation, LagrangeInterpolation,
AkimaInterpolation, ConstantInterpolation, QuadraticSpline, CubicSpline,
BSplineInterpolation, BSplineApprox, CubicHermiteSpline, PCHIPInterpolation,
QuinticHermiteSpline, LinearInterpolationIntInv, ConstantInterpolationIntInv

const extrapolation_types::Vector{Symbol} = [:none, :constant, :linear, :extension]
QuinticHermiteSpline, LinearInterpolationIntInv, ConstantInterpolationIntInv,
ExtrapolationType

# added for RegularizationSmooth, JJS 11/27/21
### Regularization data smoothing and interpolation
Expand Down
15 changes: 8 additions & 7 deletions src/interpolation_caches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -314,8 +314,8 @@ struct QuadraticSpline{uType, tType, IType, pType, kType, cType, scType, T, N} <
k::kType # knot vector
c::cType # B-spline control points
sc::scType # Spline coefficients (preallocated memory)
extrapolation_down::Symbol
extrapolation_up::Symbol
extrapolation_down::ExtrapolationType.T
extrapolation_up::ExtrapolationType.T
iguesser::Guesser{tType}
cache_parameters::Bool
linear_lookup::Bool
Expand All @@ -324,8 +324,6 @@ struct QuadraticSpline{uType, tType, IType, pType, kType, cType, scType, T, N} <
extrapolation_up, cache_parameters, assume_linear_t)
linear_lookup = seems_linear(assume_linear_t, t)
N = get_output_dim(u)
validate_extrapolation(extrapolation_down)
validate_extrapolation(extrapolation_up)
new{typeof(u), typeof(t), typeof(I), typeof(p.α), typeof(k),
typeof(c), typeof(sc), eltype(u), N}(u,
t,
Expand All @@ -344,7 +342,8 @@ struct QuadraticSpline{uType, tType, IType, pType, kType, cType, scType, T, N} <
end

function QuadraticSpline(
u::uType, t; extrapolation_down::Symbol = :none, extrapolation_up::Symbol = :none,
u::uType, t; extrapolation_down::ExtrapolationType.T = ExtrapolationType.none,
extrapolation_up::ExtrapolationType.T = ExtrapolationType.none,
cache_parameters = false, assume_linear_t = 1e-2) where {uType <:
AbstractVector{<:Number}}
u, t = munge_data(u, t)
Expand All @@ -357,9 +356,11 @@ function QuadraticSpline(

p = QuadraticSplineParameterCache(u, t, k, c, sc, cache_parameters)
A = QuadraticSpline(
u, t, nothing, p, k, c, sc, extrapolation_down, extrapolation_up, cache_parameters, assume_linear_t)
u, t, nothing, p, k, c, sc, extrapolation_down,
extrapolation_up, cache_parameters, assume_linear_t)
I = cumulative_integral(A, cache_parameters)
QuadraticSpline(u, t, I, p, k, c, sc, extrapolation_down, extrapolation_up, cache_parameters, assume_linear_t)
QuadraticSpline(u, t, I, p, k, c, sc, extrapolation_down,
extrapolation_up, cache_parameters, assume_linear_t)
end

function QuadraticSpline(
Expand Down
16 changes: 8 additions & 8 deletions src/interpolation_methods.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,28 +10,28 @@ end

function _extrapolate_down(A, t)
(; extrapolation_down) = A
if extrapolation_down == :none
if extrapolation_down == ExtrapolationType.none
throw(ExtrapolationError(DOWN_EXTRAPOLATION_ERROR))
elseif extrapolation_down == :constant
elseif extrapolation_down == ExtrapolationType.constant
first(A.u)
elseif extrapolation_down == :linear
elseif extrapolation_down == ExtrapolationType.linear
slope = derivative(A, first(A.t))
first(A.u) + slope * (t - first(A.t))
elseif extrapolation_down == :extension
elseif extrapolation_down == ExtrapolationType.extension
_interpolate(A, t, A.iguesser)
end
end

function _extrapolate_up(A, t)
(; extrapolation_up) = A
if extrapolation_up == :none
if extrapolation_up == ExtrapolationType.none
throw(ExtrapolationError(UP_EXTRAPOLATION_ERROR))
elseif extrapolation_up == :constant
elseif extrapolation_up == ExtrapolationType.constant
last(A.u)
elseif extrapolation_up == :linear
elseif extrapolation_up == ExtrapolationType.linear
slope = derivative(A, last(A.t))
last(A.u) + slope * (t - last(A.t))
elseif extrapolation_up == :extension
elseif extrapolation_up == ExtrapolationType.extension
_interpolate(A, t, A.iguesser)
end
end
Expand Down
6 changes: 0 additions & 6 deletions src/interpolation_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -248,12 +248,6 @@ function get_parameters(A::QuinticHermiteSpline, idx)
end
end

function validate_extrapolation(method::Symbol)
if method extrapolation_types
error("Invalid extrapolation method `$method` supplied, use one of $extrapolation_types.")
end
end

function du_PCHIP(u, t)
h = diff(u)
δ = h ./ diff(t)
Expand Down

0 comments on commit 832ca31

Please sign in to comment.