Skip to content

Commit

Permalink
More Robust QR for LM
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Oct 23, 2023
1 parent 80e6459 commit fabd33c
Show file tree
Hide file tree
Showing 5 changed files with 131 additions and 54 deletions.
6 changes: 2 additions & 4 deletions docs/src/solvers/NonlinearLeastSquaresSolvers.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,8 @@ Solves the nonlinear least squares problem defined by `prob` using the algorithm
handling of sparse matrices via colored automatic differentiation and preconditioned
linear solvers. Designed for large-scale and numerically-difficult nonlinear least squares
problems.
- `SimpleNewtonRaphson()`: Newton Raphson implementation that uses Linear Least Squares
solution at every step to compute the descent direction. **WARNING**: This method is not
a robust solver for nonlinear least squares problems. The computed delta step might not
be the correct descent direction!
- `SimpleNewtonRaphson()`: Simple Gauss Newton Implementation with `QRFactorization` to
solve a linear least squares problem at each step!

## Example usage

Expand Down
6 changes: 4 additions & 2 deletions src/NonlinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,10 @@ import PrecompileTools
for T in (Float32, Float64)
prob = NonlinearProblem{false}((u, p) -> u .* u .- p, T(0.1), T(2))

precompile_algs = (NewtonRaphson(), TrustRegion(), LevenbergMarquardt(),
PseudoTransient(), GeneralBroyden(), GeneralKlement(), nothing)
# precompile_algs = (NewtonRaphson(), TrustRegion(), LevenbergMarquardt(),
# PseudoTransient(), GeneralBroyden(), GeneralKlement(), nothing)
# DON'T MERGE
precompile_algs = ()

for alg in precompile_algs
solve(prob, alg, abstol = T(1e-2))
Expand Down
159 changes: 116 additions & 43 deletions src/levenberg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@ An advanced Levenberg-Marquardt implementation with the improvements suggested i
algorithm for nonlinear least-squares minimization". Designed for large-scale and
numerically-difficult nonlinear systems.
If no `linsolve` is provided or a variant of `QR` is provided, then we will use an efficient
routine for the factorization without constructing `JᵀJ` and `Jᵀf`. For more details see
"Chapter 10: Implementation of the Levenberg-Marquardt Method" of
["Numerical Optimization" by Jorge Nocedal & Stephen J. Wright](https://link.springer.com/book/10.1007/978-0-387-40065-5).
### Keyword Arguments
- `autodiff`: determines the backend used for the Jacobian. Note that this argument is
Expand Down Expand Up @@ -104,7 +109,8 @@ function LevenbergMarquardt(; concrete_jac = nothing, linsolve = nothing,
finite_diff_step_geodesic, α_geodesic, b_uphill, min_damping_D)
end

@concrete mutable struct LevenbergMarquardtCache{iip} <: AbstractNonlinearSolveCache{iip}
@concrete mutable struct LevenbergMarquardtCache{iip, fastqr} <:
AbstractNonlinearSolveCache{iip}
f
alg
u
Expand Down Expand Up @@ -144,6 +150,8 @@ end
u_tmp
Jv
mat_tmp
rhs_tmp
stats::NLStats
end

Expand All @@ -155,8 +163,26 @@ function SciMLBase.__init(prob::Union{NonlinearProblem{uType, iip},
@unpack f, u0, p = prob
u = alias_u0 ? u0 : deepcopy(u0)
fu1 = evaluate_f(prob, u)
uf, linsolve, J, fu2, jac_cache, du, JᵀJ, v = jacobian_caches(alg, f, u, p, Val(iip);
linsolve_kwargs, linsolve_with_JᵀJ = Val(true))

# Use QR if the user did not specify a linear solver
if (alg.linsolve === nothing || alg.linsolve isa QRFactorization ||
alg.linsolve isa FastQRFactorization) && !(u isa Number)
linsolve_with_JᵀJ = Val(false)

Check warning on line 170 in src/levenberg.jl

View check run for this annotation

Codecov / codecov/patch

src/levenberg.jl#L170

Added line #L170 was not covered by tests
else
linsolve_with_JᵀJ = Val(true)
end

if _unwrap_val(linsolve_with_JᵀJ)
uf, linsolve, J, fu2, jac_cache, du, JᵀJ, v = jacobian_caches(alg, f, u, p,
Val(iip); linsolve_kwargs, linsolve_with_JᵀJ)
= nothing
else
uf, linsolve, J, fu2, jac_cache, du = jacobian_caches(alg, f, u, p, Val(iip);

Check warning on line 180 in src/levenberg.jl

View check run for this annotation

Codecov / codecov/patch

src/levenberg.jl#L180

Added line #L180 was not covered by tests
linsolve_kwargs, linsolve_with_JᵀJ)
JᵀJ = similar(u)
= similar(J)
v = similar(du)

Check warning on line 184 in src/levenberg.jl

View check run for this annotation

Codecov / codecov/patch

src/levenberg.jl#L182-L184

Added lines #L182 - L184 were not covered by tests
end

λ = convert(eltype(u), alg.damping_initial)
λ_factor = convert(eltype(u), alg.damping_increase_factor)
Expand All @@ -182,16 +208,26 @@ function SciMLBase.__init(prob::Union{NonlinearProblem{uType, iip},
δ = _mutable_zero(u)
make_new_J = true
fu_tmp = zero(fu1)
mat_tmp = zero(JᵀJ)

return LevenbergMarquardtCache{iip}(f, alg, u, fu1, fu2, du, p, uf, linsolve, J,
if _unwrap_val(linsolve_with_JᵀJ)
mat_tmp = zero(JᵀJ)
rhs_tmp = nothing
else
mat_tmp = similar(JᵀJ, length(fu1) + length(u), length(u))
fill!(mat_tmp, zero(eltype(u)))
rhs_tmp = similar(mat_tmp, length(fu1) + length(u))
fill!(rhs_tmp, zero(eltype(u)))

Check warning on line 219 in src/levenberg.jl

View check run for this annotation

Codecov / codecov/patch

src/levenberg.jl#L216-L219

Added lines #L216 - L219 were not covered by tests
end

return LevenbergMarquardtCache{iip, !_unwrap_val(linsolve_with_JᵀJ)}(f, alg, u, fu1,
fu2, du, p, uf, linsolve, J,
jac_cache, false, maxiters, internalnorm, ReturnCode.Default, abstol, prob, DᵀD,
JᵀJ, λ, λ_factor, damping_increase_factor, damping_decrease_factor, h, α_geodesic,
b_uphill, min_damping_D, v, a, tmp_vec, v_old, loss, δ, loss, make_new_J, fu_tmp,
zero(u), zero(fu1), mat_tmp, NLStats(1, 0, 0, 0, 0))
zero(u), zero(fu1), mat_tmp, rhs_tmp, J², NLStats(1, 0, 0, 0, 0))
end

function perform_step!(cache::LevenbergMarquardtCache{true})
function perform_step!(cache::LevenbergMarquardtCache{true, fastqr}) where {fastqr}
@unpack fu1, f, make_new_J = cache
if iszero(fu1)
cache.force_stop = true
Expand All @@ -200,35 +236,57 @@ function perform_step!(cache::LevenbergMarquardtCache{true})

if make_new_J
jacobian!!(cache.J, cache)
__matmul!(cache.JᵀJ, cache.J', cache.J)
cache.DᵀD .= max.(cache.DᵀD, Diagonal(cache.JᵀJ))
if fastqr
cache.J² .= cache.J .^ 2
sum!(cache.JᵀJ', cache.J²)
cache.DᵀD.diag .= max.(cache.DᵀD.diag, cache.JᵀJ)

Check warning on line 242 in src/levenberg.jl

View check run for this annotation

Codecov / codecov/patch

src/levenberg.jl#L240-L242

Added lines #L240 - L242 were not covered by tests
else
__matmul!(cache.JᵀJ, cache.J', cache.J)
cache.DᵀD .= max.(cache.DᵀD, Diagonal(cache.JᵀJ))
end
cache.make_new_J = false
cache.stats.njacs += 1
end
@unpack u, p, λ, JᵀJ, DᵀD, J, alg, linsolve = cache

# Usual Levenberg-Marquardt step ("velocity").
# The following lines do: cache.v = -cache.mat_tmp \ cache.u_tmp
mul!(_vec(cache.u_tmp), J', _vec(fu1))
@. cache.mat_tmp = JᵀJ + λ * DᵀD
linres = dolinsolve(alg.precs, linsolve; A = __maybe_symmetric(cache.mat_tmp),
b = _vec(cache.u_tmp), linu = _vec(cache.du), p = p, reltol = cache.abstol)
cache.linsolve = linres.cache
_vec(cache.v) .= -1 .* _vec(cache.du)
if fastqr
cache.mat_tmp[1:length(fu1), :] .= cache.J
cache.mat_tmp[(length(fu1) + 1):end, :] .= λ .* cache.DᵀD
cache.rhs_tmp[1:length(fu1)] .= _vec(fu1)
linres = dolinsolve(alg.precs, linsolve; A = cache.mat_tmp,

Check warning on line 258 in src/levenberg.jl

View check run for this annotation

Codecov / codecov/patch

src/levenberg.jl#L255-L258

Added lines #L255 - L258 were not covered by tests
b = cache.rhs_tmp, linu = _vec(cache.du), p = p, reltol = cache.abstol)
_vec(cache.v) .= -_vec(cache.du)

Check warning on line 260 in src/levenberg.jl

View check run for this annotation

Codecov / codecov/patch

src/levenberg.jl#L260

Added line #L260 was not covered by tests
else
mul!(_vec(cache.u_tmp), J', _vec(fu1))
@. cache.mat_tmp = JᵀJ + λ * DᵀD
linres = dolinsolve(alg.precs, linsolve; A = __maybe_symmetric(cache.mat_tmp),
b = _vec(cache.u_tmp), linu = _vec(cache.du), p = p, reltol = cache.abstol)
cache.linsolve = linres.cache
_vec(cache.v) .= -_vec(cache.du)
end

# Geodesic acceleration (step_size = v + a / 2).
@unpack v, α_geodesic, h = cache
f(cache.fu_tmp, _restructure(u, _vec(u) .+ h .* _vec(v)), p)
cache.u_tmp .= _restructure(cache.u_tmp, _vec(u) .+ h .* _vec(v))
f(cache.fu_tmp, cache.u_tmp, p)

# The following lines do: cache.a = -J \ cache.fu_tmp
# NOTE: Don't pass `A` in again, since we want to reuse the previous solve
mul!(_vec(cache.Jv), J, _vec(v))
@. cache.fu_tmp = (2 / h) * ((cache.fu_tmp - fu1) / h - cache.Jv)
mul!(_vec(cache.u_tmp), J', _vec(cache.fu_tmp))
# NOTE: Don't pass `A` in again, since we want to reuse the previous solve
linres = dolinsolve(alg.precs, linsolve; b = _vec(cache.u_tmp),
linu = _vec(cache.du), p = p, reltol = cache.abstol)
cache.linsolve = linres.cache
@. cache.a = -cache.du
if fastqr
cache.rhs_tmp[1:length(fu1)] .= _vec(cache.fu_tmp)
linres = dolinsolve(alg.precs, linsolve; b = cache.rhs_tmp, linu = _vec(cache.du),

Check warning on line 281 in src/levenberg.jl

View check run for this annotation

Codecov / codecov/patch

src/levenberg.jl#L280-L281

Added lines #L280 - L281 were not covered by tests
p = p, reltol = cache.abstol)
else
mul!(_vec(cache.u_tmp), J', _vec(cache.fu_tmp))
linres = dolinsolve(alg.precs, linsolve; b = _vec(cache.u_tmp),
linu = _vec(cache.du), p = p, reltol = cache.abstol)
cache.linsolve = linres.cache
@. cache.a = -cache.du
end
cache.stats.nsolve += 2
cache.stats.nfactors += 2

Expand Down Expand Up @@ -263,7 +321,7 @@ function perform_step!(cache::LevenbergMarquardtCache{true})
return nothing
end

function perform_step!(cache::LevenbergMarquardtCache{false})
function perform_step!(cache::LevenbergMarquardtCache{false, fastqr}) where {fastqr}

Check warning on line 324 in src/levenberg.jl

View check run for this annotation

Codecov / codecov/patch

src/levenberg.jl#L324

Added line #L324 was not covered by tests
@unpack fu1, f, make_new_J = cache
if iszero(fu1)
cache.force_stop = true
Expand All @@ -272,40 +330,55 @@ function perform_step!(cache::LevenbergMarquardtCache{false})

if make_new_J
cache.J = jacobian!!(cache.J, cache)
cache.JᵀJ = cache.J' * cache.J
if cache.JᵀJ isa Number
cache.DᵀD = max(cache.DᵀD, cache.JᵀJ)
if fastqr
cache.JᵀJ = _vec(sum(cache.J .^ 2; dims = 1))
cache.DᵀD.diag .= max.(cache.DᵀD.diag, cache.JᵀJ)

Check warning on line 335 in src/levenberg.jl

View check run for this annotation

Codecov / codecov/patch

src/levenberg.jl#L333-L335

Added lines #L333 - L335 were not covered by tests
else
cache.DᵀD .= max.(cache.DᵀD, Diagonal(cache.JᵀJ))
cache.JᵀJ = cache.J' * cache.J
if cache.JᵀJ isa Number
cache.DᵀD = max(cache.DᵀD, cache.JᵀJ)

Check warning on line 339 in src/levenberg.jl

View check run for this annotation

Codecov / codecov/patch

src/levenberg.jl#L337-L339

Added lines #L337 - L339 were not covered by tests
else
cache.DᵀD .= max.(cache.DᵀD, Diagonal(cache.JᵀJ))

Check warning on line 341 in src/levenberg.jl

View check run for this annotation

Codecov / codecov/patch

src/levenberg.jl#L341

Added line #L341 was not covered by tests
end
end
cache.make_new_J = false
cache.stats.njacs += 1
end
@unpack u, p, λ, JᵀJ, DᵀD, J, linsolve, alg = cache

cache.mat_tmp = JᵀJ + λ * DᵀD
# Usual Levenberg-Marquardt step ("velocity").
if linsolve === nothing
cache.v = -cache.mat_tmp \ (J' * fu1)
if fastqr
cache.mat_tmp = vcat(J, λ .* cache.DᵀD)
cache.rhs_tmp[1:length(fu1)] .= -_vec(fu1)
linres = dolinsolve(alg.precs, linsolve; A = cache.mat_tmp,

Check warning on line 353 in src/levenberg.jl

View check run for this annotation

Codecov / codecov/patch

src/levenberg.jl#L350-L353

Added lines #L350 - L353 were not covered by tests
b = cache.rhs_tmp, linu = _vec(cache.v), p = p, reltol = cache.abstol)
else
linres = dolinsolve(alg.precs, linsolve; A = -__maybe_symmetric(cache.mat_tmp),
b = _vec(J' * _vec(fu1)), linu = _vec(cache.v), p, reltol = cache.abstol)
cache.linsolve = linres.cache
cache.mat_tmp = JᵀJ + λ * DᵀD
if linsolve === nothing
cache.v = -cache.mat_tmp \ (J' * fu1)

Check warning on line 358 in src/levenberg.jl

View check run for this annotation

Codecov / codecov/patch

src/levenberg.jl#L356-L358

Added lines #L356 - L358 were not covered by tests
else
linres = dolinsolve(alg.precs, linsolve; A = -__maybe_symmetric(cache.mat_tmp),

Check warning on line 360 in src/levenberg.jl

View check run for this annotation

Codecov / codecov/patch

src/levenberg.jl#L360

Added line #L360 was not covered by tests
b = _vec(J' * _vec(fu1)), linu = _vec(cache.v), p, reltol = cache.abstol)
cache.linsolve = linres.cache

Check warning on line 362 in src/levenberg.jl

View check run for this annotation

Codecov / codecov/patch

src/levenberg.jl#L362

Added line #L362 was not covered by tests
end
end

@unpack v, h, α_geodesic = cache
# Geodesic acceleration (step_size = v + a / 2).
if linsolve === nothing
cache.a = -cache.mat_tmp \
_vec(J' * ((2 / h) .* ((f(u .+ h .* v, p) .- fu1) ./ h .- J * v)))
else
rhs_term = _vec(((2 / h) .* ((_vec(f(u .+ h .* _restructure(u, v), p)) .-

Check warning on line 368 in src/levenberg.jl

View check run for this annotation

Codecov / codecov/patch

src/levenberg.jl#L368

Added line #L368 was not covered by tests
_vec(fu1)) ./ h .- J * _vec(v))))
if fastqr
cache.rhs_tmp[1:length(fu1)] .= -_vec(rhs_term)

Check warning on line 371 in src/levenberg.jl

View check run for this annotation

Codecov / codecov/patch

src/levenberg.jl#L370-L371

Added lines #L370 - L371 were not covered by tests
linres = dolinsolve(alg.precs, linsolve;
b = _mutable(_vec(J' * #((2 / h) .* ((f(u .+ h .* v, p) .- fu1) ./ h .- J * v)))),
_vec(((2 / h) .*
((_vec(f(u .+ h .* _restructure(u, v), p)) .-
_vec(fu1)) ./ h .- J * _vec(v)))))),
linu = _vec(cache.a), p, reltol = cache.abstol)
cache.linsolve = linres.cache
b = cache.rhs_tmp, linu = _vec(cache.a), p = p, reltol = cache.abstol)
else
if linsolve === nothing
cache.a = -cache.mat_tmp \ _vec(J' * rhs_term)

Check warning on line 376 in src/levenberg.jl

View check run for this annotation

Codecov / codecov/patch

src/levenberg.jl#L375-L376

Added lines #L375 - L376 were not covered by tests
else
linres = dolinsolve(alg.precs, linsolve; b = _mutable(_vec(J' * rhs_term)),

Check warning on line 378 in src/levenberg.jl

View check run for this annotation

Codecov / codecov/patch

src/levenberg.jl#L378

Added line #L378 was not covered by tests
linu = _vec(cache.a), p, reltol = cache.abstol)
cache.linsolve = linres.cache

Check warning on line 380 in src/levenberg.jl

View check run for this annotation

Codecov / codecov/patch

src/levenberg.jl#L380

Added line #L380 was not covered by tests
end
end
cache.stats.nsolve += 1
cache.stats.nfactors += 1
Expand Down
4 changes: 1 addition & 3 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ function value_derivative(f::F, x::R) where {F, R}
ForwardDiff.value(out), ForwardDiff.extract_derivative(T, out)
end

# Todo: improve this dispatch
function value_derivative(f::F, x::SVector) where {F}
f(x), ForwardDiff.jacobian(f, x)
end
Expand Down Expand Up @@ -206,8 +205,7 @@ function __get_concrete_algorithm(alg, prob)
# Use Finite Differencing
use_sparse_ad ? AutoSparseFiniteDiff() : AutoFiniteDiff()
else
use_sparse_ad ? AutoSparseForwardDiff() :
AutoForwardDiff{ForwardDiff.pickchunksize(length(prob.u0)), Nothing}(nothing)
use_sparse_ad ? AutoSparseForwardDiff() : AutoForwardDiff{nothing, Nothing}(nothing)
end
return set_ad(alg, ad)
end
Expand Down
10 changes: 8 additions & 2 deletions test/nonlinear_least_squares.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,14 @@ prob_iip = NonlinearLeastSquaresProblem(NonlinearFunction(loss_function;
resid_prototype = zero(y_target)), θ_init, x)

nlls_problems = [prob_oop, prob_iip]
solvers = [GaussNewton(), LevenbergMarquardt(), LeastSquaresOptimJL(:lm),
LeastSquaresOptimJL(:dogleg)]
solvers = [
GaussNewton(),
GaussNewton(; linsolve = CholeskyFactorization()),
LevenbergMarquardt(),
LevenbergMarquardt(; linsolve = CholeskyFactorization()),
LeastSquaresOptimJL(:lm),
LeastSquaresOptimJL(:dogleg),
]

for prob in nlls_problems, solver in solvers
@time sol = solve(prob, solver; maxiters = 10000, abstol = 1e-8)
Expand Down

0 comments on commit fabd33c

Please sign in to comment.