Skip to content

Commit

Permalink
Merge pull request SciML#349 from ashutosh-b-b/bb/bspline_array
Browse files Browse the repository at this point in the history
BSplineInterpolation: Add support for higher order arrays
  • Loading branch information
ChrisRackauckas authored Oct 21, 2024
2 parents c6c0090 + 26b109f commit 76c545b
Show file tree
Hide file tree
Showing 5 changed files with 357 additions and 5 deletions.
47 changes: 47 additions & 0 deletions src/derivatives.jl
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,30 @@ function _derivative(A::BSplineInterpolation{<:AbstractVector{<:Number}}, t::Num
ducum * A.d * scale
end

function _derivative(
A::BSplineInterpolation{<:AbstractArray{<:Number, N}}, t::Number, iguess) where {N}
# change t into param [0 1]
ax_u = axes(A.u)[1:(end - 1)]
t < A.t[1] && return zeros(size(A.u)[1:(end - 1)]...)
t > A.t[end] && return zeros(size(A.u)[1:(end - 1)]...)
idx = get_idx(A, t, iguess)
n = length(A.t)
scale = (A.p[idx + 1] - A.p[idx]) / (A.t[idx + 1] - A.t[idx])
t_ = A.p[idx] + (t - A.t[idx]) * scale
sc = t isa ForwardDiff.Dual ? zeros(eltype(t), n) : A.sc
spline_coefficients!(sc, A.d - 1, A.k, t_)
ducum = zeros(size(A.u)[1:(end - 1)]...)
if t == A.t[1]
ducum = (A.c[ax_u..., 2] - A.c[ax_u..., 1]) / (A.k[A.d + 2])
else
for i in 1:(n - 1)
ducum = ducum +
sc[i + 1] * (A.c[ax_u..., i + 1] - A.c[ax_u..., i]) /
(A.k[i + A.d + 1] - A.k[i + 1])
end
end
ducum * A.d * scale
end
# BSpline Curve Approx
function _derivative(A::BSplineApprox{<:AbstractVector{<:Number}}, t::Number, iguess)
# change t into param [0 1]
Expand All @@ -185,6 +209,29 @@ function _derivative(A::BSplineApprox{<:AbstractVector{<:Number}}, t::Number, ig
ducum * A.d * scale
end

function _derivative(
A::BSplineApprox{<:AbstractArray{<:Number, N}}, t::Number, iguess) where {N}
# change t into param [0 1]
ax_u = axes(A.u)[1:(end - 1)]
t < A.t[1] && return zeros(size(A.u)[1:(end - 1)]...)
t > A.t[end] && return zeros(size(A.u)[1:(end - 1)]...)
idx = get_idx(A, t, iguess)
scale = (A.p[idx + 1] - A.p[idx]) / (A.t[idx + 1] - A.t[idx])
t_ = A.p[idx] + (t - A.t[idx]) * scale
sc = t isa ForwardDiff.Dual ? zeros(eltype(t), A.h) : A.sc
spline_coefficients!(sc, A.d - 1, A.k, t_)
ducum = zeros(size(A.u)[1:(end - 1)]...)
if t == A.t[1]
ducum = (A.c[ax_u..., 2] - A.c[ax_u..., 1]) / (A.k[A.d + 2])
else
for i in 1:(A.h - 1)
ducum = ducum +
sc[i + 1] * (A.c[ax_u..., i + 1] - A.c[ax_u..., i]) /
(A.k[i + A.d + 1] - A.k[i + 1])
end
end
ducum * A.d * scale
end
# Cubic Hermite Spline
function _derivative(
A::CubicHermiteSpline{<:AbstractVector{<:Number}}, t::Number, iguess)
Expand Down
172 changes: 170 additions & 2 deletions src/interpolation_caches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -597,7 +597,7 @@ struct BSplineInterpolation{uType, tType, pType, kType, cType, scType, T, N} <:
end

function BSplineInterpolation(
u, t, d, pVecType, knotVecType; extrapolate = false, assume_linear_t = 1e-2)
u::AbstractVector, t, d, pVecType, knotVecType; extrapolate = false, assume_linear_t = 1e-2)
u, t = munge_data(u, t)
n = length(t)
n < d + 1 && error("BSplineInterpolation needs at least d + 1, i.e. $(d+1) points.")
Expand Down Expand Up @@ -665,6 +665,79 @@ function BSplineInterpolation(
u, t, d, p, k, c, sc, pVecType, knotVecType, extrapolate, assume_linear_t)
end

function BSplineInterpolation(
u::AbstractArray{T, N}, t, d, pVecType, knotVecType; extrapolate = false,
assume_linear_t = 1e-2) where {T, N}
u, t = munge_data(u, t)
n = length(t)
n < d + 1 && error("BSplineInterpolation needs at least d + 1, i.e. $(d+1) points.")
s = zero(eltype(u))
p = zero(t)
k = zeros(eltype(t), n + d + 1)
l = zeros(eltype(u), n - 1)
p[1] = zero(eltype(t))
p[end] = one(eltype(t))

ax_u = axes(u)[1:(end - 1)]

for i in 2:n
s += ((t[i] - t[i - 1])^2 + sum((u[ax_u..., i] - u[ax_u..., i - 1]) .^ 2))
l[i - 1] = s
end
if pVecType == :Uniform
for i in 2:(n - 1)
p[i] = p[1] + (i - 1) * (p[end] - p[1]) / (n - 1)
end
elseif pVecType == :ArcLen
for i in 2:(n - 1)
p[i] = p[1] + l[i - 1] / s * (p[end] - p[1])
end
end

lidx = 1
ridx = length(k)
while lidx <= (d + 1) && ridx >= (length(k) - d)
k[lidx] = p[1]
k[ridx] = p[end]
lidx += 1
ridx -= 1
end

ps = zeros(eltype(t), n - 2)
s = zero(eltype(t))
for i in 2:(n - 1)
s += p[i]
ps[i - 1] = s
end

if knotVecType == :Uniform
# uniformly spaced knot vector
# this method is not recommended because, if it is used with the chord length method for global interpolation,
# the system of linear equations would be singular.
for i in (d + 2):n
k[i] = k[1] + (i - d - 1) // (n - d) * (k[end] - k[1])
end
elseif knotVecType == :Average
# average spaced knot vector
idx = 1
if d + 2 <= n
k[d + 2] = 1 // d * ps[d]
end
for i in (d + 3):n
k[i] = 1 // d * (ps[idx + d] - ps[idx])
idx += 1
end
end
# control points
sc = zeros(eltype(t), n, n)
spline_coefficients!(sc, d, k, p)
c = (sc \ reshape(u, prod(size(u)[1:(end - 1)]), :)')'
c = reshape(c, size(u)...)
sc = zeros(eltype(t), n)
BSplineInterpolation(
u, t, d, p, k, c, sc, pVecType, knotVecType, extrapolate, assume_linear_t)
end

"""
BSplineApprox(u, t, d, h, pVecType, knotVecType; extrapolate = false)
Expand Down Expand Up @@ -738,7 +811,7 @@ struct BSplineApprox{uType, tType, pType, kType, cType, scType, T, N} <:
end

function BSplineApprox(
u, t, d, h, pVecType, knotVecType; extrapolate = false, assume_linear_t = 1e-2)
u::AbstractVector, t, d, h, pVecType, knotVecType; extrapolate = false, assume_linear_t = 1e-2)
u, t = munge_data(u, t)
n = length(t)
h < d + 1 && error("BSplineApprox needs at least d + 1, i.e. $(d+1) control points.")
Expand Down Expand Up @@ -827,6 +900,101 @@ function BSplineApprox(
u, t, d, h, p, k, c, sc, pVecType, knotVecType, extrapolate, assume_linear_t)
end

function BSplineApprox(
u::AbstractArray{T, N}, t, d, h, pVecType, knotVecType; extrapolate = false,
assume_linear_t = 1e-2) where {T, N}
u, t = munge_data(u, t)
n = length(t)
h < d + 1 && error("BSplineApprox needs at least d + 1, i.e. $(d+1) control points.")
s = zero(eltype(u))
p = zero(t)
k = zeros(eltype(t), h + d + 1)
l = zeros(eltype(u), n - 1)
p[1] = zero(eltype(t))
p[end] = one(eltype(t))

ax_u = axes(u)[1:(end - 1)]

for i in 2:n
s += ((t[i] - t[i - 1])^2 + sum((u[ax_u..., i] - u[ax_u..., i - 1]) .^ 2))
l[i - 1] = s
end
if pVecType == :Uniform
for i in 2:(n - 1)
p[i] = p[1] + (i - 1) * (p[end] - p[1]) / (n - 1)
end
elseif pVecType == :ArcLen
for i in 2:(n - 1)
p[i] = p[1] + l[i - 1] / s * (p[end] - p[1])
end
end

lidx = 1
ridx = length(k)
while lidx <= (d + 1) && ridx >= (length(k) - d)
k[lidx] = p[1]
k[ridx] = p[end]
lidx += 1
ridx -= 1
end

ps = zeros(eltype(t), n - 2)
s = zero(eltype(t))
for i in 2:(n - 1)
s += p[i]
ps[i - 1] = s
end

if knotVecType == :Uniform
# uniformly spaced knot vector
# this method is not recommended because, if it is used with the chord length method for global interpolation,
# the system of linear equations would be singular.
for i in (d + 2):h
k[i] = k[1] + (i - d - 1) // (h - d) * (k[end] - k[1])
end
elseif knotVecType == :Average
# NOTE: verify that average method can be applied when size of k is less than size of p
# average spaced knot vector
idx = 1
if d + 2 <= h
k[d + 2] = 1 // d * ps[d]
end
for i in (d + 3):h
k[i] = 1 // d * (ps[idx + d] - ps[idx])
idx += 1
end
end
# control points
c = zeros(eltype(u), size(u)[1:(end - 1)]..., h)
c[ax_u..., 1] = u[ax_u..., 1]
c[ax_u..., end] = u[ax_u..., end]
q = zeros(eltype(u), size(u)[1:(end - 1)]..., n)
sc = zeros(eltype(t), n, h)
for i in 1:n
spline_coefficients!(view(sc, i, :), d, k, p[i])
end
for k in 2:(n - 1)
q[ax_u..., k] = u[ax_u..., k] - sc[k, 1] * u[ax_u..., 1] -
sc[k, h] * u[ax_u..., end]
end
Q = Array{eltype(u), N}(undef, size(u)[1:(end - 1)]..., h - 2)
for i in 2:(h - 1)
s = zeros(eltype(sc), size(u)[1:(end - 1)]...)
for k in 2:(n - 1)
s = s + sc[k, i] * q[ax_u..., k]
end
Q[ax_u..., i - 1] = s
end
sc = sc[2:(end - 1), 2:(h - 1)]
M = transpose(sc) * sc
Q = reshape(Q, prod(size(u)[1:(end - 1)]), :)
P = (M \ Q')'
P = reshape(P, size(u)[1:(end - 1)]..., :)
c[ax_u..., 2:(end - 1)] = P
sc = zeros(eltype(t), h)
BSplineApprox(
u, t, d, h, p, k, c, sc, pVecType, knotVecType, extrapolate, assume_linear_t)
end
"""
CubicHermiteSpline(du, u, t; extrapolate = false, cache_parameters = false)
Expand Down
36 changes: 36 additions & 0 deletions src/interpolation_methods.jl
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,25 @@ function _interpolate(A::BSplineInterpolation{<:AbstractVector{<:Number}},
ucum
end

function _interpolate(A::BSplineInterpolation{<:AbstractArray{T, N}},
t::Number,
iguess) where {T <: Number, N}
ax_u = axes(A.u)[1:(end - 1)]
t < A.t[1] && return A.u[ax_u..., 1]
t > A.t[end] && return A.u[ax_u..., end]
# change t into param [0 1]
idx = get_idx(A, t, iguess)
t = A.p[idx] + (t - A.t[idx]) / (A.t[idx + 1] - A.t[idx]) * (A.p[idx + 1] - A.p[idx])
n = length(A.t)
sc = t isa ForwardDiff.Dual ? zeros(eltype(t), n) : A.sc
nonzero_coefficient_idxs = spline_coefficients!(sc, A.d, A.k, t)
ucum = zeros(eltype(A.u), size(A.u)[1:(end - 1)]...)
for i in nonzero_coefficient_idxs
ucum = ucum + (sc[i] * A.c[ax_u..., i])
end
ucum
end

# BSpline Curve Approx
function _interpolate(A::BSplineApprox{<:AbstractVector{<:Number}}, t::Number, iguess)
t < A.t[1] && return A.u[1]
Expand All @@ -213,6 +232,23 @@ function _interpolate(A::BSplineApprox{<:AbstractVector{<:Number}}, t::Number, i
ucum
end

function _interpolate(
A::BSplineApprox{<:AbstractArray{T, N}}, t::Number, iguess) where {T <: Number, N}
ax_u = axes(A.u)[1:(end - 1)]
t < A.t[1] && return A.u[ax_u..., 1]
t > A.t[end] && return A.u[ax_u..., end]
# change t into param [0 1]
idx = get_idx(A, t, iguess)
t = A.p[idx] + (t - A.t[idx]) / (A.t[idx + 1] - A.t[idx]) * (A.p[idx + 1] - A.p[idx])
sc = t isa ForwardDiff.Dual ? zeros(eltype(t), A.h) : A.sc
nonzero_coefficient_idxs = spline_coefficients!(sc, A.d, A.k, t)
ucum = zeros(eltype(A.u), size(A.u)[1:(end - 1)]...)
for i in nonzero_coefficient_idxs
ucum = ucum + (sc[i] * A.c[ax_u..., i])
end
ucum
end

# Cubic Hermite Spline
function _interpolate(
A::CubicHermiteSpline{<:AbstractVector{<:Number}}, t::Number, iguess)
Expand Down
38 changes: 38 additions & 0 deletions test/derivative_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,44 @@ end
:Uniform,
:Uniform],
name = "BSpline Approx (Uniform, Uniform)")

f3d(t) = [sin(t) cos(t);
0.0 cos(2t)]

t3d = 0.1:0.1:1.0 |> collect
u3d = cat(f3d.(t3d)...; dims = 3)
test_derivatives(BSplineInterpolation;
args = [u3d, t3d,
2,
:Uniform,
:Uniform],
name = "BSpline Interpolation (Uniform, Uniform): AbstractArray"
)

test_derivatives(BSplineInterpolation;
args = [u3d, t3d,
2,
:ArcLen,
:Average],
name = "BSpline Interpolation (Arclen, Average): AbstractArray"
)

test_derivatives(BSplineApprox;
args = [u3d, t3d,
3,
4,
:Uniform,
:Uniform],
name = "BSpline Approx (Uniform, Uniform): AbstractArray")

test_derivatives(BSplineApprox;
args = [u3d, t3d,
3,
4,
:ArcLen,
:Average],
name = "BSpline Approx (Arclen, Average): AbstractArray"
)
end

@testset "Cubic Hermite Spline" begin
Expand Down
Loading

0 comments on commit 76c545b

Please sign in to comment.