Skip to content

Commit

Permalink
Merge pull request SciML#193 from sathvikbhagavan/sb/add_extrapolate_…
Browse files Browse the repository at this point in the history
…keyword

Add extrapolation keyword
  • Loading branch information
ChrisRackauckas authored Oct 17, 2023
2 parents 2649c4b + 1802631 commit 5ad4c24
Show file tree
Hide file tree
Showing 9 changed files with 375 additions and 248 deletions.
6 changes: 6 additions & 0 deletions src/DataInterpolations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,12 @@ function (interp::AbstractInterpolation)(u::AbstractVector, t::AbstractVector)
u
end

const EXTRAPOLATION_ERROR = "Cannot extrapolate as `extrapolate` keyword passed was `false`"
struct ExtrapolationError <: Exception end
function Base.showerror(io::IO, e::ExtrapolationError)
print(io, EXTRAPOLATION_ERROR)
end

export LinearInterpolation, QuadraticInterpolation, LagrangeInterpolation,
AkimaInterpolation, ConstantInterpolation, QuadraticSpline, CubicSpline,
BSplineInterpolation, BSplineApprox
Expand Down
9 changes: 8 additions & 1 deletion src/derivatives.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
derivative(A, t) = derivative(A, t, firstindex(A.t) - 1)[1]
function derivative(A, t)
((t < A.t[1] || t > A.t[end]) && !A.extrapolate) && throw(ExtrapolationError())
derivative(A, t, firstindex(A.t) - 1)[1]
end

function derivative(A::LinearInterpolation{<:AbstractVector}, t::Number, iguess)
idx = searchsortedfirstcorrelated(A.t, t, iguess)
Expand Down Expand Up @@ -33,6 +36,7 @@ function derivative(A::QuadraticInterpolation{<:AbstractMatrix}, t::Number, igue
end

function derivative(A::LagrangeInterpolation{<:AbstractVector}, t::Number)
((t < A.t[1] || t > A.t[end]) && !A.extrapolate) && throw(ExtrapolationError())
idxs = findRequiredIdxs(A, t)
if A.t[idxs[1]] == t
return zero(A.u[idxs[1]])
Expand Down Expand Up @@ -68,6 +72,7 @@ function derivative(A::LagrangeInterpolation{<:AbstractVector}, t::Number)
end

function derivative(A::LagrangeInterpolation{<:AbstractMatrix}, t::Number)
((t < A.t[1] || t > A.t[end]) && !A.extrapolate) && throw(ExtrapolationError())
idxs = findRequiredIdxs(A, t)
if A.t[idxs[1]] == t
return zero(A.u[:, idxs[1]])
Expand Down Expand Up @@ -115,10 +120,12 @@ function derivative(A::AkimaInterpolation{<:AbstractVector}, t::Number, iguess)
end

function derivative(A::ConstantInterpolation{<:AbstractVector}, t::Number)
((t < A.t[1] || t > A.t[end]) && !A.extrapolate) && throw(ExtrapolationError())
return isempty(searchsorted(A.t, t)) ? zero(A.u[1]) : eltype(A.u)(NaN)
end

function derivative(A::ConstantInterpolation{<:AbstractMatrix}, t::Number)
((t < A.t[1] || t > A.t[end]) && !A.extrapolate) && throw(ExtrapolationError())
return isempty(searchsorted(A.t, t)) ? zero(A.u[:, 1]) : eltype(A.u)(NaN) .* A.u[:, 1]
end

Expand Down
126 changes: 85 additions & 41 deletions src/interpolation_caches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,34 +2,38 @@
struct LinearInterpolation{uType, tType, FT, T} <: AbstractInterpolation{FT, T}
u::uType
t::tType
function LinearInterpolation{FT}(u, t) where {FT}
new{typeof(u), typeof(t), FT, eltype(u)}(u, t)
extrapolate::Bool
function LinearInterpolation{FT}(u, t, extrapolate) where {FT}
new{typeof(u), typeof(t), FT, eltype(u)}(u, t, extrapolate)
end
end

function LinearInterpolation(u, t)
function LinearInterpolation(u, t; extrapolate = true)
u, t = munge_data(u, t)
LinearInterpolation{true}(u, t)
LinearInterpolation{true}(u, t, extrapolate)
end

### Quadratic Interpolation
struct QuadraticInterpolation{uType, tType, FT, T} <: AbstractInterpolation{FT, T}
u::uType
t::tType
mode::Symbol
function QuadraticInterpolation{FT}(u, t, mode) where {FT}
extrapolate::Bool
function QuadraticInterpolation{FT}(u, t, mode, extrapolate) where {FT}
mode (:Forward, :Backward) ||
error("mode should be :Forward or :Backward for QuadraticInterpolation")
new{typeof(u), typeof(t), FT, eltype(u)}(u, t, mode)
new{typeof(u), typeof(t), FT, eltype(u)}(u, t, mode, extrapolate)
end
end

function QuadraticInterpolation(u, t, mode)
function QuadraticInterpolation(u, t, mode; extrapolate = true)
u, t = munge_data(u, t)
QuadraticInterpolation{true}(u, t, mode)
QuadraticInterpolation{true}(u, t, mode, extrapolate)
end

QuadraticInterpolation(u, t) = QuadraticInterpolation(u, t, :Forward)
function QuadraticInterpolation(u, t; extrapolate = true)
QuadraticInterpolation(u, t, :Forward; extrapolate)
end

### Lagrange Interpolation
struct LagrangeInterpolation{uType, tType, FT, T, bcacheType} <:
Expand All @@ -38,22 +42,27 @@ struct LagrangeInterpolation{uType, tType, FT, T, bcacheType} <:
t::tType
n::Int
bcache::bcacheType
function LagrangeInterpolation{FT}(u, t, n) where {FT}
extrapolate::Bool
function LagrangeInterpolation{FT}(u, t, n, extrapolate) where {FT}
bcache = zeros(eltype(u[1]), n + 1)
fill!(bcache, NaN)
new{typeof(u), typeof(t), FT, eltype(u), typeof(bcache)}(u, t, n, bcache)
new{typeof(u), typeof(t), FT, eltype(u), typeof(bcache)}(u,
t,
n,
bcache,
extrapolate)
end
end

function LagrangeInterpolation(u, t, n = nothing)
function LagrangeInterpolation(u, t, n = nothing; extrapolate = true)
u, t = munge_data(u, t)
if isnothing(n)
n = length(t) - 1 # degree
end
if n != length(t) - 1
error("Currently only n=length(t) - 1 is supported")
end
LagrangeInterpolation{true}(u, t, n)
LagrangeInterpolation{true}(u, t, n, extrapolate)
end

### Akima Interpolation
Expand All @@ -64,17 +73,19 @@ struct AkimaInterpolation{uType, tType, bType, cType, dType, FT, T} <:
b::bType
c::cType
d::dType
function AkimaInterpolation{FT}(u, t, b, c, d) where {FT}
extrapolate::Bool
function AkimaInterpolation{FT}(u, t, b, c, d, extrapolate) where {FT}
new{typeof(u), typeof(t), typeof(b), typeof(c),
typeof(d), FT, eltype(u)}(u,
t,
b,
c,
d)
d,
extrapolate)
end
end

function AkimaInterpolation(u, t)
function AkimaInterpolation(u, t; extrapolate = true)
u, t = munge_data(u, t)
n = length(t)
dt = diff(t)
Expand All @@ -96,22 +107,23 @@ function AkimaInterpolation(u, t)
c = (3.0 .* m[3:(end - 2)] .- 2.0 .* b[1:(end - 1)] .- b[2:end]) ./ dt
d = (b[1:(end - 1)] .+ b[2:end] .- 2.0 .* m[3:(end - 2)]) ./ dt .^ 2

AkimaInterpolation{true}(u, t, b, c, d)
AkimaInterpolation{true}(u, t, b, c, d, extrapolate)
end

### ConstantInterpolation Interpolation
struct ConstantInterpolation{uType, tType, dirType, FT, T} <: AbstractInterpolation{FT, T}
u::uType
t::tType
dir::Symbol # indicates if value to the $dir should be used for the interpolation
function ConstantInterpolation{FT}(u, t, dir) where {FT}
new{typeof(u), typeof(t), typeof(dir), FT, eltype(u)}(u, t, dir)
extrapolate::Bool
function ConstantInterpolation{FT}(u, t, dir, extrapolate) where {FT}
new{typeof(u), typeof(t), typeof(dir), FT, eltype(u)}(u, t, dir, extrapolate)
end
end

function ConstantInterpolation(u, t; dir = :left)
function ConstantInterpolation(u, t; dir = :left, extrapolate = true)
u, t = munge_data(u, t)
ConstantInterpolation{true}(u, t, dir)
ConstantInterpolation{true}(u, t, dir, extrapolate)
end

Base.@deprecate_binding ZeroSpline ConstantInterpolation
Expand All @@ -124,17 +136,21 @@ struct QuadraticSpline{uType, tType, tAType, dType, zType, FT, T} <:
tA::tAType
d::dType
z::zType
function QuadraticSpline{FT}(u, t, tA, d, z) where {FT}
extrapolate::Bool
function QuadraticSpline{FT}(u, t, tA, d, z, extrapolate) where {FT}
new{typeof(u), typeof(t), typeof(tA),
typeof(d), typeof(z), FT, eltype(u)}(u,
t,
tA,
d,
z)
z,
extrapolate)
end
end

function QuadraticSpline(u::uType, t) where {uType <: AbstractVector{<:Number}}
function QuadraticSpline(u::uType,
t;
extrapolate = true) where {uType <: AbstractVector{<:Number}}
u, t = munge_data(u, t)
s = length(t)
dl = ones(eltype(t), s - 1)
Expand All @@ -147,10 +163,10 @@ function QuadraticSpline(u::uType, t) where {uType <: AbstractVector{<:Number}}

d = map(i -> i == 1 ? typed_zero : 2 // 1 * (u[i] - u[i - 1]) / (t[i] - t[i - 1]), 1:s)
z = tA \ d
QuadraticSpline{true}(u, t, tA, d, z)
QuadraticSpline{true}(u, t, tA, d, z, extrapolate)
end

function QuadraticSpline(u::uType, t) where {uType <: AbstractVector}
function QuadraticSpline(u::uType, t; extrapolate = true) where {uType <: AbstractVector}
u, t = munge_data(u, t)
s = length(t)
dl = ones(eltype(t), s - 1)
Expand All @@ -163,7 +179,7 @@ function QuadraticSpline(u::uType, t) where {uType <: AbstractVector}
d = transpose(reshape(reduce(hcat, d_), :, s))
z_ = reshape(transpose(tA \ d), size(u[1])..., :)
z = [z_s for z_s in eachslice(z_, dims = ndims(z_))]
QuadraticSpline{true}(u, t, tA, d, z)
QuadraticSpline{true}(u, t, tA, d, z, extrapolate)
end

# Cubic Spline Interpolation
Expand All @@ -172,12 +188,19 @@ struct CubicSpline{uType, tType, hType, zType, FT, T} <: AbstractInterpolation{F
t::tType
h::hType
z::zType
function CubicSpline{FT}(u, t, h, z) where {FT}
new{typeof(u), typeof(t), typeof(h), typeof(z), FT, eltype(u)}(u, t, h, z)
extrapolate::Bool
function CubicSpline{FT}(u, t, h, z, extrapolate) where {FT}
new{typeof(u), typeof(t), typeof(h), typeof(z), FT, eltype(u)}(u,
t,
h,
z,
extrapolate)
end
end

function CubicSpline(u::uType, t) where {uType <: AbstractVector{<:Number}}
function CubicSpline(u::uType,
t;
extrapolate = true) where {uType <: AbstractVector{<:Number}}
u, t = munge_data(u, t)
n = length(t) - 1
h = vcat(0, map(k -> t[k + 1] - t[k], 1:(length(t) - 1)), 0)
Expand All @@ -194,10 +217,10 @@ function CubicSpline(u::uType, t) where {uType <: AbstractVector{<:Number}}
6(u[i + 1] - u[i]) / h[i + 1] - 6(u[i] - u[i - 1]) / h[i],
1:(n + 1))
z = tA \ d
CubicSpline{true}(u, t, h[1:(n + 1)], z)
CubicSpline{true}(u, t, h[1:(n + 1)], z, extrapolate)
end

function CubicSpline(u::uType, t) where {uType <: AbstractVector}
function CubicSpline(u::uType, t; extrapolate = true) where {uType <: AbstractVector}
u, t = munge_data(u, t)
n = length(t) - 1
h = vcat(0, map(k -> t[k + 1] - t[k], 1:(length(t) - 1)), 0)
Expand All @@ -211,7 +234,7 @@ function CubicSpline(u::uType, t) where {uType <: AbstractVector}
d = transpose(reshape(reduce(hcat, d_), :, n + 1))
z_ = reshape(transpose(tA \ d), size(u[1])..., :)
z = [z_s for z_s in eachslice(z_, dims = ndims(z_))]
CubicSpline{true}(u, t, h[1:(n + 1)], z)
CubicSpline{true}(u, t, h[1:(n + 1)], z, extrapolate)
end

### BSpline Curve Interpolation
Expand All @@ -225,19 +248,29 @@ struct BSplineInterpolation{uType, tType, pType, kType, cType, FT, T} <:
c::cType # control points
pVecType::Symbol
knotVecType::Symbol
function BSplineInterpolation{FT}(u, t, d, p, k, c, pVecType, knotVecType) where {FT}
extrapolate::Bool
function BSplineInterpolation{FT}(u,
t,
d,
p,
k,
c,
pVecType,
knotVecType,
extrapolate) where {FT}
new{typeof(u), typeof(t), typeof(p), typeof(k), typeof(c), FT, eltype(u)}(u,
t,
d,
p,
k,
c,
pVecType,
knotVecType)
knotVecType,
extrapolate)
end
end

function BSplineInterpolation(u, t, d, pVecType, knotVecType)
function BSplineInterpolation(u, t, d, pVecType, knotVecType; extrapolate = true)
u, t = munge_data(u, t)
n = length(t)
s = zero(eltype(u))
Expand Down Expand Up @@ -298,7 +331,7 @@ function BSplineInterpolation(u, t, d, pVecType, knotVecType)
# control points
N = spline_coefficients(n, d, k, p)
c = vec(N \ u[:, :])
BSplineInterpolation{true}(u, t, d, p, k, c, pVecType, knotVecType)
BSplineInterpolation{true}(u, t, d, p, k, c, pVecType, knotVecType, extrapolate)
end

### BSpline Curve Approx
Expand All @@ -313,7 +346,17 @@ struct BSplineApprox{uType, tType, pType, kType, cType, FT, T} <:
c::cType # control points
pVecType::Symbol
knotVecType::Symbol
function BSplineApprox{FT}(u, t, d, h, p, k, c, pVecType, knotVecType) where {FT}
extrapolate::Bool
function BSplineApprox{FT}(u,
t,
d,
h,
p,
k,
c,
pVecType,
knotVecType,
extrapolate) where {FT}
new{typeof(u), typeof(t), typeof(p), typeof(k), typeof(c), FT, eltype(u)}(u,
t,
d,
Expand All @@ -322,11 +365,12 @@ struct BSplineApprox{uType, tType, pType, kType, cType, FT, T} <:
k,
c,
pVecType,
knotVecType)
knotVecType,
extrapolate)
end
end

function BSplineApprox(u, t, d, h, pVecType, knotVecType)
function BSplineApprox(u, t, d, h, pVecType, knotVecType; extrapolate = true)
u, t = munge_data(u, t)
n = length(t)
s = zero(eltype(u))
Expand Down Expand Up @@ -409,5 +453,5 @@ function BSplineApprox(u, t, d, h, pVecType, knotVecType)
M = transpose(N) * N
P = M \ Q
c[2:(end - 1)] .= vec(P)
BSplineApprox{true}(u, t, d, h, p, k, c, pVecType, knotVecType)
BSplineApprox{true}(u, t, d, h, p, k, c, pVecType, knotVecType, extrapolate)
end
8 changes: 7 additions & 1 deletion src/interpolation_methods.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
_interpolate(interp, t) = _interpolate(interp, t, firstindex(interp.t) - 1)[1]
function _interpolate(interp, t)
((t < interp.t[1] || t > interp.t[end]) && !interp.extrapolate) &&
throw(ExtrapolationError())
_interpolate(interp, t, firstindex(interp.t) - 1)[1]
end

# Linear Interpolation
function _interpolate(A::LinearInterpolation{<:AbstractVector}, t::Number, iguess)
Expand Down Expand Up @@ -53,6 +57,7 @@ end

# Lagrange Interpolation
function _interpolate(A::LagrangeInterpolation{<:AbstractVector}, t::Number)
((t < A.t[1] || t > A.t[end]) && !A.extrapolate) && throw(ExtrapolationError())
idxs = findRequiredIdxs(A, t)
if A.t[idxs[1]] == t
return A.u[idxs[1]]
Expand Down Expand Up @@ -81,6 +86,7 @@ function _interpolate(A::LagrangeInterpolation{<:AbstractVector}, t::Number)
end

function _interpolate(A::LagrangeInterpolation{<:AbstractMatrix}, t::Number)
((t < A.t[1] || t > A.t[end]) && !A.extrapolate) && throw(ExtrapolationError())
idxs = findRequiredIdxs(A, t)
if A.t[idxs[1]] == t
return A.u[:, idxs[1]]
Expand Down
2 changes: 1 addition & 1 deletion src/interpolation_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -125,4 +125,4 @@ function searchsortedlastcorrelated(v::AbstractVector, x, guess)
end

searchsortedfirstcorrelated(r::AbstractRange, x, _) = searchsortedfirst(r, x)
searchsortedlastcorrelated(r::AbstractRange, x, _) = searchsortedlast(r, x)
searchsortedlastcorrelated(r::AbstractRange, x, _) = searchsortedlast(r, x)
Loading

0 comments on commit 5ad4c24

Please sign in to comment.