From 22fd30a4f395e8ed600c669b51559fc70c2c483b Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 7 Dec 2023 16:54:10 -0500 Subject: [PATCH] Automatically construct the jacobian for FastLM --- ...NonlinearSolveFastLevenbergMarquardtExt.jl | 69 +++++++++++++++---- ext/NonlinearSolveLeastSquaresOptimExt.jl | 10 +-- src/default.jl | 12 ++-- src/extension_algs.jl | 18 +++-- test/nonlinear_least_squares.jl | 31 +++++++-- 5 files changed, 107 insertions(+), 33 deletions(-) diff --git a/ext/NonlinearSolveFastLevenbergMarquardtExt.jl b/ext/NonlinearSolveFastLevenbergMarquardtExt.jl index e292e29c6..6d33c23ba 100644 --- a/ext/NonlinearSolveFastLevenbergMarquardtExt.jl +++ b/ext/NonlinearSolveFastLevenbergMarquardtExt.jl @@ -3,11 +3,12 @@ module NonlinearSolveFastLevenbergMarquardtExt using ArrayInterface, NonlinearSolve, SciMLBase import ConcreteStructs: @concrete import FastLevenbergMarquardt as FastLM +import FiniteDiff, ForwardDiff function _fast_lm_solver(::FastLevenbergMarquardtJL{linsolve}, x) where {linsolve} - if linsolve == :cholesky + if linsolve === :cholesky return FastLM.CholeskySolver(ArrayInterface.undefmatrix(x)) - elseif linsolve == :qr + elseif linsolve === :qr return FastLM.QRSolver(eltype(x), length(x)) else throw(ArgumentError("Unknown FastLevenbergMarquardt Linear Solver: $linsolve")) @@ -33,23 +34,65 @@ end function SciMLBase.__init(prob::NonlinearLeastSquaresProblem, alg::FastLevenbergMarquardtJL, args...; alias_u0 = false, abstol = 1e-8, - reltol = 1e-8, verbose = false, maxiters = 1000, kwargs...) + reltol = 1e-8, maxiters = 1000, kwargs...) iip = SciMLBase.isinplace(prob) - u0 = alias_u0 ? prob.u0 : deepcopy(prob.u0) - - @assert prob.f.jac!==nothing "FastLevenbergMarquardt requires a Jacobian!" + u = NonlinearSolve.__maybe_unaliased(prob.u0, alias_u0) + fu = NonlinearSolve.evaluate_f(prob, u) f! = InplaceFunction{iip}(prob.f) - J! = InplaceFunction{iip}(prob.f.jac) - resid_prototype = prob.f.resid_prototype === nothing ? - (!iip ? prob.f(u0, prob.p) : zeros(u0)) : - prob.f.resid_prototype + if prob.f.jac === nothing + use_forward_diff = if alg.autodiff === nothing + ForwardDiff.can_dual(eltype(u)) + else + alg.autodiff isa AutoForwardDiff + end + uf = SciMLBase.JacobianWrapper{iip}(prob.f, prob.p) + if use_forward_diff + cache = iip ? ForwardDiff.JacobianConfig(uf, fu, u) : + ForwardDiff.JacobianConfig(uf, u) + else + cache = FiniteDiff.JacobianCache(u, fu) + end + J! = if iip + if use_forward_diff + fu_cache = similar(fu) + function (J, x, p) + uf.p = p + ForwardDiff.jacobian!(J, uf, fu_cache, x, cache) + return J + end + else + function (J, x, p) + uf.p = p + FiniteDiff.finite_difference_jacobian!(J, uf, x, cache) + return J + end + end + else + if use_forward_diff + function (J, x, p) + uf.p = p + ForwardDiff.jacobian!(J, uf, x, cache) + return J + end + else + function (J, x, p) + uf.p = p + J_ = FiniteDiff.finite_difference_jacobian(uf, x, cache) + copyto!(J, J_) + return J + end + end + end + else + J! = InplaceFunction{iip}(prob.f.jac) + end - J = similar(u0, length(resid_prototype), length(u0)) + J = similar(u, length(fu), length(u)) - solver = _fast_lm_solver(alg, u0) - LM = FastLM.LMWorkspace(u0, resid_prototype, J) + solver = _fast_lm_solver(alg, u) + LM = FastLM.LMWorkspace(u, fu, J) return FastLevenbergMarquardtJLCache(f!, J!, prob, alg, LM, solver, (; xtol = abstol, ftol = reltol, maxit = maxiters, alg.factor, alg.factoraccept, diff --git a/ext/NonlinearSolveLeastSquaresOptimExt.jl b/ext/NonlinearSolveLeastSquaresOptimExt.jl index 63ce7d0cc..004c26670 100644 --- a/ext/NonlinearSolveLeastSquaresOptimExt.jl +++ b/ext/NonlinearSolveLeastSquaresOptimExt.jl @@ -5,12 +5,12 @@ import ConcreteStructs: @concrete import LeastSquaresOptim as LSO function _lso_solver(::LeastSquaresOptimJL{alg, linsolve}) where {alg, linsolve} - ls = linsolve == :qr ? LSO.QR() : - (linsolve == :cholesky ? LSO.Cholesky() : - (linsolve == :lsmr ? LSO.LSMR() : nothing)) - if alg == :lm + ls = linsolve === :qr ? LSO.QR() : + (linsolve === :cholesky ? LSO.Cholesky() : + (linsolve === :lsmr ? LSO.LSMR() : nothing)) + if alg === :lm return LSO.LevenbergMarquardt(ls) - elseif alg == :dogleg + elseif alg === :dogleg return LSO.Dogleg(ls) else throw(ArgumentError("Unknown LeastSquaresOptim Algorithm: $alg")) diff --git a/src/default.jl b/src/default.jl index fa1eef52b..8b69b5a06 100644 --- a/src/default.jl +++ b/src/default.jl @@ -244,8 +244,10 @@ function FastShortcutNonlinearPolyalg(; concrete_jac = nothing, linsolve = nothi autodiff = nothing) where {JAC, SA} if JAC if SA - algs = (SimpleNewtonRaphson(; autodiff), - SimpleTrustRegion(; autodiff), + algs = (SimpleNewtonRaphson(; + autodiff = ifelse(autodiff === nothing, AutoForwardDiff(), autodiff)), + SimpleTrustRegion(; + autodiff = ifelse(autodiff === nothing, AutoForwardDiff(), autodiff)), NewtonRaphson(; concrete_jac, linsolve, precs, linesearch = BackTracking(), autodiff), TrustRegion(; concrete_jac, linsolve, precs, @@ -263,8 +265,10 @@ function FastShortcutNonlinearPolyalg(; concrete_jac = nothing, linsolve = nothi algs = (SimpleBroyden(), Broyden(; init_jacobian = Val(:true_jacobian)), SimpleKlement(), - SimpleNewtonRaphson(; autodiff), - SimpleTrustRegion(; autodiff), + SimpleNewtonRaphson(; + autodiff = ifelse(autodiff === nothing, AutoForwardDiff(), autodiff)), + SimpleTrustRegion(; + autodiff = ifelse(autodiff === nothing, AutoForwardDiff(), autodiff)), NewtonRaphson(; concrete_jac, linsolve, precs, linesearch = BackTracking(), autodiff), TrustRegion(; concrete_jac, linsolve, precs, diff --git a/src/extension_algs.jl b/src/extension_algs.jl index 8f9ed4400..3fe4b84fd 100644 --- a/src/extension_algs.jl +++ b/src/extension_algs.jl @@ -36,7 +36,7 @@ function LeastSquaresOptimJL(alg = :lm; linsolve = nothing, autodiff::Symbol = : end """ - FastLevenbergMarquardtJL(linsolve = :cholesky) + FastLevenbergMarquardtJL(linsolve = :cholesky; autodiff = nothing) Wrapper over [FastLevenbergMarquardt.jl](https://github.com/kamesy/FastLevenbergMarquardt.jl) for solving `NonlinearLeastSquaresProblem`. @@ -46,19 +46,20 @@ for solving `NonlinearLeastSquaresProblem`. This is not really the fastest solver. It is called that since the original package is called "Fast". `LevenbergMarquardt()` is almost always a better choice. -!!! warning - - This algorithm requires the jacobian function to be provided! - ## Arguments: - `linsolve`: Linear solver to use. Can be `:qr` or `:cholesky`. + - `autodiff`: determines the backend used for the Jacobian. Note that this argument is + ignored if an analytical Jacobian is passed, as that will be used instead. Defaults to + `nothing` which means that a default is selected according to the problem specification! + Valid choices are `nothing`, `AutoForwardDiff` or `AutoFiniteDiff`. !!! note This algorithm is only available if `FastLevenbergMarquardt.jl` is installed. """ @concrete struct FastLevenbergMarquardtJL{linsolve} <: AbstractNonlinearSolveAlgorithm + autodiff factor factoraccept factorreject @@ -71,14 +72,17 @@ end function FastLevenbergMarquardtJL(linsolve::Symbol = :cholesky; factor = 1e-6, factoraccept = 13.0, factorreject = 3.0, factorupdate = :marquardt, - minscale = 1e-12, maxscale = 1e16, minfactor = 1e-28, maxfactor = 1e32) + minscale = 1e-12, maxscale = 1e16, minfactor = 1e-28, maxfactor = 1e32, + autodiff = nothing) @assert linsolve in (:qr, :cholesky) @assert factorupdate in (:marquardt, :nielson) + @assert autodiff === nothing || autodiff isa AutoFiniteDiff || + autodiff isa AutoForwardDiff if Base.get_extension(@__MODULE__, :NonlinearSolveFastLevenbergMarquardtExt) === nothing error("LeastSquaresOptimJL requires FastLevenbergMarquardt.jl to be loaded") end - return FastLevenbergMarquardtJL{linsolve}(factor, factoraccept, factorreject, + return FastLevenbergMarquardtJL{linsolve}(autodiff, factor, factoraccept, factorreject, factorupdate, minscale, maxscale, minfactor, maxfactor) end diff --git a/test/nonlinear_least_squares.jl b/test/nonlinear_least_squares.jl index 7729f74d1..330c2f5da 100644 --- a/test/nonlinear_least_squares.jl +++ b/test/nonlinear_least_squares.jl @@ -89,12 +89,35 @@ function jac!(J, θ, p) return J end -prob = NonlinearLeastSquaresProblem(NonlinearFunction(loss_function; - resid_prototype = zero(y_target), jac = jac!), θ_init, x) +jac(θ, p) = ForwardDiff.jacobian(θ -> loss_function(θ, p), θ) -solvers = [FastLevenbergMarquardtJL(:cholesky), FastLevenbergMarquardtJL(:qr)] +probs = [ + NonlinearLeastSquaresProblem(NonlinearFunction{true}(loss_function; + resid_prototype = zero(y_target), jac = jac!), θ_init, x), + NonlinearLeastSquaresProblem(NonlinearFunction{false}(loss_function; + resid_prototype = zero(y_target), jac = jac), θ_init, x), + NonlinearLeastSquaresProblem(NonlinearFunction{false}(loss_function; jac), θ_init, x), +] + +solvers = [FastLevenbergMarquardtJL(linsolve) for linsolve in (:cholesky, :qr)] + +for solver in solvers, prob in probs + @time sol = solve(prob, solver; maxiters = 10000, abstol = 1e-8) + @test norm(sol.resid) < 1e-6 +end + +probs = [ + NonlinearLeastSquaresProblem(NonlinearFunction{true}(loss_function; + resid_prototype = zero(y_target)), θ_init, x), + NonlinearLeastSquaresProblem(NonlinearFunction{false}(loss_function; + resid_prototype = zero(y_target)), θ_init, x), + NonlinearLeastSquaresProblem(NonlinearFunction{false}(loss_function), θ_init, x), +] + +solvers = [FastLevenbergMarquardtJL(linsolve; autodiff) for linsolve in (:cholesky, :qr), +autodiff in (nothing, AutoForwardDiff(), AutoFiniteDiff())] -for solver in solvers +for solver in solvers, prob in probs @time sol = solve(prob, solver; maxiters = 10000, abstol = 1e-8) @test norm(sol.resid) < 1e-6 end