-
-
Notifications
You must be signed in to change notification settings - Fork 42
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
21 changed files
with
1,214 additions
and
1,072 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,72 +1,84 @@ | ||
module NonlinearSolveFastLevenbergMarquardtExt | ||
|
||
using ArrayInterface: ArrayInterface | ||
using FastClosures: @closure | ||
|
||
using ArrayInterface: ArrayInterface | ||
using FastLevenbergMarquardt: FastLevenbergMarquardt | ||
using NonlinearSolveBase: NonlinearSolveBase, get_tolerance | ||
using NonlinearSolve: NonlinearSolve, FastLevenbergMarquardtJL | ||
using SciMLBase: SciMLBase, NonlinearLeastSquaresProblem, NonlinearProblem, ReturnCode | ||
using StaticArraysCore: SArray | ||
|
||
const FastLM = FastLevenbergMarquardt | ||
using NonlinearSolveBase: NonlinearSolveBase | ||
using NonlinearSolve: NonlinearSolve, FastLevenbergMarquardtJL | ||
using SciMLBase: SciMLBase, AbstractNonlinearProblem, ReturnCode | ||
|
||
@inline function _fast_lm_solver(::FastLevenbergMarquardtJL{linsolve}, x) where {linsolve} | ||
if linsolve === :cholesky | ||
return FastLM.CholeskySolver(ArrayInterface.undefmatrix(x)) | ||
elseif linsolve === :qr | ||
return FastLM.QRSolver(eltype(x), length(x)) | ||
else | ||
throw(ArgumentError("Unknown FastLevenbergMarquardt Linear Solver: $linsolve")) | ||
end | ||
end | ||
@inline _fast_lm_solver(::FastLevenbergMarquardtJL{linsolve}, ::SArray) where {linsolve} = linsolve | ||
const FastLM = FastLevenbergMarquardt | ||
|
||
function SciMLBase.__solve(prob::Union{NonlinearLeastSquaresProblem, NonlinearProblem}, | ||
alg::FastLevenbergMarquardtJL, args...; alias_u0 = false, abstol = nothing, | ||
reltol = nothing, maxiters = 1000, termination_condition = nothing, kwargs...) | ||
NonlinearSolve.__test_termination_condition( | ||
termination_condition, :FastLevenbergMarquardt) | ||
function SciMLBase.__solve( | ||
prob::AbstractNonlinearProblem, alg::FastLevenbergMarquardtJL, args...; | ||
alias_u0 = false, abstol = nothing, reltol = nothing, maxiters = 1000, | ||
termination_condition = nothing, kwargs... | ||
) | ||
NonlinearSolveBase.assert_extension_supported_termination_condition( | ||
termination_condition, alg | ||
) | ||
|
||
fn, u, resid = NonlinearSolve.__construct_extension_f( | ||
prob; alias_u0, can_handle_oop = Val(prob.u0 isa SArray)) | ||
f_wrapped, u, resid = NonlinearSolveBase.construct_extension_function_wrapper( | ||
prob; alias_u0, can_handle_oop = Val(prob.u0 isa SArray) | ||
) | ||
f = if prob.u0 isa SArray | ||
@closure (u, p) -> fn(u) | ||
@closure (u, p) -> f_wrapped(u) | ||
else | ||
@closure (du, u, p) -> fn(du, u) | ||
@closure (du, u, p) -> f_wrapped(du, u) | ||
end | ||
abstol = get_tolerance(abstol, eltype(u)) | ||
reltol = get_tolerance(reltol, eltype(u)) | ||
|
||
_jac_fn = NonlinearSolve.__construct_extension_jac( | ||
prob, alg, u, resid; alg.autodiff, can_handle_oop = Val(prob.u0 isa SArray)) | ||
abstol = NonlinearSolveBase.get_tolerance(abstol, eltype(u)) | ||
reltol = NonlinearSolveBase.get_tolerance(reltol, eltype(u)) | ||
|
||
jac_fn_wrapped = NonlinearSolveBase.construct_extension_jac( | ||
prob, alg, u, resid; alg.autodiff, can_handle_oop = Val(prob.u0 isa SArray) | ||
) | ||
jac_fn = if prob.u0 isa SArray | ||
@closure (u, p) -> _jac_fn(u) | ||
@closure (u, p) -> jac_fn_wrapped(u) | ||
else | ||
@closure (J, u, p) -> _jac_fn(J, u) | ||
@closure (J, u, p) -> jac_fn_wrapped(J, u) | ||
end | ||
|
||
solver_kwargs = (; xtol = reltol, ftol = reltol, gtol = abstol, maxit = maxiters, | ||
solver_kwargs = (; | ||
xtol = reltol, ftol = reltol, gtol = abstol, maxit = maxiters, | ||
alg.factor, alg.factoraccept, alg.factorreject, alg.minscale, | ||
alg.maxscale, alg.factorupdate, alg.minfactor, alg.maxfactor) | ||
alg.maxscale, alg.factorupdate, alg.minfactor, alg.maxfactor | ||
) | ||
|
||
if prob.u0 isa SArray | ||
res, fx, info, iter, nfev, njev = FastLM.lmsolve( | ||
f, jac_fn, prob.u0; solver_kwargs...) | ||
f, jac_fn, prob.u0; solver_kwargs... | ||
) | ||
LM, solver = nothing, nothing | ||
else | ||
J = prob.f.jac_prototype === nothing ? similar(u, length(resid), length(u)) : | ||
zero(prob.f.jac_prototype) | ||
solver = _fast_lm_solver(alg, u) | ||
|
||
solver = if alg.linsolve === :cholesky | ||
FastLM.CholeskySolver(ArrayInterface.undefmatrix(u)) | ||
elseif alg.linsolve === :qr | ||
FastLM.QRSolver(eltype(u), length(u)) | ||
else | ||
throw(ArgumentError("Unknown FastLevenbergMarquardt Linear Solver: \ | ||
$(Meta.quot(alg.linsolve))")) | ||
end | ||
|
||
LM = FastLM.LMWorkspace(u, resid, J) | ||
|
||
res, fx, info, iter, nfev, njev, LM, solver = FastLM.lmsolve!( | ||
f, jac_fn, LM; solver, solver_kwargs...) | ||
f, jac_fn, LM; solver, solver_kwargs... | ||
) | ||
end | ||
|
||
stats = SciMLBase.NLStats(nfev, njev, -1, -1, iter) | ||
retcode = info == -1 ? ReturnCode.MaxIters : ReturnCode.Success | ||
return SciMLBase.build_solution(prob, alg, res, fx; retcode, | ||
original = (res, fx, info, iter, nfev, njev, LM, solver), stats) | ||
return SciMLBase.build_solution( | ||
prob, alg, res, fx; retcode, | ||
original = (res, fx, info, iter, nfev, njev, LM, solver), stats | ||
) | ||
end | ||
|
||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,41 +1,51 @@ | ||
module NonlinearSolveFixedPointAccelerationExt | ||
|
||
using NonlinearSolveBase: NonlinearSolveBase, get_tolerance | ||
using FixedPointAcceleration: FixedPointAcceleration, fixed_point | ||
|
||
using NonlinearSolveBase: NonlinearSolveBase | ||
using NonlinearSolve: NonlinearSolve, FixedPointAccelerationJL | ||
using SciMLBase: SciMLBase, NonlinearProblem, ReturnCode | ||
using FixedPointAcceleration: FixedPointAcceleration, fixed_point | ||
|
||
function SciMLBase.__solve(prob::NonlinearProblem, alg::FixedPointAccelerationJL, args...; | ||
function SciMLBase.__solve( | ||
prob::NonlinearProblem, alg::FixedPointAccelerationJL, args...; | ||
abstol = nothing, maxiters = 1000, alias_u0::Bool = false, | ||
show_trace::Val{PrintReports} = Val(false), | ||
termination_condition = nothing, kwargs...) where {PrintReports} | ||
NonlinearSolve.__test_termination_condition( | ||
termination_condition, :FixedPointAccelerationJL) | ||
show_trace::Val = Val(false), termination_condition = nothing, kwargs... | ||
) | ||
NonlinearSolveBase.assert_extension_supported_termination_condition( | ||
termination_condition, alg | ||
) | ||
|
||
f, u0, resid = NonlinearSolveBase.construct_extension_function_wrapper( | ||
prob; alias_u0, make_fixed_point = Val(true), force_oop = Val(true) | ||
) | ||
|
||
f, u0, resid = NonlinearSolve.__construct_extension_f( | ||
prob; alias_u0, make_fixed_point = Val(true), force_oop = Val(true)) | ||
tol = get_tolerance(abstol, eltype(u0)) | ||
tol = NonlinearSolveBase.get_tolerance(abstol, eltype(u0)) | ||
|
||
sol = fixed_point(f, u0; Algorithm = alg.algorithm, MaxIter = maxiters, MaxM = alg.m, | ||
sol = fixed_point( | ||
f, u0; Algorithm = alg.algorithm, MaxIter = maxiters, MaxM = alg.m, | ||
ConvergenceMetricThreshold = tol, ExtrapolationPeriod = alg.extrapolation_period, | ||
Dampening = alg.dampening, PrintReports, ReplaceInvalids = alg.replace_invalids, | ||
ConditionNumberThreshold = alg.condition_number_threshold, quiet_errors = true) | ||
Dampening = alg.dampening, PrintReports = show_trace isa Val{true}, | ||
ReplaceInvalids = alg.replace_invalids, | ||
ConditionNumberThreshold = alg.condition_number_threshold, quiet_errors = true | ||
) | ||
|
||
if sol.FixedPoint_ === missing | ||
u0 = prob.u0 isa Number ? u0[1] : u0 | ||
resid = NonlinearSolve.evaluate_f(prob, u0) | ||
resid = NonlinearSolveBase.Utils.evaluate_f(prob, u0) | ||
res = u0 | ||
converged = false | ||
else | ||
res = prob.u0 isa Number ? first(sol.FixedPoint_) : | ||
reshape(sol.FixedPoint_, size(prob.u0)) | ||
resid = NonlinearSolve.evaluate_f(prob, res) | ||
resid = NonlinearSolveBase.Utils.evaluate_f(prob, res) | ||
converged = maximum(abs, resid) ≤ tol | ||
end | ||
|
||
return SciMLBase.build_solution(prob, alg, res, resid; original = sol, | ||
return SciMLBase.build_solution( | ||
prob, alg, res, resid; original = sol, | ||
retcode = converged ? ReturnCode.Success : ReturnCode.Failure, | ||
stats = SciMLBase.NLStats(sol.Iterations_, 0, 0, 0, sol.Iterations_)) | ||
stats = SciMLBase.NLStats(sol.Iterations_, 0, 0, 0, sol.Iterations_) | ||
) | ||
end | ||
|
||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,79 +1,70 @@ | ||
module NonlinearSolveLeastSquaresOptimExt | ||
|
||
using ConcreteStructs: @concrete | ||
using LeastSquaresOptim: LeastSquaresOptim | ||
using NonlinearSolveBase: NonlinearSolveBase, TraceMinimal, get_tolerance | ||
|
||
using NonlinearSolveBase: NonlinearSolveBase, TraceMinimal | ||
using NonlinearSolve: NonlinearSolve, LeastSquaresOptimJL | ||
using SciMLBase: SciMLBase, NonlinearLeastSquaresProblem, NonlinearProblem, ReturnCode | ||
using SciMLBase: SciMLBase, AbstractNonlinearProblem, ReturnCode | ||
|
||
const LSO = LeastSquaresOptim | ||
|
||
@inline function _lso_solver(::LeastSquaresOptimJL{alg, ls}) where {alg, ls} | ||
linsolve = ls === :qr ? LSO.QR() : | ||
(ls === :cholesky ? LSO.Cholesky() : (ls === :lsmr ? LSO.LSMR() : nothing)) | ||
if alg === :lm | ||
return LSO.LevenbergMarquardt(linsolve) | ||
elseif alg === :dogleg | ||
return LSO.Dogleg(linsolve) | ||
else | ||
throw(ArgumentError("Unknown LeastSquaresOptim Algorithm: $alg")) | ||
end | ||
end | ||
|
||
@concrete struct LeastSquaresOptimJLCache | ||
prob | ||
alg | ||
allocated_prob | ||
kwargs | ||
end | ||
|
||
function Base.show(io::IO, cache::LeastSquaresOptimJLCache) | ||
print(io, "LeastSquaresOptimJLCache()") | ||
end | ||
|
||
function SciMLBase.reinit!(cache::LeastSquaresOptimJLCache, args...; kwargs...) | ||
error("Reinitialization not supported for LeastSquaresOptimJL.") | ||
end | ||
|
||
function SciMLBase.__init(prob::Union{NonlinearLeastSquaresProblem, NonlinearProblem}, | ||
alg::LeastSquaresOptimJL, args...; alias_u0 = false, abstol = nothing, | ||
show_trace::Val{ShT} = Val(false), trace_level = TraceMinimal(), | ||
reltol = nothing, store_trace::Val{StT} = Val(false), maxiters = 1000, | ||
termination_condition = nothing, kwargs...) where {ShT, StT} | ||
NonlinearSolve.__test_termination_condition(termination_condition, :LeastSquaresOptim) | ||
function SciMLBase.__solve( | ||
prob::AbstractNonlinearProblem, alg::LeastSquaresOptimJL, args...; | ||
alias_u0 = false, abstol = nothing, reltol = nothing, maxiters = 1000, | ||
trace_level = TraceMinimal(), termination_condition = nothing, | ||
show_trace::Val = Val(false), store_trace::Val = Val(false), kwargs... | ||
) | ||
NonlinearSolveBase.assert_extension_supported_termination_condition( | ||
termination_condition, alg | ||
) | ||
|
||
f!, u, resid = NonlinearSolve.__construct_extension_f(prob; alias_u0) | ||
abstol = get_tolerance(abstol, eltype(u)) | ||
reltol = get_tolerance(reltol, eltype(u)) | ||
f!, u, resid = NonlinearSolveBase.construct_extension_function_wrapper(prob; alias_u0) | ||
abstol = NonlinearSolveBase.get_tolerance(abstol, eltype(u)) | ||
reltol = NonlinearSolveBase.get_tolerance(reltol, eltype(u)) | ||
|
||
if prob.f.jac === nothing && alg.autodiff isa Symbol | ||
lsoprob = LSO.LeastSquaresProblem(; x = u, f!, y = resid, alg.autodiff, | ||
J = prob.f.jac_prototype, output_length = length(resid)) | ||
lsoprob = LSO.LeastSquaresProblem(; | ||
x = u, f!, y = resid, alg.autodiff, J = prob.f.jac_prototype, | ||
output_length = length(resid) | ||
) | ||
else | ||
g! = NonlinearSolve.__construct_extension_jac(prob, alg, u, resid; alg.autodiff) | ||
g! = NonlinearSolveBase.construct_extension_jac(prob, alg, u, resid; alg.autodiff) | ||
lsoprob = LSO.LeastSquaresProblem(; | ||
x = u, f!, y = resid, g!, J = prob.f.jac_prototype, | ||
output_length = length(resid)) | ||
output_length = length(resid) | ||
) | ||
end | ||
|
||
allocated_prob = LSO.LeastSquaresProblemAllocated(lsoprob, _lso_solver(alg)) | ||
linsolve = alg.ls === :qr ? LSO.QR() : | ||
(alg.ls === :cholesky ? LSO.Cholesky() : | ||
(alg.ls === :lsmr ? LSO.LSMR() : nothing)) | ||
|
||
return LeastSquaresOptimJLCache(prob, | ||
alg, | ||
allocated_prob, | ||
(; x_tol = reltol, f_tol = abstol, g_tol = abstol, iterations = maxiters, | ||
show_trace = ShT, store_trace = StT, show_every = trace_level.print_frequency)) | ||
end | ||
lso_solver = if alg.alg === :lm | ||
LSO.LevenbergMarquardt(linsolve) | ||
elseif alg.alg === :dogleg | ||
LSO.Dogleg(linsolve) | ||
else | ||
throw(ArgumentError("Unknown LeastSquaresOptim Algorithm: $(Meta.quot(alg.alg))")) | ||
end | ||
|
||
allocated_prob = LSO.LeastSquaresProblemAllocated(lsoprob, lso_solver(alg)) | ||
res = LSO.optimize!( | ||
allocated_prob; | ||
x_tol = reltol, f_tol = abstol, g_tol = abstol, iterations = maxiters, | ||
show_trace = show_trace isa Val{true}, store_trace = store_trace isa Val{true}, | ||
show_every = trace_level.print_frequency | ||
) | ||
|
||
function SciMLBase.solve!(cache::LeastSquaresOptimJLCache) | ||
res = LSO.optimize!(cache.allocated_prob; cache.kwargs...) | ||
maxiters = cache.kwargs[:iterations] | ||
retcode = res.x_converged || res.f_converged || res.g_converged ? ReturnCode.Success : | ||
(res.iterations ≥ maxiters ? ReturnCode.MaxIters : | ||
ReturnCode.ConvergenceFailure) | ||
stats = SciMLBase.NLStats(res.f_calls, res.g_calls, -1, -1, res.iterations) | ||
|
||
f!(resid, res.minimizer) | ||
|
||
return SciMLBase.build_solution( | ||
cache.prob, cache.alg, res.minimizer, res.ssr / 2; retcode, original = res, stats) | ||
prob, alg, res.minimizer, resid; retcode, original = res, stats | ||
) | ||
end | ||
|
||
end |
Oops, something went wrong.