Skip to content

Commit

Permalink
refactor: methods for higher dim arrays
Browse files Browse the repository at this point in the history
  • Loading branch information
sathvikbhagavan committed Oct 17, 2024
1 parent c6c0090 commit caf0b76
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 13 deletions.
27 changes: 27 additions & 0 deletions src/interpolation_caches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,33 @@ function QuadraticSpline(
QuadraticSpline(u, t, I, p, tA, d, z, extrapolate, cache_parameters, linear_lookup)
end

function QuadraticSpline(
u::uType, t; extrapolate = false, cache_parameters = false,
assume_linear_t = 1e-2) where {uType <:
AbstractArray{T, N}} where {T, N}
u, t = munge_data(u, t)
linear_lookup = seems_linear(assume_linear_t, t)
s = length(t)
dl = ones(eltype(t), s - 1)
d_tmp = ones(eltype(t), s)
du = zeros(eltype(t), s - 1)
tA = Tridiagonal(dl, d_tmp, du)
ax = axes(u)[1:(end - 1)]
d_ = map(
i -> i == 1 ? zeros(eltype(t), size(u[ax..., 1])) :
2 // 1 * (u[ax..., i] - u[ax..., i - 1]) / (t[i] - t[i - 1]),
1:s)
d = transpose(reshape(reduce(hcat, d_), :, s))
z_ = reshape(transpose(tA \ d), size(u[ax..., 1])..., :)
z = [z_s for z_s in eachslice(z_, dims = ndims(z_))]

p = QuadraticSplineParameterCache(z, t, cache_parameters)
A = QuadraticSpline(
u, t, nothing, p, tA, d, z, extrapolate, cache_parameters, linear_lookup)
I = cumulative_integral(A, cache_parameters)
QuadraticSpline(u, t, I, p, tA, d, z, extrapolate, cache_parameters, linear_lookup)
end

"""
CubicSpline(u, t; extrapolate = false, cache_parameters = false)
Expand Down
27 changes: 20 additions & 7 deletions src/interpolation_methods.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,13 +87,15 @@ function _interpolate(A::LagrangeInterpolation{<:AbstractVector}, t::Number, igu
N / D
end

function _interpolate(A::LagrangeInterpolation{<:AbstractMatrix}, t::Number, iguess)
function _interpolate(
A::LagrangeInterpolation{<:AbstractArray{T, N}}, t::Number, iguess) where {T, N}
idx = get_idx(A, t, iguess)
findRequiredIdxs!(A, t, idx)
ax = axes(A.u)[1:(end - 1)]
if A.t[A.idxs[1]] == t
return A.u[:, A.idxs[1]]
return A.u[ax..., A.idxs[1]]
end
N = zero(A.u[:, 1])
N1 = zero(A.u[ax..., 1])
D = zero(A.t[1])
tmp = D
for i in 1:length(A.idxs)
Expand All @@ -111,9 +113,9 @@ function _interpolate(A::LagrangeInterpolation{<:AbstractMatrix}, t::Number, igu
end
tmp = inv((t - A.t[A.idxs[i]]) * mult)
D += tmp
@. N += (tmp * A.u[:, A.idxs[i]])
@. N1 += (tmp * A.u[ax..., A.idxs[i]])
end
N / D
N1 / D
end

function _interpolate(A::AkimaInterpolation{<:AbstractVector}, t::Number, iguess)
Expand All @@ -134,15 +136,16 @@ function _interpolate(A::ConstantInterpolation{<:AbstractVector}, t::Number, igu
A.u[idx]
end

function _interpolate(A::ConstantInterpolation{<:AbstractMatrix}, t::Number, iguess)
function _interpolate(
A::ConstantInterpolation{<:AbstractArray{T, N}}, t::Number, iguess) where {T, N}
if A.dir === :left
# :left means that value to the left is used for interpolation
idx = get_idx(A, t, iguess; lb = 1, ub_shift = 0)
else
# :right means that value to the right is used for interpolation
idx = get_idx(A, t, iguess; side = :first, lb = 1, ub_shift = 0)
end
A.u[:, idx]
A.u[axes(A.u)[1:(end - 1)]..., idx]
end

# QuadraticSpline Interpolation
Expand All @@ -154,6 +157,16 @@ function _interpolate(A::QuadraticSpline{<:AbstractVector}, t::Number, iguess)
return A.z[idx] * Δt + σ * Δt^2 + Cᵢ
end

function _interpolate(
A::QuadraticSpline{<:AbstractArray{T, N}}, t::Number, iguess) where {T, N}
idx = get_idx(A, t, iguess)
ax = axes(A.u)[1:(end - 1)]
Cᵢ = A.u[ax..., idx]
Δt = t - A.t[idx]
σ = get_parameters(A, idx)
return A.z[idx] * Δt + σ * Δt^2 + Cᵢ
end

# CubicSpline Interpolation
function _interpolate(A::CubicSpline{<:AbstractVector}, t::Number, iguess)
idx = get_idx(A, t, iguess)
Expand Down
19 changes: 13 additions & 6 deletions src/parameter_caches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,12 @@ function QuadraticParameterCache(u, t, cache_parameters)
end
end

function quadratic_interpolation_parameters(u, t, idx)
if u isa AbstractMatrix
u₀ = u[:, idx]
u₁ = u[:, idx + 1]
u₂ = u[:, idx + 2]
function quadratic_interpolation_parameters(u::AbstractArray{T, N}, t, idx) where {T, N}
if N > 1
ax = axes(u)
u₀ = u[ax[1:(end - 1)]..., idx]
u₁ = u[ax[1:(end - 1)]..., idx + 1]
u₂ = u[ax[1:(end - 1)]..., idx + 2]
else
u₀ = u[idx]
u₁ = u[idx + 1]
Expand Down Expand Up @@ -89,11 +90,17 @@ function QuadraticSplineParameterCache(z, t, cache_parameters)
end
end

function quadratic_spline_parameters(z, t, idx)
function quadratic_spline_parameters(z::AbstractVector, t, idx)
σ = 1 // 2 * (z[idx + 1] - z[idx]) / (t[idx + 1] - t[idx])
return σ
end

function quadratic_spline_parameters(z::AbstractArray, t, idx)
ax = axes(z)[1:(end - 1)]
σ = 1 // 2 * (z[ax..., idx + 1] - z[ax..., idx]) / (t[idx + 1] - t[idx])
return σ
end

struct CubicSplineParameterCache{pType}
c₁::pType
c₂::pType
Expand Down

0 comments on commit caf0b76

Please sign in to comment.