Skip to content

Commit 2fad512

Browse files
committed
Add smoothing tests using AD
1 parent cf72cb4 commit 2fad512

File tree

2 files changed

+65
-2
lines changed

2 files changed

+65
-2
lines changed

test/Project.toml

+1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
44
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
55
QuadGK = "1fd47b50-473d-5c70-9696-f719f8f3bcdc"
66
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
7+
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
78
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
89
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
910
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"

test/smoothing.jl

+64-2
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,67 @@
22

33
using BSplineKit
44
using QuadGK: quadgk
5+
using ReverseDiff
56
using Test
67

8+
# This is the objective function that `fit` is supposed to minimise.
9+
# We can verify this using automatic differentiation: the gradient wrt the spline
10+
# coefficients should be zero.
11+
function smoothing_objective(cs, R::AbstractBSplineBasis, xs, ys; weights = nothing, λ)
12+
# Construct spline from coefficients and knots
13+
S = Spline(R, cs)
14+
S″ = Derivative(2) * S
15+
16+
# Compute first term of objective (loss) function
17+
T = eltype(cs)
18+
loss = zero(T)
19+
for i in eachindex(xs, ys)
20+
w = weights === nothing ? 1 : weights[i]
21+
loss += w * abs2(ys[i] - S(xs[i]))
22+
end
23+
24+
# Integrate roughness term interval by interval
25+
for i in eachindex(xs)[1:end-1]
26+
# Note: S″(x) is linear within each interval, and thus the integrand is quadratic.
27+
# Therefore, a two-point GL quadrature is exact (weights = 1 and locations = ±1/√3).
28+
a = xs[i]
29+
b = xs[i + 1]
30+
Δ = (b - a) / 2
31+
xc = (a + b) / 2
32+
gl_weight = 1
33+
gl_ξ = 1 / sqrt(T(3))
34+
for ξ in (-gl_ξ, +gl_ξ)
35+
x = Δ * ξ + xc
36+
loss += λ * Δ * gl_weight * abs2(S″(x))
37+
end
38+
end
39+
40+
loss
41+
end
42+
43+
function check_zero_gradient(S::Spline, xs, ys; weights = nothing, λ)
44+
cs = coefficients(S)
45+
R = basis(S) # usually a RecombinedBSplineBasis
46+
47+
# Not sure how useful this is...
48+
∇f = similar(cs) # gradient wrt coefficients
49+
inputs = (cs,)
50+
results = (∇f,)
51+
# all_results = map(DiffResults.GradientResult, results)
52+
cfg = ReverseDiff.GradientConfig(inputs)
53+
54+
# Compute gradient
55+
ReverseDiff.gradient!(results, cs -> smoothing_objective(cs, R, xs, ys; weights, λ), inputs, cfg)
56+
57+
# Verify that |∇f|² is negligible. Note that is has the same units as |y_i|² ≡ Y², since
58+
# f ~ Y² and therefore ∂f/∂cⱼ ~ Y. So we compare it with the sum of |y_i|².
59+
reference = sum(abs2, ys)
60+
err = sum(abs2, ∇f)
61+
@test err / reference < 1e-12
62+
63+
nothing
64+
end
65+
766
# Returns the integral of |S''(x)| (the "curvature") over the whole spline.
867
function total_curvature(S::Spline)
968
ts = knots(S)
@@ -41,10 +100,11 @@ end
41100
λs = [0.0, 1e-6, 1e-5, 1e-4, 1e-3, 1e-2]
42101
curvatures = similar(λs)
43102
distances = similar(λs)
44-
for i in eachindex(λs)
45-
S = @inferred fit(xs, ys, λs[i])
103+
for (i, λ) in pairs(λs)
104+
S = @inferred fit(xs, ys, λ)
46105
curvatures[i] = total_curvature(S)
47106
distances[i] = distance_from_data(S, xs, ys)
107+
check_zero_gradient(S, xs, ys; λ)
48108
end
49109
@test issorted(curvatures; rev = true) # in decreasing order (small λ => large curvature)
50110
@test issorted(distances) # in increasing order (large λ => large distance from data)
@@ -55,10 +115,12 @@ end
55115
weights = fill!(similar(xs), 1)
56116
S = fit(xs, ys, λ)
57117
Sw = @inferred fit(xs, ys, λ; weights) # equivalent to the default (all weights are 1)
118+
check_zero_gradient(Sw, xs, ys; λ, weights)
58119
@test coefficients(S) == coefficients(Sw)
59120
# Now give more weight to point i = 3
60121
weights[3] = 1000
61122
Sw = fit(xs, ys, λ; weights)
123+
check_zero_gradient(Sw, xs, ys; λ, weights)
62124
@test abs(Sw(xs[3]) - ys[3]) < abs(S(xs[3]) - ys[3]) # the new curve is closer to the data point i = 3
63125
@test total_curvature(Sw) > total_curvature(S) # since we give more importance to data fitting (basically, the sum of weights is larger)
64126
@test distance_from_data(Sw, xs, ys) < distance_from_data(S, xs, ys)

0 commit comments

Comments
 (0)