From 17693fbcea4272286ff98603f00910fa6f226e51 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 16 Oct 2023 11:11:31 -0400 Subject: [PATCH] Rename wrappers to be consistent with other SciML packages --- Project.toml | 2 +- ...NonlinearSolveFastLevenbergMarquardtExt.jl | 12 +++---- ext/NonlinearSolveLeastSquaresOptimExt.jl | 14 ++++---- src/NonlinearSolve.jl | 4 +-- src/{algorithms.jl => extension_algs.jl} | 33 +++++++++---------- test/nonlinear_least_squares.jl | 5 +-- 6 files changed, 32 insertions(+), 38 deletions(-) rename src/{algorithms.jl => extension_algs.jl} (59%) diff --git a/Project.toml b/Project.toml index ad8c24075..1205910b7 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "NonlinearSolve" uuid = "8913a72c-1f9b-4ce2-8d82-65094dcecaec" authors = ["SciML"] -version = "2.3.0" +version = "2.4.0" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/ext/NonlinearSolveFastLevenbergMarquardtExt.jl b/ext/NonlinearSolveFastLevenbergMarquardtExt.jl index 7a853c327..4a096747d 100644 --- a/ext/NonlinearSolveFastLevenbergMarquardtExt.jl +++ b/ext/NonlinearSolveFastLevenbergMarquardtExt.jl @@ -4,9 +4,7 @@ using ArrayInterface, NonlinearSolve, SciMLBase import ConcreteStructs: @concrete import FastLevenbergMarquardt as FastLM -NonlinearSolve.extension_loaded(::Val{:FastLevenbergMarquardt}) = true - -function _fast_lm_solver(::FastLevenbergMarquardtSolver{linsolve}, x) where {linsolve} +function _fast_lm_solver(::FastLevenbergMarquardtJL{linsolve}, x) where {linsolve} if linsolve == :cholesky return FastLM.CholeskySolver(ArrayInterface.undefmatrix(x)) elseif linsolve == :qr @@ -16,7 +14,7 @@ function _fast_lm_solver(::FastLevenbergMarquardtSolver{linsolve}, x) where {lin end end -@concrete struct FastLMCache +@concrete struct FastLevenbergMarquardtJLCache f! J! prob @@ -34,7 +32,7 @@ end (f::InplaceFunction{false})(fx, x, p) = (fx .= f.f(x, p)) function SciMLBase.__init(prob::NonlinearLeastSquaresProblem, - alg::FastLevenbergMarquardtSolver, args...; abstol = 1e-8, reltol = 1e-8, + alg::FastLevenbergMarquardtJL, args...; abstol = 1e-8, reltol = 1e-8, verbose = false, maxiters = 1000, kwargs...) iip = SciMLBase.isinplace(prob) @@ -52,13 +50,13 @@ function SciMLBase.__init(prob::NonlinearLeastSquaresProblem, solver = _fast_lm_solver(alg, prob.u0) LM = FastLM.LMWorkspace(prob.u0, resid_prototype, J) - return FastLMCache(f!, J!, prob, alg, LM, solver, + return FastLevenbergMarquardtJLCache(f!, J!, prob, alg, LM, solver, (; xtol = abstol, ftol = reltol, maxit = maxiters, alg.factor, alg.factoraccept, alg.factorreject, alg.minscale, alg.maxscale, alg.factorupdate, alg.minfactor, alg.maxfactor, kwargs...)) end -function SciMLBase.solve!(cache::FastLMCache) +function SciMLBase.solve!(cache::FastLevenbergMarquardtJLCache) res, fx, info, iter, nfev, njev, LM, solver = FastLM.lmsolve!(cache.f!, cache.J!, cache.lmworkspace, cache.prob.p; cache.solver, cache.kwargs...) stats = SciMLBase.NLStats(nfev, njev, -1, -1, iter) diff --git a/ext/NonlinearSolveLeastSquaresOptimExt.jl b/ext/NonlinearSolveLeastSquaresOptimExt.jl index 40299f5b3..e1f3cdfb7 100644 --- a/ext/NonlinearSolveLeastSquaresOptimExt.jl +++ b/ext/NonlinearSolveLeastSquaresOptimExt.jl @@ -4,9 +4,7 @@ using NonlinearSolve, SciMLBase import ConcreteStructs: @concrete import LeastSquaresOptim as LSO -NonlinearSolve.extension_loaded(::Val{:LeastSquaresOptim}) = true - -function _lso_solver(::LSOptimSolver{alg, linsolve}) where {alg, linsolve} +function _lso_solver(::LeastSquaresOptimJL{alg, linsolve}) where {alg, linsolve} ls = linsolve == :qr ? LSO.QR() : (linsolve == :cholesky ? LSO.Cholesky() : (linsolve == :lsmr ? LSO.LSMR() : nothing)) @@ -19,7 +17,7 @@ function _lso_solver(::LSOptimSolver{alg, linsolve}) where {alg, linsolve} end end -@concrete struct LeastSquaresOptimCache +@concrete struct LeastSquaresOptimJLCache prob alg allocated_prob @@ -34,8 +32,8 @@ end (f::FunctionWrapper{true})(du, u) = f.f(du, u, f.p) (f::FunctionWrapper{false})(du, u) = (du .= f.f(u, f.p)) -function SciMLBase.__init(prob::NonlinearLeastSquaresProblem, alg::LSOptimSolver, args...; - abstol = 1e-8, reltol = 1e-8, verbose = false, maxiters = 1000, kwargs...) +function SciMLBase.__init(prob::NonlinearLeastSquaresProblem, alg::LeastSquaresOptimJL, + args...; abstol = 1e-8, reltol = 1e-8, verbose = false, maxiters = 1000, kwargs...) iip = SciMLBase.isinplace(prob) f! = FunctionWrapper{iip}(prob.f, prob.p) @@ -49,12 +47,12 @@ function SciMLBase.__init(prob::NonlinearLeastSquaresProblem, alg::LSOptimSolver J = prob.f.jac_prototype, alg.autodiff, output_length = length(resid_prototype)) allocated_prob = LSO.LeastSquaresProblemAllocated(lsoprob, _lso_solver(alg)) - return LeastSquaresOptimCache(prob, alg, allocated_prob, + return LeastSquaresOptimJLCache(prob, alg, allocated_prob, (; x_tol = abstol, f_tol = reltol, iterations = maxiters, show_trace = verbose, kwargs...)) end -function SciMLBase.solve!(cache::LeastSquaresOptimCache) +function SciMLBase.solve!(cache::LeastSquaresOptimJLCache) res = LSO.optimize!(cache.allocated_prob; cache.kwargs...) maxiters = cache.kwargs[:iterations] retcode = res.x_converged || res.f_converged || res.g_converged ? ReturnCode.Success : diff --git a/src/NonlinearSolve.jl b/src/NonlinearSolve.jl index efcce13f5..dbeaefbfe 100644 --- a/src/NonlinearSolve.jl +++ b/src/NonlinearSolve.jl @@ -30,8 +30,6 @@ abstract type AbstractNewtonAlgorithm{CJ, AD} <: AbstractNonlinearSolveAlgorithm abstract type AbstractNonlinearSolveCache{iip} end -extension_loaded(::Val) = false - isinplace(::AbstractNonlinearSolveCache{iip}) where {iip} = iip function SciMLBase.__solve(prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem}, @@ -96,7 +94,7 @@ end export RadiusUpdateSchemes export NewtonRaphson, TrustRegion, LevenbergMarquardt, GaussNewton -export LSOptimSolver, FastLevenbergMarquardtSolver +export LeastSquaresOptimJL, FastLevenbergMarquardtJL export LineSearch diff --git a/src/algorithms.jl b/src/extension_algs.jl similarity index 59% rename from src/algorithms.jl rename to src/extension_algs.jl index 3f288433e..6038a7623 100644 --- a/src/algorithms.jl +++ b/src/extension_algs.jl @@ -1,9 +1,10 @@ -# Define Algorithms extended via extensions +# This file only include the algorithm struct to be exported by LinearSolve.jl. The main +# functionality is implemented as package extensions """ - LSOptimSolver(alg = :lm; linsolve = nothing, autodiff::Symbol = :central) + LeastSquaresOptimJL(alg = :lm; linsolve = nothing, autodiff::Symbol = :central) -Wrapper over [LeastSquaresOptim.jl](https://github.com/matthieugomez/LeastSquaresOptim.jl) for solving -`NonlinearLeastSquaresProblem`. +Wrapper over [LeastSquaresOptim.jl](https://github.com/matthieugomez/LeastSquaresOptim.jl) +for solving `NonlinearLeastSquaresProblem`. ## Arguments: @@ -16,25 +17,24 @@ Wrapper over [LeastSquaresOptim.jl](https://github.com/matthieugomez/LeastSquare !!! note This algorithm is only available if `LeastSquaresOptim.jl` is installed. """ -struct LSOptimSolver{alg, linsolve} <: AbstractNonlinearSolveAlgorithm +struct LeastSquaresOptimJL{alg, linsolve} <: AbstractNonlinearSolveAlgorithm autodiff::Symbol end -function LSOptimSolver(alg = :lm; linsolve = nothing, autodiff::Symbol = :central) +function LeastSquaresOptimJL(alg = :lm; linsolve = nothing, autodiff::Symbol = :central) @assert alg in (:lm, :dogleg) @assert linsolve === nothing || linsolve in (:qr, :cholesky, :lsmr) @assert autodiff in (:central, :forward) - if !extension_loaded(Val(:LeastSquaresOptim)) - @warn "LeastSquaresOptim.jl is not loaded! It needs to be explicitly loaded \ - before `solve(prob, LSOptimSolver())` is called." + if Base.get_extension(@__MODULE__, :NonlinearSolveLeastSquaresOptimExt) === nothing + error("LeastSquaresOptimJL requires LeastSquaresOptim.jl to be loaded") end - return LSOptimSolver{alg, linsolve}(autodiff) + return LeastSquaresOptimJL{alg, linsolve}(autodiff) end """ - FastLevenbergMarquardtSolver(linsolve = :cholesky) + FastLevenbergMarquardtJL(linsolve = :cholesky) Wrapper over [FastLevenbergMarquardt.jl](https://github.com/kamesy/FastLevenbergMarquardt.jl) for solving `NonlinearLeastSquaresProblem`. @@ -53,7 +53,7 @@ Wrapper over [FastLevenbergMarquardt.jl](https://github.com/kamesy/FastLevenberg !!! note This algorithm is only available if `FastLevenbergMarquardt.jl` is installed. """ -@concrete struct FastLevenbergMarquardtSolver{linsolve} <: AbstractNonlinearSolveAlgorithm +@concrete struct FastLevenbergMarquardtJL{linsolve} <: AbstractNonlinearSolveAlgorithm factor factoraccept factorreject @@ -64,17 +64,16 @@ Wrapper over [FastLevenbergMarquardt.jl](https://github.com/kamesy/FastLevenberg maxfactor end -function FastLevenbergMarquardtSolver(linsolve::Symbol = :cholesky; factor = 1e-6, +function FastLevenbergMarquardtJL(linsolve::Symbol = :cholesky; factor = 1e-6, factoraccept = 13.0, factorreject = 3.0, factorupdate = :marquardt, minscale = 1e-12, maxscale = 1e16, minfactor = 1e-28, maxfactor = 1e32) @assert linsolve in (:qr, :cholesky) @assert factorupdate in (:marquardt, :nielson) - if !extension_loaded(Val(:FastLevenbergMarquardt)) - @warn "FastLevenbergMarquardt.jl is not loaded! It needs to be explicitly loaded \ - before `solve(prob, FastLevenbergMarquardtSolver())` is called." + if Base.get_extension(@__MODULE__, :NonlinearSolveFastLevenbergMarquardtExt) === nothing + error("LeastSquaresOptimJL requires FastLevenbergMarquardt.jl to be loaded") end - return FastLevenbergMarquardtSolver{linsolve}(factor, factoraccept, factorreject, + return FastLevenbergMarquardtJL{linsolve}(factor, factoraccept, factorreject, factorupdate, minscale, maxscale, minfactor, maxfactor) end diff --git a/test/nonlinear_least_squares.jl b/test/nonlinear_least_squares.jl index 4f67ef28e..0f3d3e898 100644 --- a/test/nonlinear_least_squares.jl +++ b/test/nonlinear_least_squares.jl @@ -27,7 +27,8 @@ prob_iip = NonlinearLeastSquaresProblem(NonlinearFunction(loss_function; resid_prototype = zero(y_target)), θ_init, x) nlls_problems = [prob_oop, prob_iip] -solvers = [GaussNewton(), LevenbergMarquardt(), LSOptimSolver(:lm), LSOptimSolver(:dogleg)] +solvers = [GaussNewton(), LevenbergMarquardt(), LeastSquaresOptimJL(:lm), + LeastSquaresOptimJL(:dogleg)] for prob in nlls_problems, solver in solvers @time sol = solve(prob, solver; maxiters = 10000, abstol = 1e-8) @@ -43,7 +44,7 @@ end prob = NonlinearLeastSquaresProblem(NonlinearFunction(loss_function; resid_prototype = zero(y_target), jac = jac!), θ_init, x) -solvers = [FastLevenbergMarquardtSolver(:cholesky), FastLevenbergMarquardtSolver(:qr)] +solvers = [FastLevenbergMarquardtJL(:cholesky), FastLevenbergMarquardtJL(:qr)] for solver in solvers @time sol = solve(prob, solver; maxiters = 10000, abstol = 1e-8)