-
-
Notifications
You must be signed in to change notification settings - Fork 42
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add a wrapper over LeastSquaresOptim
- Loading branch information
Showing
8 changed files
with
122 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
module NonlinearSolveLeastSquaresOptimExt | ||
|
||
using NonlinearSolve, SciMLBase | ||
import ConcreteStructs: @concrete | ||
import LeastSquaresOptim as LSO | ||
|
||
extension_loaded(::Val{:LeastSquaresOptim}) = true | ||
|
||
function _lso_solver(::LSOptimSolver{alg, linsolve}) where {alg, linsolve} | ||
ls = linsolve == :qr ? LSO.QR() : | ||
(linsolve == :cholesky ? LSO.Cholesky() : | ||
(linsolve == :lsmr ? LSO.LSMR() : nothing)) | ||
if alg == :lm | ||
return LSO.LevenbergMarquardt(ls) | ||
elseif alg == :dogleg | ||
return LSO.Dogleg(ls) | ||
else | ||
throw(ArgumentError("Unknown LeastSquaresOptim Algorithm: $alg")) | ||
end | ||
end | ||
|
||
@concrete struct LeastSquaresOptimCache | ||
prob | ||
alg | ||
allocated_prob | ||
kwargs | ||
end | ||
|
||
@concrete struct FunctionWrapper{iip} | ||
f | ||
p | ||
end | ||
|
||
(f::FunctionWrapper{true})(du, u) = f.f(du, u, f.p) | ||
(f::FunctionWrapper{false})(du, u) = (du .= f.f(u, f.p)) | ||
|
||
function SciMLBase.__init(prob::NonlinearLeastSquaresProblem, alg::LSOptimSolver, args...; | ||
abstol = 1e-8, reltol = 1e-8, verbose = false, maxiters = 1000, kwargs...) | ||
iip = SciMLBase.isinplace(prob) | ||
|
||
f! = FunctionWrapper{iip}(prob.f, prob.p) | ||
g! = prob.f.jac === nothing ? nothing : FunctionWrapper{iip}(prob.f.jac, prob.p) | ||
|
||
lsoprob = LSO.LeastSquaresProblem(; x = prob.u0, f!, y = prob.f.resid_prototype, g!, | ||
J = prob.f.jac_prototype, alg.autodiff, | ||
output_length = length(prob.f.resid_prototype)) | ||
allocated_prob = LSO.LeastSquaresProblemAllocated(lsoprob, _lso_solver(alg)) | ||
|
||
return LeastSquaresOptimCache(prob, alg, allocated_prob, | ||
(; x_tol = abstol, f_tol = reltol, iterations = maxiters, show_trace = verbose, | ||
kwargs...)) | ||
end | ||
|
||
function SciMLBase.solve!(cache::LeastSquaresOptimCache) | ||
res = LSO.optimize!(cache.allocated_prob; cache.kwargs...) | ||
maxiters = cache.kwargs[:iterations] | ||
retcode = res.x_converged || res.f_converged || res.g_converged ? ReturnCode.Success : | ||
(res.iterations ≥ maxiters ? ReturnCode.MaxIters : ReturnCode.ConvergenceFailure) | ||
stats = SciMLBase.NLStats(res.f_calls, res.g_calls, -1, -1, res.iterations) | ||
return SciMLBase.build_solution(cache.prob, cache.alg, res.minimizer, res.ssr / 2; | ||
retcode, original=res, stats) | ||
end | ||
|
||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
# Define Algorithms extended via extensions | ||
""" | ||
LSOptimSolver(alg = :lm; linsolve = nothing, autodiff::Symbol = :central) | ||
Wrapper over [LeastSquaresOptim.jl](https://github.com/matthieugomez/LeastSquaresOptim.jl) for solving | ||
`NonlinearLeastSquaresProblem`. | ||
## Arguments: | ||
- `alg`: Algorithm to use. Can be `:lm` or `:dogleg`. | ||
- `linsolve`: Linear solver to use. Can be `:qr`, `:cholesky` or `:lsmr`. If | ||
`nothing`, then `LeastSquaresOptim.jl` will choose the best linear solver based | ||
on the Jacobian structure. | ||
!!! note | ||
This algorithm is only available if `LeastSquaresOptim.jl` is installed. | ||
""" | ||
struct LSOptimSolver{alg, linsolve} <: AbstractNonlinearSolveAlgorithm | ||
autodiff::Symbol | ||
|
||
function LSOptimSolver(alg = :lm; linsolve = nothing, autodiff::Symbol = :central) | ||
@assert alg in (:lm, :dogleg) | ||
@assert linsolve === nothing || linsolve in (:qr, :cholesky, :lsmr) | ||
@assert autodiff in (:central, :forward) | ||
|
||
return new{alg, linsolve}(autodiff) | ||
end | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters