Skip to content

Commit 64a5423

Browse files
committed
Allow smoothing splines with vector data
1 parent 3754f65 commit 64a5423

File tree

2 files changed

+74
-14
lines changed

2 files changed

+74
-14
lines changed

src/SplineInterpolations/smoothing.jl

+31-12
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,9 @@ function fit(
5050
λ 0 || throw(DomainError(λ, "the smoothing parameter λ must be non-negative"))
5151
eachindex(xs) == eachindex(ys) || throw(DimensionMismatch("x and y vectors must have the same length"))
5252
N = length(xs)
53-
cs = similar(xs)
53+
cs = similar(ys)
5454

55-
T = eltype(cs)
55+
T = eltype(xs)
5656

5757
# Create natural cubic B-spline basis with knots = input points
5858
B = BSplineBasis(order, copy(xs))
@@ -98,15 +98,24 @@ function fit(
9898
# ldiv!(cs, F, z)
9999

100100
# Construct RHS trying to reduce allocations
101-
zs = copy(ys)
101+
Y = eltype(ys)
102+
if ndims(Y) == 0 # scalar data -- this includes float types (ndims(Float64) == 0)
103+
cs_lin = cs
104+
zs = copy(ys)
105+
else # vector or multidimensional data, for example for SVector values
106+
Z = eltype(eltype(ys)) # for example, if eltype(ys) <: SVector{D,Z}
107+
@assert Z <: Number
108+
cs_lin = reinterpret(reshape, Z, cs)'
109+
zs = copy(reinterpret(reshape, Z, ys)') # dimensions (N, D)
110+
end
102111
if weights !== nothing
103112
eachindex(weights) == eachindex(xs) || throw(DimensionMismatch("the `weights` vector must have the same length as the data"))
104113
lmul!(Diagonal(weights), zs) # zs = W * ys
105114
end
106-
mul!(cs, A', zs) # cs = A' * (W * ys)
107-
lmul!(2, cs) # cs = 2 * A' * (W * ys)
115+
mul!(cs_lin, A', zs) # cs = A' * (W * ys)
116+
lmul!(2, cs_lin) # cs = 2 * A' * (W * ys)
108117

109-
ldiv!(F, cs) # solve linear system
118+
ldiv!(F, cs_lin) # solve linear system
110119

111120
Spline(R, cs)
112121
end
@@ -120,9 +129,9 @@ function fit(
120129
λ 0 || throw(DomainError(λ, "the smoothing parameter λ must be non-negative"))
121130
eachindex(xs) == eachindex(ys) || throw(DimensionMismatch("x and y vectors must have the same length"))
122131
N = length(xs)
123-
cs = similar(xs)
132+
cs = similar(ys)
124133

125-
T = eltype(cs)
134+
T = eltype(xs)
126135
ts_in = make_knots(xs, order_in, bc)
127136
R = PeriodicBSplineBasis(order_in, ts_in)
128137
k = order(R) # = 4
@@ -170,15 +179,25 @@ function fit(
170179
F = cholesky(M) # factorise matrix (assuming posdef)
171180

172181
# Construct RHS trying to reduce allocations
173-
zs = copy(ys)
182+
Y = eltype(ys)
183+
if ndims(Y) == 0 # scalar data
184+
cs_lin = cs
185+
zs = copy(ys)
186+
else # multidimensional data
187+
Z = eltype(eltype(ys)) # for example, if eltype(ys) <: SVector{D,Z}
188+
@assert Z <: Number
189+
cs_lin = reinterpret(reshape, Z, cs)'
190+
zs = copy(reinterpret(reshape, Z, ys)') # dimensions (N, D)
191+
end
174192
if weights !== nothing
175193
eachindex(weights) == eachindex(xs) || throw(DimensionMismatch("the `weights` vector must have the same length as the data"))
176194
lmul!(Diagonal(weights), zs) # zs = W * ys
177195
end
178-
mul!(cs, A', zs) # cs = A' * (W * ys)
179-
lmul!(2, cs) # cs = 2 * A' * (W * ys)
196+
@show summary(cs_lin) summary(A) summary(zs)
197+
mul!(cs_lin, A', zs) # cs = A' * (W * ys)
198+
lmul!(2, cs_lin) # cs = 2 * A' * (W * ys)
180199

181-
cs .= F \ cs # solve linear system (allocates intermediate array)
200+
cs_lin .= F \ cs_lin # solve linear system (allocates intermediate array)
182201

183202
Spline(R, cs)
184203
end

test/smoothing.jl

+43-2
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
using BSplineKit
44
using QuadGK: quadgk
5+
using StaticArrays
56
using ReverseDiff
67
using Test
78

@@ -49,7 +50,8 @@ function smoothing_objective(cs, R::AbstractBSplineBasis, xs, ys; weights = noth
4950
loss
5051
end
5152

52-
function check_zero_gradient(S::Spline, xs, ys; weights = nothing, λ)
53+
# Scalar data
54+
function _check_zero_gradient(::Type{T}, S::Spline, xs, ys; weights = nothing, λ, rtol = 1e-12) where {T <: Real}
5355
R = basis(S) # usually a RecombinedBSplineBasis
5456
cs = parent(coefficients(S)) # `parent` is useful if this is a PeriodicBSplineBasis
5557

@@ -67,11 +69,27 @@ function check_zero_gradient(S::Spline, xs, ys; weights = nothing, λ)
6769
# f ~ Y² and therefore ∂f/∂cⱼ ~ Y. So we compare it with the sum of |y_i|².
6870
reference = sum(abs2, ys)
6971
err = sum(abs2, ∇f)
70-
@test err / reference < 1e-12
72+
@test err / reference < rtol
7173

7274
nothing
7375
end
7476

77+
# Vector data (e.g. parametric splines)
78+
# Currently all components are separately smoothed, so we verify them as separate scalar functions.
79+
function _check_zero_gradient(::Type{SVector{N, T}}, S::Spline, xs, ys; kws...) where {N, T}
80+
R = basis(S)
81+
cs = coefficients(S)
82+
for i in 1:N
83+
Si = Spline(R, getindex.(cs, i))
84+
check_zero_gradient(Si, xs, getindex.(ys, i); kws...)
85+
end
86+
nothing
87+
end
88+
89+
function check_zero_gradient(S::Spline, xs, ys; kws...)
90+
_check_zero_gradient(eltype(ys), S, xs, ys; kws...)
91+
end
92+
7593
# Returns the integral of |S''(x)| (the "curvature") over the whole spline.
7694
function total_curvature(S::Spline)
7795
ts = knots(S)
@@ -148,4 +166,27 @@ end
148166
Sw = fit(xs, ys, λ; weights)
149167
check_zero_gradient(Sw, xs, ys; λ, weights)
150168
end
169+
170+
@testset "Parametric" begin
171+
λ = 1e-2
172+
N = 100
173+
ts = range(0, 2π; length = N + 1)[1:N]
174+
vs = [0.1 * SVector(cos(t), sin(t)) .+ 0.01 * sin(10 * t) for t in ts]
175+
S_nat = @inferred fit(ts, vs, λ, Natural())
176+
S_per = @inferred fit(ts, vs, λ, Periodic(2π))
177+
178+
@testset "Natural" check_zero_gradient(S_nat, ts, vs; λ)
179+
@testset "Periodic" check_zero_gradient(S_per, ts, vs; λ)
180+
181+
@testset "With weights" begin
182+
weights = fill!(similar(ts), 1)
183+
weights[3] = 1000
184+
185+
S_nat = @inferred fit(ts, vs, λ, Natural(); weights)
186+
S_per = @inferred fit(ts, vs, λ, Periodic(2π); weights)
187+
188+
@testset "Natural" check_zero_gradient(S_nat, ts, vs; λ, weights)
189+
@testset "Periodic" check_zero_gradient(S_per, ts, vs; λ, weights)
190+
end
191+
end
151192
end

0 commit comments

Comments
 (0)