Skip to content

Commit

Permalink
Rename wrappers to be consistent with other SciML packages
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Oct 16, 2023
1 parent 1a0e5ee commit 17693fb
Show file tree
Hide file tree
Showing 6 changed files with 32 additions and 38 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "NonlinearSolve"
uuid = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
authors = ["SciML"]
version = "2.3.0"
version = "2.4.0"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
12 changes: 5 additions & 7 deletions ext/NonlinearSolveFastLevenbergMarquardtExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@ using ArrayInterface, NonlinearSolve, SciMLBase
import ConcreteStructs: @concrete
import FastLevenbergMarquardt as FastLM

NonlinearSolve.extension_loaded(::Val{:FastLevenbergMarquardt}) = true

function _fast_lm_solver(::FastLevenbergMarquardtSolver{linsolve}, x) where {linsolve}
function _fast_lm_solver(::FastLevenbergMarquardtJL{linsolve}, x) where {linsolve}
if linsolve == :cholesky
return FastLM.CholeskySolver(ArrayInterface.undefmatrix(x))
elseif linsolve == :qr
Expand All @@ -16,7 +14,7 @@ function _fast_lm_solver(::FastLevenbergMarquardtSolver{linsolve}, x) where {lin
end
end

@concrete struct FastLMCache
@concrete struct FastLevenbergMarquardtJLCache
f!
J!
prob
Expand All @@ -34,7 +32,7 @@ end
(f::InplaceFunction{false})(fx, x, p) = (fx .= f.f(x, p))

function SciMLBase.__init(prob::NonlinearLeastSquaresProblem,
alg::FastLevenbergMarquardtSolver, args...; abstol = 1e-8, reltol = 1e-8,
alg::FastLevenbergMarquardtJL, args...; abstol = 1e-8, reltol = 1e-8,
verbose = false, maxiters = 1000, kwargs...)
iip = SciMLBase.isinplace(prob)

Expand All @@ -52,13 +50,13 @@ function SciMLBase.__init(prob::NonlinearLeastSquaresProblem,
solver = _fast_lm_solver(alg, prob.u0)
LM = FastLM.LMWorkspace(prob.u0, resid_prototype, J)

return FastLMCache(f!, J!, prob, alg, LM, solver,
return FastLevenbergMarquardtJLCache(f!, J!, prob, alg, LM, solver,
(; xtol = abstol, ftol = reltol, maxit = maxiters, alg.factor, alg.factoraccept,
alg.factorreject, alg.minscale, alg.maxscale, alg.factorupdate, alg.minfactor,
alg.maxfactor, kwargs...))
end

function SciMLBase.solve!(cache::FastLMCache)
function SciMLBase.solve!(cache::FastLevenbergMarquardtJLCache)
res, fx, info, iter, nfev, njev, LM, solver = FastLM.lmsolve!(cache.f!, cache.J!,
cache.lmworkspace, cache.prob.p; cache.solver, cache.kwargs...)
stats = SciMLBase.NLStats(nfev, njev, -1, -1, iter)
Expand Down
14 changes: 6 additions & 8 deletions ext/NonlinearSolveLeastSquaresOptimExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@ using NonlinearSolve, SciMLBase
import ConcreteStructs: @concrete
import LeastSquaresOptim as LSO

NonlinearSolve.extension_loaded(::Val{:LeastSquaresOptim}) = true

function _lso_solver(::LSOptimSolver{alg, linsolve}) where {alg, linsolve}
function _lso_solver(::LeastSquaresOptimJL{alg, linsolve}) where {alg, linsolve}
ls = linsolve == :qr ? LSO.QR() :
(linsolve == :cholesky ? LSO.Cholesky() :
(linsolve == :lsmr ? LSO.LSMR() : nothing))
Expand All @@ -19,7 +17,7 @@ function _lso_solver(::LSOptimSolver{alg, linsolve}) where {alg, linsolve}
end
end

@concrete struct LeastSquaresOptimCache
@concrete struct LeastSquaresOptimJLCache
prob
alg
allocated_prob
Expand All @@ -34,8 +32,8 @@ end
(f::FunctionWrapper{true})(du, u) = f.f(du, u, f.p)
(f::FunctionWrapper{false})(du, u) = (du .= f.f(u, f.p))

function SciMLBase.__init(prob::NonlinearLeastSquaresProblem, alg::LSOptimSolver, args...;
abstol = 1e-8, reltol = 1e-8, verbose = false, maxiters = 1000, kwargs...)
function SciMLBase.__init(prob::NonlinearLeastSquaresProblem, alg::LeastSquaresOptimJL,
args...; abstol = 1e-8, reltol = 1e-8, verbose = false, maxiters = 1000, kwargs...)
iip = SciMLBase.isinplace(prob)

f! = FunctionWrapper{iip}(prob.f, prob.p)
Expand All @@ -49,12 +47,12 @@ function SciMLBase.__init(prob::NonlinearLeastSquaresProblem, alg::LSOptimSolver
J = prob.f.jac_prototype, alg.autodiff, output_length = length(resid_prototype))
allocated_prob = LSO.LeastSquaresProblemAllocated(lsoprob, _lso_solver(alg))

return LeastSquaresOptimCache(prob, alg, allocated_prob,
return LeastSquaresOptimJLCache(prob, alg, allocated_prob,
(; x_tol = abstol, f_tol = reltol, iterations = maxiters, show_trace = verbose,
kwargs...))
end

function SciMLBase.solve!(cache::LeastSquaresOptimCache)
function SciMLBase.solve!(cache::LeastSquaresOptimJLCache)
res = LSO.optimize!(cache.allocated_prob; cache.kwargs...)
maxiters = cache.kwargs[:iterations]
retcode = res.x_converged || res.f_converged || res.g_converged ? ReturnCode.Success :
Expand Down
4 changes: 1 addition & 3 deletions src/NonlinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,6 @@ abstract type AbstractNewtonAlgorithm{CJ, AD} <: AbstractNonlinearSolveAlgorithm

abstract type AbstractNonlinearSolveCache{iip} end

extension_loaded(::Val) = false

isinplace(::AbstractNonlinearSolveCache{iip}) where {iip} = iip

function SciMLBase.__solve(prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem},
Expand Down Expand Up @@ -96,7 +94,7 @@ end
export RadiusUpdateSchemes

export NewtonRaphson, TrustRegion, LevenbergMarquardt, GaussNewton
export LSOptimSolver, FastLevenbergMarquardtSolver
export LeastSquaresOptimJL, FastLevenbergMarquardtJL

export LineSearch

Expand Down
33 changes: 16 additions & 17 deletions src/algorithms.jl → src/extension_algs.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
# Define Algorithms extended via extensions
# This file only include the algorithm struct to be exported by LinearSolve.jl. The main
# functionality is implemented as package extensions
"""
LSOptimSolver(alg = :lm; linsolve = nothing, autodiff::Symbol = :central)
LeastSquaresOptimJL(alg = :lm; linsolve = nothing, autodiff::Symbol = :central)
Wrapper over [LeastSquaresOptim.jl](https://github.com/matthieugomez/LeastSquaresOptim.jl) for solving
`NonlinearLeastSquaresProblem`.
Wrapper over [LeastSquaresOptim.jl](https://github.com/matthieugomez/LeastSquaresOptim.jl)
for solving `NonlinearLeastSquaresProblem`.
## Arguments:
Expand All @@ -16,25 +17,24 @@ Wrapper over [LeastSquaresOptim.jl](https://github.com/matthieugomez/LeastSquare
!!! note
This algorithm is only available if `LeastSquaresOptim.jl` is installed.
"""
struct LSOptimSolver{alg, linsolve} <: AbstractNonlinearSolveAlgorithm
struct LeastSquaresOptimJL{alg, linsolve} <: AbstractNonlinearSolveAlgorithm
autodiff::Symbol
end

function LSOptimSolver(alg = :lm; linsolve = nothing, autodiff::Symbol = :central)
function LeastSquaresOptimJL(alg = :lm; linsolve = nothing, autodiff::Symbol = :central)
@assert alg in (:lm, :dogleg)
@assert linsolve === nothing || linsolve in (:qr, :cholesky, :lsmr)
@assert autodiff in (:central, :forward)

if !extension_loaded(Val(:LeastSquaresOptim))
@warn "LeastSquaresOptim.jl is not loaded! It needs to be explicitly loaded \
before `solve(prob, LSOptimSolver())` is called."
if Base.get_extension(@__MODULE__, :NonlinearSolveLeastSquaresOptimExt) === nothing
error("LeastSquaresOptimJL requires LeastSquaresOptim.jl to be loaded")
end

return LSOptimSolver{alg, linsolve}(autodiff)
return LeastSquaresOptimJL{alg, linsolve}(autodiff)
end

"""
FastLevenbergMarquardtSolver(linsolve = :cholesky)
FastLevenbergMarquardtJL(linsolve = :cholesky)
Wrapper over [FastLevenbergMarquardt.jl](https://github.com/kamesy/FastLevenbergMarquardt.jl) for solving
`NonlinearLeastSquaresProblem`.
Expand All @@ -53,7 +53,7 @@ Wrapper over [FastLevenbergMarquardt.jl](https://github.com/kamesy/FastLevenberg
!!! note
This algorithm is only available if `FastLevenbergMarquardt.jl` is installed.
"""
@concrete struct FastLevenbergMarquardtSolver{linsolve} <: AbstractNonlinearSolveAlgorithm
@concrete struct FastLevenbergMarquardtJL{linsolve} <: AbstractNonlinearSolveAlgorithm
factor
factoraccept
factorreject
Expand All @@ -64,17 +64,16 @@ Wrapper over [FastLevenbergMarquardt.jl](https://github.com/kamesy/FastLevenberg
maxfactor
end

function FastLevenbergMarquardtSolver(linsolve::Symbol = :cholesky; factor = 1e-6,
function FastLevenbergMarquardtJL(linsolve::Symbol = :cholesky; factor = 1e-6,
factoraccept = 13.0, factorreject = 3.0, factorupdate = :marquardt,
minscale = 1e-12, maxscale = 1e16, minfactor = 1e-28, maxfactor = 1e32)
@assert linsolve in (:qr, :cholesky)
@assert factorupdate in (:marquardt, :nielson)

if !extension_loaded(Val(:FastLevenbergMarquardt))
@warn "FastLevenbergMarquardt.jl is not loaded! It needs to be explicitly loaded \
before `solve(prob, FastLevenbergMarquardtSolver())` is called."
if Base.get_extension(@__MODULE__, :NonlinearSolveFastLevenbergMarquardtExt) === nothing
error("LeastSquaresOptimJL requires FastLevenbergMarquardt.jl to be loaded")
end

return FastLevenbergMarquardtSolver{linsolve}(factor, factoraccept, factorreject,
return FastLevenbergMarquardtJL{linsolve}(factor, factoraccept, factorreject,
factorupdate, minscale, maxscale, minfactor, maxfactor)
end
5 changes: 3 additions & 2 deletions test/nonlinear_least_squares.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ prob_iip = NonlinearLeastSquaresProblem(NonlinearFunction(loss_function;
resid_prototype = zero(y_target)), θ_init, x)

nlls_problems = [prob_oop, prob_iip]
solvers = [GaussNewton(), LevenbergMarquardt(), LSOptimSolver(:lm), LSOptimSolver(:dogleg)]
solvers = [GaussNewton(), LevenbergMarquardt(), LeastSquaresOptimJL(:lm),
LeastSquaresOptimJL(:dogleg)]

for prob in nlls_problems, solver in solvers
@time sol = solve(prob, solver; maxiters = 10000, abstol = 1e-8)
Expand All @@ -43,7 +44,7 @@ end
prob = NonlinearLeastSquaresProblem(NonlinearFunction(loss_function;
resid_prototype = zero(y_target), jac = jac!), θ_init, x)

solvers = [FastLevenbergMarquardtSolver(:cholesky), FastLevenbergMarquardtSolver(:qr)]
solvers = [FastLevenbergMarquardtJL(:cholesky), FastLevenbergMarquardtJL(:qr)]

for solver in solvers
@time sol = solve(prob, solver; maxiters = 10000, abstol = 1e-8)
Expand Down

0 comments on commit 17693fb

Please sign in to comment.