Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Impoving NLS Solvers #236

Merged
merged 5 commits into from
Oct 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "NonlinearSolve"
uuid = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
authors = ["SciML"]
version = "2.2.1"
version = "2.3.0"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand All @@ -24,15 +24,25 @@ SparseDiffTools = "47a9eef4-7e08-11e9-0b38-333d64bd3804"
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"

[weakdeps]
FastLevenbergMarquardt = "7a0df574-e128-4d35-8cbd-3d84502bf7ce"
LeastSquaresOptim = "0fc2ff8b-aaa3-5acd-a817-1944a5e08891"

[extensions]
NonlinearSolveFastLevenbergMarquardtExt = "FastLevenbergMarquardt"
NonlinearSolveLeastSquaresOptimExt = "LeastSquaresOptim"

[compat]
ADTypes = "0.2"
ArrayInterface = "6.0.24, 7"
ConcreteStructs = "0.2"
DiffEqBase = "6.130"
EnumX = "1"
Enzyme = "0.11"
FastLevenbergMarquardt = "0.1"
FiniteDiff = "2"
ForwardDiff = "0.10.3"
LeastSquaresOptim = "0.8"
LineSearches = "7"
LinearSolve = "2"
NonlinearProblemLibrary = "0.1"
Expand All @@ -50,7 +60,9 @@ julia = "1.9"
[extras]
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
FastLevenbergMarquardt = "7a0df574-e128-4d35-8cbd-3d84502bf7ce"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LeastSquaresOptim = "0fc2ff8b-aaa3-5acd-a817-1944a5e08891"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
NonlinearProblemLibrary = "b7050fa9-e91f-4b37-bcee-a89a063da141"
Expand All @@ -64,4 +76,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Enzyme", "BenchmarkTools", "SafeTestsets", "Pkg", "Test", "ForwardDiff", "StaticArrays", "Symbolics", "LinearSolve", "Random", "LinearAlgebra", "Zygote", "SparseDiffTools", "NonlinearProblemLibrary"]
test = ["Enzyme", "BenchmarkTools", "SafeTestsets", "Pkg", "Test", "ForwardDiff", "StaticArrays", "Symbolics", "LinearSolve", "Random", "LinearAlgebra", "Zygote", "SparseDiffTools", "NonlinearProblemLibrary", "LeastSquaresOptim", "FastLevenbergMarquardt"]
71 changes: 71 additions & 0 deletions ext/NonlinearSolveFastLevenbergMarquardtExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
module NonlinearSolveFastLevenbergMarquardtExt

using ArrayInterface, NonlinearSolve, SciMLBase
import ConcreteStructs: @concrete
import FastLevenbergMarquardt as FastLM

NonlinearSolve.extension_loaded(::Val{:FastLevenbergMarquardt}) = true

function _fast_lm_solver(::FastLevenbergMarquardtSolver{linsolve}, x) where {linsolve}
if linsolve == :cholesky
return FastLM.CholeskySolver(ArrayInterface.undefmatrix(x))
elseif linsolve == :qr
return FastLM.QRSolver(eltype(x), length(x))
else
throw(ArgumentError("Unknown FastLevenbergMarquardt Linear Solver: $linsolve"))

Check warning on line 15 in ext/NonlinearSolveFastLevenbergMarquardtExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/NonlinearSolveFastLevenbergMarquardtExt.jl#L15

Added line #L15 was not covered by tests
end
end

@concrete struct FastLMCache
f!
J!
prob
alg
lmworkspace
solver
kwargs
end

@concrete struct InplaceFunction{iip} <: Function
f
end

(f::InplaceFunction{true})(fx, x, p) = f.f(fx, x, p)
(f::InplaceFunction{false})(fx, x, p) = (fx .= f.f(x, p))

Check warning on line 34 in ext/NonlinearSolveFastLevenbergMarquardtExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/NonlinearSolveFastLevenbergMarquardtExt.jl#L34

Added line #L34 was not covered by tests

function SciMLBase.__init(prob::NonlinearLeastSquaresProblem,
alg::FastLevenbergMarquardtSolver, args...; abstol = 1e-8, reltol = 1e-8,
verbose = false, maxiters = 1000, kwargs...)
iip = SciMLBase.isinplace(prob)

@assert prob.f.jac!==nothing "FastLevenbergMarquardt requires a Jacobian!"

f! = InplaceFunction{iip}(prob.f)
J! = InplaceFunction{iip}(prob.f.jac)

resid_prototype = prob.f.resid_prototype === nothing ?
(!iip ? prob.f(prob.u0, prob.p) : zeros(prob.u0)) :
prob.f.resid_prototype

J = similar(prob.u0, length(resid_prototype), length(prob.u0))

solver = _fast_lm_solver(alg, prob.u0)
LM = FastLM.LMWorkspace(prob.u0, resid_prototype, J)

return FastLMCache(f!, J!, prob, alg, LM, solver,
(; xtol = abstol, ftol = reltol, maxit = maxiters, alg.factor, alg.factoraccept,
alg.factorreject, alg.minscale, alg.maxscale, alg.factorupdate, alg.minfactor,
alg.maxfactor, kwargs...))
end

function SciMLBase.solve!(cache::FastLMCache)
res, fx, info, iter, nfev, njev, LM, solver = FastLM.lmsolve!(cache.f!, cache.J!,
cache.lmworkspace, cache.prob.p; cache.solver, cache.kwargs...)
stats = SciMLBase.NLStats(nfev, njev, -1, -1, iter)
retcode = info == 1 ? ReturnCode.Success :
(info == -1 ? ReturnCode.MaxIters : ReturnCode.Default)
return SciMLBase.build_solution(cache.prob, cache.alg, res, fx;
retcode, original = (res, fx, info, iter, nfev, njev, LM, solver), stats)
end

end
68 changes: 68 additions & 0 deletions ext/NonlinearSolveLeastSquaresOptimExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
module NonlinearSolveLeastSquaresOptimExt

using NonlinearSolve, SciMLBase
import ConcreteStructs: @concrete
import LeastSquaresOptim as LSO

NonlinearSolve.extension_loaded(::Val{:LeastSquaresOptim}) = true

function _lso_solver(::LSOptimSolver{alg, linsolve}) where {alg, linsolve}
ls = linsolve == :qr ? LSO.QR() :
(linsolve == :cholesky ? LSO.Cholesky() :
(linsolve == :lsmr ? LSO.LSMR() : nothing))
if alg == :lm
return LSO.LevenbergMarquardt(ls)
elseif alg == :dogleg
return LSO.Dogleg(ls)
else
throw(ArgumentError("Unknown LeastSquaresOptim Algorithm: $alg"))

Check warning on line 18 in ext/NonlinearSolveLeastSquaresOptimExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/NonlinearSolveLeastSquaresOptimExt.jl#L18

Added line #L18 was not covered by tests
end
end

@concrete struct LeastSquaresOptimCache
prob
alg
allocated_prob
kwargs
end

@concrete struct FunctionWrapper{iip}
f
p
end

(f::FunctionWrapper{true})(du, u) = f.f(du, u, f.p)
(f::FunctionWrapper{false})(du, u) = (du .= f.f(u, f.p))

function SciMLBase.__init(prob::NonlinearLeastSquaresProblem, alg::LSOptimSolver, args...;
abstol = 1e-8, reltol = 1e-8, verbose = false, maxiters = 1000, kwargs...)
iip = SciMLBase.isinplace(prob)

f! = FunctionWrapper{iip}(prob.f, prob.p)
g! = prob.f.jac === nothing ? nothing : FunctionWrapper{iip}(prob.f.jac, prob.p)

resid_prototype = prob.f.resid_prototype === nothing ?
(!iip ? prob.f(prob.u0, prob.p) : zeros(prob.u0)) :
prob.f.resid_prototype

lsoprob = LSO.LeastSquaresProblem(; x = prob.u0, f!, y = resid_prototype, g!,
J = prob.f.jac_prototype, alg.autodiff, output_length = length(resid_prototype))
allocated_prob = LSO.LeastSquaresProblemAllocated(lsoprob, _lso_solver(alg))

return LeastSquaresOptimCache(prob, alg, allocated_prob,
(; x_tol = abstol, f_tol = reltol, iterations = maxiters, show_trace = verbose,
kwargs...))
end

function SciMLBase.solve!(cache::LeastSquaresOptimCache)
res = LSO.optimize!(cache.allocated_prob; cache.kwargs...)
maxiters = cache.kwargs[:iterations]
retcode = res.x_converged || res.f_converged || res.g_converged ? ReturnCode.Success :
(res.iterations ≥ maxiters ? ReturnCode.MaxIters :
ReturnCode.ConvergenceFailure)
stats = SciMLBase.NLStats(res.f_calls, res.g_calls, -1, -1, res.iterations)
return SciMLBase.build_solution(cache.prob, cache.alg, res.minimizer, res.ssr / 2;
retcode, original = res, stats)
end

end
4 changes: 4 additions & 0 deletions src/NonlinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@

abstract type AbstractNonlinearSolveCache{iip} end

extension_loaded(::Val) = false

Check warning on line 33 in src/NonlinearSolve.jl

View check run for this annotation

Codecov / codecov/patch

src/NonlinearSolve.jl#L33

Added line #L33 was not covered by tests

isinplace(::AbstractNonlinearSolveCache{iip}) where {iip} = iip

function SciMLBase.__solve(prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem},
Expand Down Expand Up @@ -60,6 +62,7 @@
end

include("utils.jl")
include("algorithms.jl")
include("linesearch.jl")
include("raphson.jl")
include("trustRegion.jl")
Expand Down Expand Up @@ -93,6 +96,7 @@
export RadiusUpdateSchemes

export NewtonRaphson, TrustRegion, LevenbergMarquardt, GaussNewton
export LSOptimSolver, FastLevenbergMarquardtSolver

export LineSearch

Expand Down
80 changes: 80 additions & 0 deletions src/algorithms.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# Define Algorithms extended via extensions
"""
LSOptimSolver(alg = :lm; linsolve = nothing, autodiff::Symbol = :central)

Wrapper over [LeastSquaresOptim.jl](https://github.com/matthieugomez/LeastSquaresOptim.jl) for solving
`NonlinearLeastSquaresProblem`.

## Arguments:

- `alg`: Algorithm to use. Can be `:lm` or `:dogleg`.
- `linsolve`: Linear solver to use. Can be `:qr`, `:cholesky` or `:lsmr`. If
`nothing`, then `LeastSquaresOptim.jl` will choose the best linear solver based
on the Jacobian structure.
- `autodiff`: Automatic differentiation / Finite Differences. Can be `:central` or `:forward`.

!!! note
This algorithm is only available if `LeastSquaresOptim.jl` is installed.
"""
struct LSOptimSolver{alg, linsolve} <: AbstractNonlinearSolveAlgorithm
autodiff::Symbol
end

function LSOptimSolver(alg = :lm; linsolve = nothing, autodiff::Symbol = :central)
@assert alg in (:lm, :dogleg)
@assert linsolve === nothing || linsolve in (:qr, :cholesky, :lsmr)
@assert autodiff in (:central, :forward)

if !extension_loaded(Val(:LeastSquaresOptim))
@warn "LeastSquaresOptim.jl is not loaded! It needs to be explicitly loaded \

Check warning on line 29 in src/algorithms.jl

View check run for this annotation

Codecov / codecov/patch

src/algorithms.jl#L29

Added line #L29 was not covered by tests
before `solve(prob, LSOptimSolver())` is called."
end

return LSOptimSolver{alg, linsolve}(autodiff)
end

"""
FastLevenbergMarquardtSolver(linsolve = :cholesky)

Wrapper over [FastLevenbergMarquardt.jl](https://github.com/kamesy/FastLevenbergMarquardt.jl) for solving
`NonlinearLeastSquaresProblem`.

!!! warning
This is not really the fastest solver. It is called that since the original package
is called "Fast". `LevenbergMarquardt()` is almost always a better choice.

!!! warning
This algorithm requires the jacobian function to be provided!

## Arguments:

- `linsolve`: Linear solver to use. Can be `:qr` or `:cholesky`.

!!! note
This algorithm is only available if `FastLevenbergMarquardt.jl` is installed.
"""
@concrete struct FastLevenbergMarquardtSolver{linsolve} <: AbstractNonlinearSolveAlgorithm
factor
factoraccept
factorreject
factorupdate::Symbol
minscale
maxscale
minfactor
maxfactor
end

function FastLevenbergMarquardtSolver(linsolve::Symbol = :cholesky; factor = 1e-6,
factoraccept = 13.0, factorreject = 3.0, factorupdate = :marquardt,
minscale = 1e-12, maxscale = 1e16, minfactor = 1e-28, maxfactor = 1e32)
@assert linsolve in (:qr, :cholesky)
@assert factorupdate in (:marquardt, :nielson)

if !extension_loaded(Val(:FastLevenbergMarquardt))
@warn "FastLevenbergMarquardt.jl is not loaded! It needs to be explicitly loaded \

Check warning on line 74 in src/algorithms.jl

View check run for this annotation

Codecov / codecov/patch

src/algorithms.jl#L74

Added line #L74 was not covered by tests
before `solve(prob, FastLevenbergMarquardtSolver())` is called."
end

return FastLevenbergMarquardtSolver{linsolve}(factor, factoraccept, factorreject,
factorupdate, minscale, maxscale, minfactor, maxfactor)
end
18 changes: 9 additions & 9 deletions src/gaussnewton.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""
GaussNewton(; concrete_jac = nothing, linsolve = nothing, precs = DEFAULT_PRECS,
adkwargs...)
GaussNewton(; concrete_jac = nothing, linsolve = nothing,
precs = DEFAULT_PRECS, adkwargs...)

An advanced GaussNewton implementation with support for efficient handling of sparse
matrices via colored automatic differentiation and preconditioned linear solvers. Designed
Expand Down Expand Up @@ -41,7 +41,7 @@ for large-scale and numerically-difficult nonlinear least squares problems.
precs
end

function GaussNewton(; concrete_jac = nothing, linsolve = NormalCholeskyFactorization(),
function GaussNewton(; concrete_jac = nothing, linsolve = CholeskyFactorization(),
precs = DEFAULT_PRECS, adkwargs...)
ad = default_adargs_to_adtype(; adkwargs...)
return GaussNewton{_unwrap_val(concrete_jac)}(ad, linsolve, precs)
Expand Down Expand Up @@ -93,12 +93,12 @@ end
function perform_step!(cache::GaussNewtonCache{true})
@unpack u, fu1, f, p, alg, J, JᵀJ, Jᵀf, linsolve, du = cache
jacobian!!(J, cache)
mul!(JᵀJ, J', J)
mul!(Jᵀf, J', fu1)
__matmul!(JᵀJ, J', J)
__matmul!(Jᵀf, J', fu1)

# u = u - J \ fu
linres = dolinsolve(alg.precs, linsolve; A = JᵀJ, b = _vec(Jᵀf), linu = _vec(du),
p, reltol = cache.abstol)
linres = dolinsolve(alg.precs, linsolve; A = __maybe_symmetric(JᵀJ), b = _vec(Jᵀf),
linu = _vec(du), p, reltol = cache.abstol)
cache.linsolve = linres.cache
@. u = u - du
f(cache.fu_new, u, p)
Expand All @@ -125,8 +125,8 @@ function perform_step!(cache::GaussNewtonCache{false})
if linsolve === nothing
cache.du = fu1 / cache.J
else
linres = dolinsolve(alg.precs, linsolve; A = cache.JᵀJ, b = _vec(cache.Jᵀf),
linu = _vec(cache.du), p, reltol = cache.abstol)
linres = dolinsolve(alg.precs, linsolve; A = __maybe_symmetric(cache.JᵀJ),
b = _vec(cache.Jᵀf), linu = _vec(cache.du), p, reltol = cache.abstol)
cache.linsolve = linres.cache
end
cache.u = @. u - cache.du # `u` might not support mutation
Expand Down
22 changes: 15 additions & 7 deletions src/jacobian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@
# NOTE: The deepcopy is needed here since we are using the resid_prototype elsewhere
fu = f.resid_prototype === nothing ? (iip ? _mutable_zero(u) : _mutable(f(u, p))) :
(iip ? deepcopy(f.resid_prototype) : f.resid_prototype)
if !has_analytic_jac && (linsolve_needs_jac || alg_wants_jac)
if !has_analytic_jac && (linsolve_needs_jac || alg_wants_jac || needsJᵀJ)
sd = sparsity_detection_alg(f, alg.ad)
ad = alg.ad
jac_cache = iip ? sparse_jacobian_cache(ad, sd, uf, fu, _maybe_mutable(u, ad)) :
Expand All @@ -74,7 +74,9 @@
jac_cache = nothing
end

J = if !(linsolve_needs_jac || alg_wants_jac)
# FIXME: To properly support needsJᵀJ without Jacobian, we need to implement
# a reverse diff operation with the seed being `Jx`, this is not yet implemented
J = if !(linsolve_needs_jac || alg_wants_jac || needsJᵀJ)
# We don't need to construct the Jacobian
JacVec(uf, u; autodiff = __get_nonsparse_ad(alg.ad))
else
Expand All @@ -93,14 +95,14 @@
Jᵀfu = J' * fu
end

linprob = LinearProblem(needsJᵀJ ? JᵀJ : J, needsJᵀJ ? _vec(Jᵀfu) : _vec(fu);
u0 = _vec(du))
linprob = LinearProblem(needsJᵀJ ? __maybe_symmetric(JᵀJ) : J,
needsJᵀJ ? _vec(Jᵀfu) : _vec(fu); u0 = _vec(du))

weight = similar(u)
recursivefill!(weight, true)

Pl, Pr = wrapprecs(alg.precs(J, nothing, u, p, nothing, nothing, nothing, nothing,
nothing)..., weight)
Pl, Pr = wrapprecs(alg.precs(needsJᵀJ ? __maybe_symmetric(JᵀJ) : J, nothing, u, p,
nothing, nothing, nothing, nothing, nothing)..., weight)
linsolve = init(linprob, alg.linsolve; alias_A = true, alias_b = true, Pl, Pr,
linsolve_kwargs...)

Expand All @@ -114,9 +116,15 @@
__get_nonsparse_ad(ad) = ad

__init_JᵀJ(J::Number) = zero(J)
__init_JᵀJ(J::AbstractArray) = zeros(eltype(J), size(J, 2), size(J, 2))
__init_JᵀJ(J::AbstractArray) = J' * J
__init_JᵀJ(J::StaticArray) = MArray{Tuple{size(J, 2), size(J, 2)}, eltype(J)}(undef)

__maybe_symmetric(x) = Symmetric(x)
__maybe_symmetric(x::Number) = x

Check warning on line 123 in src/jacobian.jl

View check run for this annotation

Codecov / codecov/patch

src/jacobian.jl#L123

Added line #L123 was not covered by tests
# LinearSolve with `nothing` doesn't dispatch correctly here
__maybe_symmetric(x::StaticArray) = x
__maybe_symmetric(x::SparseArrays.AbstractSparseMatrix) = x

Check warning on line 126 in src/jacobian.jl

View check run for this annotation

Codecov / codecov/patch

src/jacobian.jl#L126

Added line #L126 was not covered by tests

## Special Handling for Scalars
function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f, u::Number, p,
::Val{false}; linsolve_with_JᵀJ::Val{needsJᵀJ} = Val(false),
Expand Down
Loading
Loading