Skip to content

Commit

Permalink
Wrap FastLM.jl
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Oct 16, 2023
1 parent f151a0a commit 1c19fa7
Show file tree
Hide file tree
Showing 6 changed files with 155 additions and 16 deletions.
6 changes: 5 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,11 @@ StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"

[weakdeps]
FastLevenbergMarquardt = "7a0df574-e128-4d35-8cbd-3d84502bf7ce"
LeastSquaresOptim = "0fc2ff8b-aaa3-5acd-a817-1944a5e08891"

[extensions]
NonlinearSolveFastLevenbergMarquardtExt = "FastLevenbergMarquardt"
NonlinearSolveLeastSquaresOptimExt = "LeastSquaresOptim"

[compat]
Expand All @@ -37,6 +39,7 @@ ConcreteStructs = "0.2"
DiffEqBase = "6.130"
EnumX = "1"
Enzyme = "0.11"
FastLevenbergMarquardt = "0.1"
FiniteDiff = "2"
ForwardDiff = "0.10.3"
LeastSquaresOptim = "0.8"
Expand All @@ -57,6 +60,7 @@ julia = "1.9"
[extras]
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
FastLevenbergMarquardt = "7a0df574-e128-4d35-8cbd-3d84502bf7ce"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LeastSquaresOptim = "0fc2ff8b-aaa3-5acd-a817-1944a5e08891"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand All @@ -72,4 +76,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Enzyme", "BenchmarkTools", "SafeTestsets", "Pkg", "Test", "ForwardDiff", "StaticArrays", "Symbolics", "LinearSolve", "Random", "LinearAlgebra", "Zygote", "SparseDiffTools", "NonlinearProblemLibrary", "LeastSquaresOptim"]
test = ["Enzyme", "BenchmarkTools", "SafeTestsets", "Pkg", "Test", "ForwardDiff", "StaticArrays", "Symbolics", "LinearSolve", "Random", "LinearAlgebra", "Zygote", "SparseDiffTools", "NonlinearProblemLibrary", "LeastSquaresOptim", "FastLevenbergMarquardt"]
71 changes: 71 additions & 0 deletions ext/NonlinearSolveFastLevenbergMarquardtExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
module NonlinearSolveFastLevenbergMarquardtExt

using ArrayInterface, NonlinearSolve, SciMLBase
import ConcreteStructs: @concrete
import FastLevenbergMarquardt as FastLM

NonlinearSolve.extension_loaded(::Val{:FastLevenbergMarquardt}) = true

function _fast_lm_solver(::FastLevenbergMarquardtSolver{linsolve}, x) where {linsolve}
if linsolve == :cholesky
return FastLM.CholeskySolver(ArrayInterface.undefmatrix(x))
elseif linsolve == :qr
return FastLM.QRSolver(eltype(x), length(x))
else
throw(ArgumentError("Unknown FastLevenbergMarquardt Linear Solver: $linsolve"))

Check warning on line 15 in ext/NonlinearSolveFastLevenbergMarquardtExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/NonlinearSolveFastLevenbergMarquardtExt.jl#L15

Added line #L15 was not covered by tests
end
end

@concrete struct FastLMCache
f!
J!
prob
alg
lmworkspace
solver
kwargs
end

@concrete struct InplaceFunction{iip} <: Function
f
end

(f::InplaceFunction{true})(fx, x, p) = f.f(fx, x, p)
(f::InplaceFunction{false})(fx, x, p) = (fx .= f.f(x, p))

Check warning on line 34 in ext/NonlinearSolveFastLevenbergMarquardtExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/NonlinearSolveFastLevenbergMarquardtExt.jl#L34

Added line #L34 was not covered by tests

function SciMLBase.__init(prob::NonlinearLeastSquaresProblem,
alg::FastLevenbergMarquardtSolver, args...; abstol = 1e-8, reltol = 1e-8,
verbose = false, maxiters = 1000, kwargs...)
iip = SciMLBase.isinplace(prob)

@assert prob.f.jac!==nothing "FastLevenbergMarquardt requires a Jacobian!"

f! = InplaceFunction{iip}(prob.f)
J! = InplaceFunction{iip}(prob.f.jac)

resid_prototype = prob.f.resid_prototype === nothing ?
(!iip ? prob.f(prob.u0, prob.p) : zeros(prob.u0)) :
prob.f.resid_prototype

J = similar(prob.u0, length(resid_prototype), length(prob.u0))

solver = _fast_lm_solver(alg, prob.u0)
LM = FastLM.LMWorkspace(prob.u0, resid_prototype, J)

return FastLMCache(f!, J!, prob, alg, LM, solver,
(; xtol = abstol, ftol = reltol, maxit = maxiters, alg.factor, alg.factoraccept,
alg.factorreject, alg.minscale, alg.maxscale, alg.factorupdate, alg.minfactor,
alg.maxfactor, kwargs...))
end

function SciMLBase.solve!(cache::FastLMCache)
res, fx, info, iter, nfev, njev, LM, solver = FastLM.lmsolve!(cache.f!, cache.J!,
cache.lmworkspace, cache.prob.p; cache.solver, cache.kwargs...)
stats = SciMLBase.NLStats(nfev, njev, -1, -1, iter)
retcode = info == 1 ? ReturnCode.Success :
(info == -1 ? ReturnCode.MaxIters : ReturnCode.Default)
return SciMLBase.build_solution(cache.prob, cache.alg, res, fx;
retcode, original = (res, fx, info, iter, nfev, njev, LM, solver), stats)
end

end
2 changes: 1 addition & 1 deletion ext/NonlinearSolveLeastSquaresOptimExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ using NonlinearSolve, SciMLBase
import ConcreteStructs: @concrete
import LeastSquaresOptim as LSO

extension_loaded(::Val{:LeastSquaresOptim}) = true
NonlinearSolve.extension_loaded(::Val{:LeastSquaresOptim}) = true

function _lso_solver(::LSOptimSolver{alg, linsolve}) where {alg, linsolve}
ls = linsolve == :qr ? LSO.QR() :
Expand Down
3 changes: 2 additions & 1 deletion src/NonlinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,8 @@ end

export RadiusUpdateSchemes

export NewtonRaphson, TrustRegion, LevenbergMarquardt, GaussNewton, LSOptimSolver
export NewtonRaphson, TrustRegion, LevenbergMarquardt, GaussNewton
export LSOptimSolver, FastLevenbergMarquardtSolver

export LineSearch

Expand Down
62 changes: 57 additions & 5 deletions src/algorithms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,70 @@ Wrapper over [LeastSquaresOptim.jl](https://github.com/matthieugomez/LeastSquare
- `linsolve`: Linear solver to use. Can be `:qr`, `:cholesky` or `:lsmr`. If
`nothing`, then `LeastSquaresOptim.jl` will choose the best linear solver based
on the Jacobian structure.
- `autodiff`: Automatic differentiation / Finite Differences. Can be `:central` or `:forward`.
!!! note
This algorithm is only available if `LeastSquaresOptim.jl` is installed.
"""
struct LSOptimSolver{alg, linsolve} <: AbstractNonlinearSolveAlgorithm
autodiff::Symbol
end

function LSOptimSolver(alg = :lm; linsolve = nothing, autodiff::Symbol = :central)
@assert alg in (:lm, :dogleg)
@assert linsolve === nothing || linsolve in (:qr, :cholesky, :lsmr)
@assert autodiff in (:central, :forward)

if !extension_loaded(Val(:LeastSquaresOptim))
@warn "LeastSquaresOptim.jl is not loaded! It needs to be explicitly loaded \

Check warning on line 29 in src/algorithms.jl

View check run for this annotation

Codecov / codecov/patch

src/algorithms.jl#L29

Added line #L29 was not covered by tests
before `solve(prob, LSOptimSolver())` is called."
end

return LSOptimSolver{alg, linsolve}(autodiff)
end

"""
FastLevenbergMarquardtSolver(linsolve = :cholesky)
Wrapper over [FastLevenbergMarquardt.jl](https://github.com/kamesy/FastLevenbergMarquardt.jl) for solving
`NonlinearLeastSquaresProblem`.
function LSOptimSolver(alg = :lm; linsolve = nothing, autodiff::Symbol = :central)
@assert alg in (:lm, :dogleg)
@assert linsolve === nothing || linsolve in (:qr, :cholesky, :lsmr)
@assert autodiff in (:central, :forward)
!!! warning
This is not really the fastest solver. It is called that since the original package
is called "Fast". `LevenbergMarquardt()` is almost always a better choice.
return new{alg, linsolve}(autodiff)
!!! warning
This algorithm requires the jacobian function to be provided!
## Arguments:
- `linsolve`: Linear solver to use. Can be `:qr` or `:cholesky`.
!!! note
This algorithm is only available if `FastLevenbergMarquardt.jl` is installed.
"""
@concrete struct FastLevenbergMarquardtSolver{linsolve} <: AbstractNonlinearSolveAlgorithm
factor
factoraccept
factorreject
factorupdate::Symbol
minscale
maxscale
minfactor
maxfactor
end

function FastLevenbergMarquardtSolver(linsolve::Symbol = :cholesky; factor = 1e-6,
factoraccept = 13.0, factorreject = 3.0, factorupdate = :marquardt,
minscale = 1e-12, maxscale = 1e16, minfactor = 1e-28, maxfactor = 1e32)
@assert linsolve in (:qr, :cholesky)
@assert factorupdate in (:marquardt, :nielson)

if !extension_loaded(Val(:FastLevenbergMarquardt))
@warn "FastLevenbergMarquardt.jl is not loaded! It needs to be explicitly loaded \

Check warning on line 74 in src/algorithms.jl

View check run for this annotation

Codecov / codecov/patch

src/algorithms.jl#L74

Added line #L74 was not covered by tests
before `solve(prob, FastLevenbergMarquardtSolver())` is called."
end

return FastLevenbergMarquardtSolver{linsolve}(factor, factoraccept, factorreject,
factorupdate, minscale, maxscale, minfactor, maxfactor)
end
27 changes: 19 additions & 8 deletions test/nonlinear_least_squares.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using NonlinearSolve, LinearSolve, LinearAlgebra, Test, Random
import LeastSquaresOptim
using NonlinearSolve, LinearSolve, LinearAlgebra, Test, Random, ForwardDiff
import FastLevenbergMarquardt, LeastSquaresOptim

true_function(x, θ) = @. θ[1] * exp(θ[2] * x) * cos(θ[3] * x + θ[4])
true_function(y, x, θ) = (@. y = θ[1] * exp(θ[2] * x) * cos(θ[3] * x + θ[4]))
Expand Down Expand Up @@ -27,15 +27,26 @@ prob_iip = NonlinearLeastSquaresProblem(NonlinearFunction(loss_function;
resid_prototype = zero(y_target)), θ_init, x)

nlls_problems = [prob_oop, prob_iip]
solvers = [
GaussNewton(),
LevenbergMarquardt(),
LSOptimSolver(:lm),
LSOptimSolver(:dogleg),
]
solvers = [GaussNewton(), LevenbergMarquardt(), LSOptimSolver(:lm), LSOptimSolver(:dogleg)]

for prob in nlls_problems, solver in solvers
@time sol = solve(prob, solver; maxiters = 10000, abstol = 1e-8)
@test SciMLBase.successful_retcode(sol)
@test norm(sol.resid) < 1e-6
end

function jac!(J, θ, p)
ForwardDiff.jacobian!(J, resid -> loss_function(resid, θ, p), θ)
return J
end

prob = NonlinearLeastSquaresProblem(NonlinearFunction(loss_function;
resid_prototype = zero(y_target), jac = jac!), θ_init, x)

solvers = [FastLevenbergMarquardtSolver(:cholesky), FastLevenbergMarquardtSolver(:qr)]

for solver in solvers
@time sol = solve(prob, solver; maxiters = 10000, abstol = 1e-8)
@test SciMLBase.successful_retcode(sol)
@test norm(sol.resid) < 1e-6
end

0 comments on commit 1c19fa7

Please sign in to comment.