Skip to content

Commit

Permalink
Add a wrapper over LeastSquaresOptim
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Oct 14, 2023
1 parent a6af39c commit 0497150
Show file tree
Hide file tree
Showing 8 changed files with 122 additions and 9 deletions.
9 changes: 8 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "NonlinearSolve"
uuid = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
authors = ["SciML"]
version = "2.2.1"
version = "2.3.0"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand All @@ -24,6 +24,12 @@ SparseDiffTools = "47a9eef4-7e08-11e9-0b38-333d64bd3804"
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"

[weakdeps]
LeastSquaresOptim = "0fc2ff8b-aaa3-5acd-a817-1944a5e08891"

[extensions]
NonlinearSolveLeastSquaresOptimExt = "LeastSquaresOptim"

[compat]
ADTypes = "0.2"
ArrayInterface = "6.0.24, 7"
Expand All @@ -33,6 +39,7 @@ EnumX = "1"
Enzyme = "0.11"
FiniteDiff = "2"
ForwardDiff = "0.10.3"
LeastSquaresOptim = "0.8"
LineSearches = "7"
LinearSolve = "2"
NonlinearProblemLibrary = "0.1"
Expand Down
64 changes: 64 additions & 0 deletions ext/NonlinearSolveLeastSquaresOptimExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
module NonlinearSolveLeastSquaresOptimExt

using NonlinearSolve, SciMLBase
import ConcreteStructs: @concrete
import LeastSquaresOptim as LSO

extension_loaded(::Val{:LeastSquaresOptim}) = true

Check warning on line 7 in ext/NonlinearSolveLeastSquaresOptimExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/NonlinearSolveLeastSquaresOptimExt.jl#L7

Added line #L7 was not covered by tests

function _lso_solver(::LSOptimSolver{alg, linsolve}) where {alg, linsolve}
ls = linsolve == :qr ? LSO.QR() :

Check warning on line 10 in ext/NonlinearSolveLeastSquaresOptimExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/NonlinearSolveLeastSquaresOptimExt.jl#L9-L10

Added lines #L9 - L10 were not covered by tests
(linsolve == :cholesky ? LSO.Cholesky() :
(linsolve == :lsmr ? LSO.LSMR() : nothing))
if alg == :lm
return LSO.LevenbergMarquardt(ls)
elseif alg == :dogleg
return LSO.Dogleg(ls)

Check warning on line 16 in ext/NonlinearSolveLeastSquaresOptimExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/NonlinearSolveLeastSquaresOptimExt.jl#L13-L16

Added lines #L13 - L16 were not covered by tests
else
throw(ArgumentError("Unknown LeastSquaresOptim Algorithm: $alg"))

Check warning on line 18 in ext/NonlinearSolveLeastSquaresOptimExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/NonlinearSolveLeastSquaresOptimExt.jl#L18

Added line #L18 was not covered by tests
end
end

@concrete struct LeastSquaresOptimCache
prob
alg
allocated_prob
kwargs
end

@concrete struct FunctionWrapper{iip}
f
p
end

(f::FunctionWrapper{true})(du, u) = f.f(du, u, f.p)
(f::FunctionWrapper{false})(du, u) = (du .= f.f(u, f.p))

Check warning on line 35 in ext/NonlinearSolveLeastSquaresOptimExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/NonlinearSolveLeastSquaresOptimExt.jl#L34-L35

Added lines #L34 - L35 were not covered by tests

function SciMLBase.__init(prob::NonlinearLeastSquaresProblem, alg::LSOptimSolver, args...;

Check warning on line 37 in ext/NonlinearSolveLeastSquaresOptimExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/NonlinearSolveLeastSquaresOptimExt.jl#L37

Added line #L37 was not covered by tests
abstol = 1e-8, reltol = 1e-8, verbose = false, maxiters = 1000, kwargs...)
iip = SciMLBase.isinplace(prob)

Check warning on line 39 in ext/NonlinearSolveLeastSquaresOptimExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/NonlinearSolveLeastSquaresOptimExt.jl#L39

Added line #L39 was not covered by tests

f! = FunctionWrapper{iip}(prob.f, prob.p)
g! = prob.f.jac === nothing ? nothing : FunctionWrapper{iip}(prob.f.jac, prob.p)

Check warning on line 42 in ext/NonlinearSolveLeastSquaresOptimExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/NonlinearSolveLeastSquaresOptimExt.jl#L41-L42

Added lines #L41 - L42 were not covered by tests

lsoprob = LSO.LeastSquaresProblem(; x = prob.u0, f!, y = prob.f.resid_prototype, g!,

Check warning on line 44 in ext/NonlinearSolveLeastSquaresOptimExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/NonlinearSolveLeastSquaresOptimExt.jl#L44

Added line #L44 was not covered by tests
J = prob.f.jac_prototype, alg.autodiff,
output_length = length(prob.f.resid_prototype))
allocated_prob = LSO.LeastSquaresProblemAllocated(lsoprob, _lso_solver(alg))

Check warning on line 47 in ext/NonlinearSolveLeastSquaresOptimExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/NonlinearSolveLeastSquaresOptimExt.jl#L47

Added line #L47 was not covered by tests

return LeastSquaresOptimCache(prob, alg, allocated_prob,

Check warning on line 49 in ext/NonlinearSolveLeastSquaresOptimExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/NonlinearSolveLeastSquaresOptimExt.jl#L49

Added line #L49 was not covered by tests
(; x_tol = abstol, f_tol = reltol, iterations = maxiters, show_trace = verbose,
kwargs...))
end

function SciMLBase.solve!(cache::LeastSquaresOptimCache)
res = LSO.optimize!(cache.allocated_prob; cache.kwargs...)
maxiters = cache.kwargs[:iterations]
retcode = res.x_converged || res.f_converged || res.g_converged ? ReturnCode.Success :

Check warning on line 57 in ext/NonlinearSolveLeastSquaresOptimExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/NonlinearSolveLeastSquaresOptimExt.jl#L54-L57

Added lines #L54 - L57 were not covered by tests
(res.iterations maxiters ? ReturnCode.MaxIters : ReturnCode.ConvergenceFailure)
stats = SciMLBase.NLStats(res.f_calls, res.g_calls, -1, -1, res.iterations)
return SciMLBase.build_solution(cache.prob, cache.alg, res.minimizer, res.ssr / 2;

Check warning on line 60 in ext/NonlinearSolveLeastSquaresOptimExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/NonlinearSolveLeastSquaresOptimExt.jl#L59-L60

Added lines #L59 - L60 were not covered by tests
retcode, original=res, stats)
end

end
5 changes: 4 additions & 1 deletion src/NonlinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ abstract type AbstractNewtonAlgorithm{CJ, AD} <: AbstractNonlinearSolveAlgorithm

abstract type AbstractNonlinearSolveCache{iip} end

extension_loaded(::Val) = false

Check warning on line 33 in src/NonlinearSolve.jl

View check run for this annotation

Codecov / codecov/patch

src/NonlinearSolve.jl#L33

Added line #L33 was not covered by tests

isinplace(::AbstractNonlinearSolveCache{iip}) where {iip} = iip

function SciMLBase.__solve(prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem},
Expand Down Expand Up @@ -60,6 +62,7 @@ function SciMLBase.solve!(cache::AbstractNonlinearSolveCache)
end

include("utils.jl")
include("algorithms.jl")
include("linesearch.jl")
include("raphson.jl")
include("trustRegion.jl")
Expand Down Expand Up @@ -92,7 +95,7 @@ end

export RadiusUpdateSchemes

export NewtonRaphson, TrustRegion, LevenbergMarquardt, GaussNewton
export NewtonRaphson, TrustRegion, LevenbergMarquardt, GaussNewton, LSOptimSolver

export LineSearch

Expand Down
28 changes: 28 additions & 0 deletions src/algorithms.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Define Algorithms extended via extensions
"""
LSOptimSolver(alg = :lm; linsolve = nothing, autodiff::Symbol = :central)
Wrapper over [LeastSquaresOptim.jl](https://github.com/matthieugomez/LeastSquaresOptim.jl) for solving
`NonlinearLeastSquaresProblem`.
## Arguments:
- `alg`: Algorithm to use. Can be `:lm` or `:dogleg`.
- `linsolve`: Linear solver to use. Can be `:qr`, `:cholesky` or `:lsmr`. If
`nothing`, then `LeastSquaresOptim.jl` will choose the best linear solver based
on the Jacobian structure.
!!! note
This algorithm is only available if `LeastSquaresOptim.jl` is installed.
"""
struct LSOptimSolver{alg, linsolve} <: AbstractNonlinearSolveAlgorithm
autodiff::Symbol

function LSOptimSolver(alg = :lm; linsolve = nothing, autodiff::Symbol = :central)
@assert alg in (:lm, :dogleg)
@assert linsolve === nothing || linsolve in (:qr, :cholesky, :lsmr)
@assert autodiff in (:central, :forward)

Check warning on line 24 in src/algorithms.jl

View check run for this annotation

Codecov / codecov/patch

src/algorithms.jl#L21-L24

Added lines #L21 - L24 were not covered by tests

return new{alg, linsolve}(autodiff)

Check warning on line 26 in src/algorithms.jl

View check run for this annotation

Codecov / codecov/patch

src/algorithms.jl#L26

Added line #L26 was not covered by tests
end
end
4 changes: 2 additions & 2 deletions src/gaussnewton.jl
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,8 @@ end
function perform_step!(cache::GaussNewtonCache{true})
@unpack u, fu1, f, p, alg, J, JᵀJ, Jᵀf, linsolve, du = cache
jacobian!!(J, cache)
mul!(JᵀJ, J', J)
mul!(Jᵀf, J', fu1)
__matmul!(JᵀJ, J', J)
__matmul!(Jᵀf, J', fu1)

# u = u - J \ fu
linres = dolinsolve(alg.precs, linsolve; A = JᵀJ, b = _vec(Jᵀf), linu = _vec(du),
Expand Down
8 changes: 5 additions & 3 deletions src/jacobian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f, u, p, ::Val{ii
# NOTE: The deepcopy is needed here since we are using the resid_prototype elsewhere
fu = f.resid_prototype === nothing ? (iip ? _mutable_zero(u) : _mutable(f(u, p))) :
(iip ? deepcopy(f.resid_prototype) : f.resid_prototype)
if !has_analytic_jac && (linsolve_needs_jac || alg_wants_jac)
if !has_analytic_jac && (linsolve_needs_jac || alg_wants_jac || needsJᵀJ)
sd = sparsity_detection_alg(f, alg.ad)
ad = alg.ad
jac_cache = iip ? sparse_jacobian_cache(ad, sd, uf, fu, _maybe_mutable(u, ad)) :
Expand All @@ -74,7 +74,9 @@ function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f, u, p, ::Val{ii
jac_cache = nothing
end

J = if !(linsolve_needs_jac || alg_wants_jac)
# FIXME: To properly support needsJᵀJ without Jacobian, we need to implement
# a reverse diff operation with the seed being `Jx`, this is not yet implemented
J = if !(linsolve_needs_jac || alg_wants_jac || needsJᵀJ)
# We don't need to construct the Jacobian
JacVec(uf, u; autodiff = __get_nonsparse_ad(alg.ad))
else
Expand Down Expand Up @@ -114,7 +116,7 @@ __get_nonsparse_ad(::AutoSparseZygote) = AutoZygote()
__get_nonsparse_ad(ad) = ad

__init_JᵀJ(J::Number) = zero(J)
__init_JᵀJ(J::AbstractArray) = zeros(eltype(J), size(J, 2), size(J, 2))
__init_JᵀJ(J::AbstractArray) = J' * J
__init_JᵀJ(J::StaticArray) = MArray{Tuple{size(J, 2), size(J, 2)}, eltype(J)}(undef)

## Special Handling for Scalars
Expand Down
5 changes: 3 additions & 2 deletions src/levenberg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ function perform_step!(cache::LevenbergMarquardtCache{true})

if make_new_J
jacobian!!(cache.J, cache)
mul!(cache.JᵀJ, cache.J', cache.J)
__matmul!(cache.JᵀJ, cache.J', cache.J)
cache.DᵀD .= max.(cache.DᵀD, Diagonal(cache.JᵀJ))
cache.make_new_J = false
cache.stats.njacs += 1
Expand All @@ -216,7 +216,8 @@ function perform_step!(cache::LevenbergMarquardtCache{true})
mul!(cache.Jv, J, v)
@. cache.fu_tmp = (2 / h) * ((cache.fu_tmp - fu1) / h - cache.Jv)
mul!(cache.u_tmp, J', cache.fu_tmp)
linres = dolinsolve(alg.precs, linsolve; A = cache.mat_tmp, b = _vec(cache.u_tmp),
# NOTE: Don't pass `A` in again, since we want to reuse the previous solve
linres = dolinsolve(alg.precs, linsolve; b = _vec(cache.u_tmp),
linu = _vec(cache.du), p = p, reltol = cache.abstol)
cache.linsolve = linres.cache
@. cache.a = -cache.du
Expand Down
8 changes: 8 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -163,3 +163,11 @@ function evaluate_f(f, u, p, ::Val{iip}; fu = nothing) where {iip}
return f(u, p)
end
end

"""
__matmul!(C, A, B)
Defaults to `mul!(C, A, B)`. However, for sparse matrices uses `C .= A * B`.
"""
__matmul!(C, A, B) = mul!(C, A, B)
__matmul!(C::AbstractSparseMatrix, A, B) = C .= A * B

Check warning on line 173 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L173

Added line #L173 was not covered by tests

0 comments on commit 0497150

Please sign in to comment.