Skip to content

Commit

Permalink
feat(CubicSpline): add dispatch for AbstractMatrix
Browse files Browse the repository at this point in the history
  • Loading branch information
ashutosh-b-b committed Oct 7, 2024
1 parent e9a6d73 commit e3406e5
Showing 1 changed file with 33 additions and 0 deletions.
33 changes: 33 additions & 0 deletions src/interpolation_caches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,37 @@ function CubicSpline(u::uType,
CubicSpline(u, t, I, p, h[1:(n + 1)], z, extrapolate, cache_parameters, linear_lookup)
end

function CubicSpline(u::uType,
t;
extrapolate = false, cache_parameters = false,
assume_linear_t = 1e-2) where {uType <:
AbstractMatrix}
u, t = munge_data(u, t)
n = length(t) - 1
h = vcat(0, map(k -> t[k + 1] - t[k], 1:(length(t) - 1)), 0)
dl = vcat(h[2:n], zero(eltype(h)))
d_tmp = 2 .* (h[1:(n + 1)] .+ h[2:(n + 2)])
du = vcat(zero(eltype(h)), h[3:(n + 1)])
tA = Tridiagonal(dl, d_tmp, du)

# zero for element type of d, which we don't know yet
ax = axes(u)[1:end-1]
typed_zero = zero(6(u[ax..., begin + 2] - u[ax..., begin + 1]) / h[begin + 2] -
6(u[ax..., begin + 1] - u[ax..., begin]) / h[begin + 1])

h_ = reshape(h, 1, :)
d = 6*((u[ax..., 3:n+1] - u[ax..., 2:n]) ./ h_[:, 3:n+1]) - 6*((u[ax..., 2:n] - u[ax..., 1:n-1]) ./ h_[:, 2:n])
d = cat(typed_zero, d, typed_zero; dims = ndims(d))

z = (tA \ d')'
linear_lookup = seems_linear(assume_linear_t, t)
p = CubicSplineParameterCache(u, h, z, cache_parameters)
A = CubicSpline(
u, t, nothing, p, h[1:(n + 1)], z, extrapolate, cache_parameters, linear_lookup)
I = cumulative_integral(A, cache_parameters)
CubicSpline(u, t, I, p, h[1:(n + 1)], z, extrapolate, cache_parameters, linear_lookup)
end

function CubicSpline(
u::uType, t; extrapolate = false, cache_parameters = false,
assume_linear_t = 1e-2) where {uType <:
Expand All @@ -486,6 +517,8 @@ function CubicSpline(
CubicSpline(u, t, I, p, h[1:(n + 1)], z, extrapolate, cache_parameters, assume_linear_t)
end



"""
BSplineInterpolation(u, t, d, pVecType, knotVecType; extrapolate = false, safetycopy = true)
Expand Down

0 comments on commit e3406e5

Please sign in to comment.