From 1c19fa7f15538df0eaf922ca80b0648e3f044d8c Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 15 Oct 2023 21:39:03 -0400 Subject: [PATCH] Wrap FastLM.jl --- Project.toml | 6 +- ...NonlinearSolveFastLevenbergMarquardtExt.jl | 71 +++++++++++++++++++ ext/NonlinearSolveLeastSquaresOptimExt.jl | 2 +- src/NonlinearSolve.jl | 3 +- src/algorithms.jl | 62 ++++++++++++++-- test/nonlinear_least_squares.jl | 27 ++++--- 6 files changed, 155 insertions(+), 16 deletions(-) create mode 100644 ext/NonlinearSolveFastLevenbergMarquardtExt.jl diff --git a/Project.toml b/Project.toml index b4f7c2553..ad8c24075 100644 --- a/Project.toml +++ b/Project.toml @@ -25,9 +25,11 @@ StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed" [weakdeps] +FastLevenbergMarquardt = "7a0df574-e128-4d35-8cbd-3d84502bf7ce" LeastSquaresOptim = "0fc2ff8b-aaa3-5acd-a817-1944a5e08891" [extensions] +NonlinearSolveFastLevenbergMarquardtExt = "FastLevenbergMarquardt" NonlinearSolveLeastSquaresOptimExt = "LeastSquaresOptim" [compat] @@ -37,6 +39,7 @@ ConcreteStructs = "0.2" DiffEqBase = "6.130" EnumX = "1" Enzyme = "0.11" +FastLevenbergMarquardt = "0.1" FiniteDiff = "2" ForwardDiff = "0.10.3" LeastSquaresOptim = "0.8" @@ -57,6 +60,7 @@ julia = "1.9" [extras] BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" +FastLevenbergMarquardt = "7a0df574-e128-4d35-8cbd-3d84502bf7ce" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LeastSquaresOptim = "0fc2ff8b-aaa3-5acd-a817-1944a5e08891" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" @@ -72,4 +76,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["Enzyme", "BenchmarkTools", "SafeTestsets", "Pkg", "Test", "ForwardDiff", "StaticArrays", "Symbolics", "LinearSolve", "Random", "LinearAlgebra", "Zygote", "SparseDiffTools", "NonlinearProblemLibrary", "LeastSquaresOptim"] +test = ["Enzyme", "BenchmarkTools", "SafeTestsets", "Pkg", "Test", "ForwardDiff", "StaticArrays", "Symbolics", "LinearSolve", "Random", "LinearAlgebra", "Zygote", "SparseDiffTools", "NonlinearProblemLibrary", "LeastSquaresOptim", "FastLevenbergMarquardt"] diff --git a/ext/NonlinearSolveFastLevenbergMarquardtExt.jl b/ext/NonlinearSolveFastLevenbergMarquardtExt.jl new file mode 100644 index 000000000..7a853c327 --- /dev/null +++ b/ext/NonlinearSolveFastLevenbergMarquardtExt.jl @@ -0,0 +1,71 @@ +module NonlinearSolveFastLevenbergMarquardtExt + +using ArrayInterface, NonlinearSolve, SciMLBase +import ConcreteStructs: @concrete +import FastLevenbergMarquardt as FastLM + +NonlinearSolve.extension_loaded(::Val{:FastLevenbergMarquardt}) = true + +function _fast_lm_solver(::FastLevenbergMarquardtSolver{linsolve}, x) where {linsolve} + if linsolve == :cholesky + return FastLM.CholeskySolver(ArrayInterface.undefmatrix(x)) + elseif linsolve == :qr + return FastLM.QRSolver(eltype(x), length(x)) + else + throw(ArgumentError("Unknown FastLevenbergMarquardt Linear Solver: $linsolve")) + end +end + +@concrete struct FastLMCache + f! + J! + prob + alg + lmworkspace + solver + kwargs +end + +@concrete struct InplaceFunction{iip} <: Function + f +end + +(f::InplaceFunction{true})(fx, x, p) = f.f(fx, x, p) +(f::InplaceFunction{false})(fx, x, p) = (fx .= f.f(x, p)) + +function SciMLBase.__init(prob::NonlinearLeastSquaresProblem, + alg::FastLevenbergMarquardtSolver, args...; abstol = 1e-8, reltol = 1e-8, + verbose = false, maxiters = 1000, kwargs...) + iip = SciMLBase.isinplace(prob) + + @assert prob.f.jac!==nothing "FastLevenbergMarquardt requires a Jacobian!" + + f! = InplaceFunction{iip}(prob.f) + J! = InplaceFunction{iip}(prob.f.jac) + + resid_prototype = prob.f.resid_prototype === nothing ? + (!iip ? prob.f(prob.u0, prob.p) : zeros(prob.u0)) : + prob.f.resid_prototype + + J = similar(prob.u0, length(resid_prototype), length(prob.u0)) + + solver = _fast_lm_solver(alg, prob.u0) + LM = FastLM.LMWorkspace(prob.u0, resid_prototype, J) + + return FastLMCache(f!, J!, prob, alg, LM, solver, + (; xtol = abstol, ftol = reltol, maxit = maxiters, alg.factor, alg.factoraccept, + alg.factorreject, alg.minscale, alg.maxscale, alg.factorupdate, alg.minfactor, + alg.maxfactor, kwargs...)) +end + +function SciMLBase.solve!(cache::FastLMCache) + res, fx, info, iter, nfev, njev, LM, solver = FastLM.lmsolve!(cache.f!, cache.J!, + cache.lmworkspace, cache.prob.p; cache.solver, cache.kwargs...) + stats = SciMLBase.NLStats(nfev, njev, -1, -1, iter) + retcode = info == 1 ? ReturnCode.Success : + (info == -1 ? ReturnCode.MaxIters : ReturnCode.Default) + return SciMLBase.build_solution(cache.prob, cache.alg, res, fx; + retcode, original = (res, fx, info, iter, nfev, njev, LM, solver), stats) +end + +end diff --git a/ext/NonlinearSolveLeastSquaresOptimExt.jl b/ext/NonlinearSolveLeastSquaresOptimExt.jl index 7514931d5..40299f5b3 100644 --- a/ext/NonlinearSolveLeastSquaresOptimExt.jl +++ b/ext/NonlinearSolveLeastSquaresOptimExt.jl @@ -4,7 +4,7 @@ using NonlinearSolve, SciMLBase import ConcreteStructs: @concrete import LeastSquaresOptim as LSO -extension_loaded(::Val{:LeastSquaresOptim}) = true +NonlinearSolve.extension_loaded(::Val{:LeastSquaresOptim}) = true function _lso_solver(::LSOptimSolver{alg, linsolve}) where {alg, linsolve} ls = linsolve == :qr ? LSO.QR() : diff --git a/src/NonlinearSolve.jl b/src/NonlinearSolve.jl index d9dcbba39..efcce13f5 100644 --- a/src/NonlinearSolve.jl +++ b/src/NonlinearSolve.jl @@ -95,7 +95,8 @@ end export RadiusUpdateSchemes -export NewtonRaphson, TrustRegion, LevenbergMarquardt, GaussNewton, LSOptimSolver +export NewtonRaphson, TrustRegion, LevenbergMarquardt, GaussNewton +export LSOptimSolver, FastLevenbergMarquardtSolver export LineSearch diff --git a/src/algorithms.jl b/src/algorithms.jl index b9a694866..3f288433e 100644 --- a/src/algorithms.jl +++ b/src/algorithms.jl @@ -11,18 +11,70 @@ Wrapper over [LeastSquaresOptim.jl](https://github.com/matthieugomez/LeastSquare - `linsolve`: Linear solver to use. Can be `:qr`, `:cholesky` or `:lsmr`. If `nothing`, then `LeastSquaresOptim.jl` will choose the best linear solver based on the Jacobian structure. +- `autodiff`: Automatic differentiation / Finite Differences. Can be `:central` or `:forward`. !!! note This algorithm is only available if `LeastSquaresOptim.jl` is installed. """ struct LSOptimSolver{alg, linsolve} <: AbstractNonlinearSolveAlgorithm autodiff::Symbol +end + +function LSOptimSolver(alg = :lm; linsolve = nothing, autodiff::Symbol = :central) + @assert alg in (:lm, :dogleg) + @assert linsolve === nothing || linsolve in (:qr, :cholesky, :lsmr) + @assert autodiff in (:central, :forward) + + if !extension_loaded(Val(:LeastSquaresOptim)) + @warn "LeastSquaresOptim.jl is not loaded! It needs to be explicitly loaded \ + before `solve(prob, LSOptimSolver())` is called." + end + + return LSOptimSolver{alg, linsolve}(autodiff) +end + +""" + FastLevenbergMarquardtSolver(linsolve = :cholesky) + +Wrapper over [FastLevenbergMarquardt.jl](https://github.com/kamesy/FastLevenbergMarquardt.jl) for solving +`NonlinearLeastSquaresProblem`. - function LSOptimSolver(alg = :lm; linsolve = nothing, autodiff::Symbol = :central) - @assert alg in (:lm, :dogleg) - @assert linsolve === nothing || linsolve in (:qr, :cholesky, :lsmr) - @assert autodiff in (:central, :forward) +!!! warning + This is not really the fastest solver. It is called that since the original package + is called "Fast". `LevenbergMarquardt()` is almost always a better choice. - return new{alg, linsolve}(autodiff) +!!! warning + This algorithm requires the jacobian function to be provided! + +## Arguments: + +- `linsolve`: Linear solver to use. Can be `:qr` or `:cholesky`. + +!!! note + This algorithm is only available if `FastLevenbergMarquardt.jl` is installed. +""" +@concrete struct FastLevenbergMarquardtSolver{linsolve} <: AbstractNonlinearSolveAlgorithm + factor + factoraccept + factorreject + factorupdate::Symbol + minscale + maxscale + minfactor + maxfactor +end + +function FastLevenbergMarquardtSolver(linsolve::Symbol = :cholesky; factor = 1e-6, + factoraccept = 13.0, factorreject = 3.0, factorupdate = :marquardt, + minscale = 1e-12, maxscale = 1e16, minfactor = 1e-28, maxfactor = 1e32) + @assert linsolve in (:qr, :cholesky) + @assert factorupdate in (:marquardt, :nielson) + + if !extension_loaded(Val(:FastLevenbergMarquardt)) + @warn "FastLevenbergMarquardt.jl is not loaded! It needs to be explicitly loaded \ + before `solve(prob, FastLevenbergMarquardtSolver())` is called." end + + return FastLevenbergMarquardtSolver{linsolve}(factor, factoraccept, factorreject, + factorupdate, minscale, maxscale, minfactor, maxfactor) end diff --git a/test/nonlinear_least_squares.jl b/test/nonlinear_least_squares.jl index 9ea4b0eca..4f67ef28e 100644 --- a/test/nonlinear_least_squares.jl +++ b/test/nonlinear_least_squares.jl @@ -1,5 +1,5 @@ -using NonlinearSolve, LinearSolve, LinearAlgebra, Test, Random -import LeastSquaresOptim +using NonlinearSolve, LinearSolve, LinearAlgebra, Test, Random, ForwardDiff +import FastLevenbergMarquardt, LeastSquaresOptim true_function(x, θ) = @. θ[1] * exp(θ[2] * x) * cos(θ[3] * x + θ[4]) true_function(y, x, θ) = (@. y = θ[1] * exp(θ[2] * x) * cos(θ[3] * x + θ[4])) @@ -27,15 +27,26 @@ prob_iip = NonlinearLeastSquaresProblem(NonlinearFunction(loss_function; resid_prototype = zero(y_target)), θ_init, x) nlls_problems = [prob_oop, prob_iip] -solvers = [ - GaussNewton(), - LevenbergMarquardt(), - LSOptimSolver(:lm), - LSOptimSolver(:dogleg), -] +solvers = [GaussNewton(), LevenbergMarquardt(), LSOptimSolver(:lm), LSOptimSolver(:dogleg)] for prob in nlls_problems, solver in solvers @time sol = solve(prob, solver; maxiters = 10000, abstol = 1e-8) @test SciMLBase.successful_retcode(sol) @test norm(sol.resid) < 1e-6 end + +function jac!(J, θ, p) + ForwardDiff.jacobian!(J, resid -> loss_function(resid, θ, p), θ) + return J +end + +prob = NonlinearLeastSquaresProblem(NonlinearFunction(loss_function; + resid_prototype = zero(y_target), jac = jac!), θ_init, x) + +solvers = [FastLevenbergMarquardtSolver(:cholesky), FastLevenbergMarquardtSolver(:qr)] + +for solver in solvers + @time sol = solve(prob, solver; maxiters = 10000, abstol = 1e-8) + @test SciMLBase.successful_retcode(sol) + @test norm(sol.resid) < 1e-6 +end