Skip to content

Commit

Permalink
refactor: cleanup all wrappers
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Oct 31, 2024
1 parent 4f74479 commit 693949b
Show file tree
Hide file tree
Showing 21 changed files with 1,214 additions and 1,072 deletions.
6 changes: 6 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ BracketingNonlinearSolve = "70df07ce-3d50-431d-a3e7-ca6ddb60ac1e"
CommonSolve = "38540f10-b2f7-11e9-35d8-d573e4eb0ff2"
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a"
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Expand Down Expand Up @@ -62,10 +63,12 @@ Aqua = "0.8"
ArrayInterface = "7.16"
BandedMatrices = "1.5"
BenchmarkTools = "1.4"
BracketingNonlinearSolve = "1"
CUDA = "5.5"
CommonSolve = "0.2.4"
ConcreteStructs = "0.2.3"
DiffEqBase = "6.155.3"
DifferentiationInterface = "0.6.18"
Enzyme = "0.13.11"
ExplicitImports = "1.5"
FastClosures = "0.3.2"
Expand All @@ -87,6 +90,9 @@ NLsolve = "4.5"
NaNMath = "1"
NonlinearProblemLibrary = "0.1.2"
NonlinearSolveBase = "1"
NonlinearSolveFirstOrder = "1"
NonlinearSolveQuasiNewton = "1"
NonlinearSolveSpectralMethods = "1"
OrdinaryDiffEqTsit5 = "1.1.0"
PETSc = "0.2"
Pkg = "1.10"
Expand Down
1 change: 1 addition & 0 deletions common/common_nlls_testing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,4 @@ prob_iip_vjp = NonlinearLeastSquaresProblem(
)

export prob_oop, prob_iip, prob_oop_vjp, prob_iip_vjp
export true_function, θ_true, x, y_target, loss_function, θ_init
86 changes: 49 additions & 37 deletions ext/NonlinearSolveFastLevenbergMarquardtExt.jl
Original file line number Diff line number Diff line change
@@ -1,72 +1,84 @@
module NonlinearSolveFastLevenbergMarquardtExt

using ArrayInterface: ArrayInterface
using FastClosures: @closure

using ArrayInterface: ArrayInterface
using FastLevenbergMarquardt: FastLevenbergMarquardt
using NonlinearSolveBase: NonlinearSolveBase, get_tolerance
using NonlinearSolve: NonlinearSolve, FastLevenbergMarquardtJL
using SciMLBase: SciMLBase, NonlinearLeastSquaresProblem, NonlinearProblem, ReturnCode
using StaticArraysCore: SArray

const FastLM = FastLevenbergMarquardt
using NonlinearSolveBase: NonlinearSolveBase
using NonlinearSolve: NonlinearSolve, FastLevenbergMarquardtJL
using SciMLBase: SciMLBase, AbstractNonlinearProblem, ReturnCode

@inline function _fast_lm_solver(::FastLevenbergMarquardtJL{linsolve}, x) where {linsolve}
if linsolve === :cholesky
return FastLM.CholeskySolver(ArrayInterface.undefmatrix(x))
elseif linsolve === :qr
return FastLM.QRSolver(eltype(x), length(x))
else
throw(ArgumentError("Unknown FastLevenbergMarquardt Linear Solver: $linsolve"))
end
end
@inline _fast_lm_solver(::FastLevenbergMarquardtJL{linsolve}, ::SArray) where {linsolve} = linsolve
const FastLM = FastLevenbergMarquardt

function SciMLBase.__solve(prob::Union{NonlinearLeastSquaresProblem, NonlinearProblem},
alg::FastLevenbergMarquardtJL, args...; alias_u0 = false, abstol = nothing,
reltol = nothing, maxiters = 1000, termination_condition = nothing, kwargs...)
NonlinearSolve.__test_termination_condition(
termination_condition, :FastLevenbergMarquardt)
function SciMLBase.__solve(
prob::AbstractNonlinearProblem, alg::FastLevenbergMarquardtJL, args...;
alias_u0 = false, abstol = nothing, reltol = nothing, maxiters = 1000,
termination_condition = nothing, kwargs...
)
NonlinearSolveBase.assert_extension_supported_termination_condition(
termination_condition, alg
)

fn, u, resid = NonlinearSolve.__construct_extension_f(
prob; alias_u0, can_handle_oop = Val(prob.u0 isa SArray))
f_wrapped, u, resid = NonlinearSolveBase.construct_extension_function_wrapper(
prob; alias_u0, can_handle_oop = Val(prob.u0 isa SArray)
)
f = if prob.u0 isa SArray
@closure (u, p) -> fn(u)
@closure (u, p) -> f_wrapped(u)
else
@closure (du, u, p) -> fn(du, u)
@closure (du, u, p) -> f_wrapped(du, u)
end
abstol = get_tolerance(abstol, eltype(u))
reltol = get_tolerance(reltol, eltype(u))

_jac_fn = NonlinearSolve.__construct_extension_jac(
prob, alg, u, resid; alg.autodiff, can_handle_oop = Val(prob.u0 isa SArray))
abstol = NonlinearSolveBase.get_tolerance(abstol, eltype(u))
reltol = NonlinearSolveBase.get_tolerance(reltol, eltype(u))

jac_fn_wrapped = NonlinearSolveBase.construct_extension_jac(
prob, alg, u, resid; alg.autodiff, can_handle_oop = Val(prob.u0 isa SArray)
)
jac_fn = if prob.u0 isa SArray
@closure (u, p) -> _jac_fn(u)
@closure (u, p) -> jac_fn_wrapped(u)
else
@closure (J, u, p) -> _jac_fn(J, u)
@closure (J, u, p) -> jac_fn_wrapped(J, u)
end

solver_kwargs = (; xtol = reltol, ftol = reltol, gtol = abstol, maxit = maxiters,
solver_kwargs = (;
xtol = reltol, ftol = reltol, gtol = abstol, maxit = maxiters,
alg.factor, alg.factoraccept, alg.factorreject, alg.minscale,
alg.maxscale, alg.factorupdate, alg.minfactor, alg.maxfactor)
alg.maxscale, alg.factorupdate, alg.minfactor, alg.maxfactor
)

if prob.u0 isa SArray
res, fx, info, iter, nfev, njev = FastLM.lmsolve(
f, jac_fn, prob.u0; solver_kwargs...)
f, jac_fn, prob.u0; solver_kwargs...
)
LM, solver = nothing, nothing
else
J = prob.f.jac_prototype === nothing ? similar(u, length(resid), length(u)) :
zero(prob.f.jac_prototype)
solver = _fast_lm_solver(alg, u)

solver = if alg.linsolve === :cholesky
FastLM.CholeskySolver(ArrayInterface.undefmatrix(u))
elseif alg.linsolve === :qr
FastLM.QRSolver(eltype(u), length(u))
else
throw(ArgumentError("Unknown FastLevenbergMarquardt Linear Solver: \
$(Meta.quot(alg.linsolve))"))
end

LM = FastLM.LMWorkspace(u, resid, J)

res, fx, info, iter, nfev, njev, LM, solver = FastLM.lmsolve!(
f, jac_fn, LM; solver, solver_kwargs...)
f, jac_fn, LM; solver, solver_kwargs...
)
end

stats = SciMLBase.NLStats(nfev, njev, -1, -1, iter)
retcode = info == -1 ? ReturnCode.MaxIters : ReturnCode.Success
return SciMLBase.build_solution(prob, alg, res, fx; retcode,
original = (res, fx, info, iter, nfev, njev, LM, solver), stats)
return SciMLBase.build_solution(
prob, alg, res, fx; retcode,
original = (res, fx, info, iter, nfev, njev, LM, solver), stats
)
end

end
44 changes: 27 additions & 17 deletions ext/NonlinearSolveFixedPointAccelerationExt.jl
Original file line number Diff line number Diff line change
@@ -1,41 +1,51 @@
module NonlinearSolveFixedPointAccelerationExt

using NonlinearSolveBase: NonlinearSolveBase, get_tolerance
using FixedPointAcceleration: FixedPointAcceleration, fixed_point

using NonlinearSolveBase: NonlinearSolveBase
using NonlinearSolve: NonlinearSolve, FixedPointAccelerationJL
using SciMLBase: SciMLBase, NonlinearProblem, ReturnCode
using FixedPointAcceleration: FixedPointAcceleration, fixed_point

function SciMLBase.__solve(prob::NonlinearProblem, alg::FixedPointAccelerationJL, args...;
function SciMLBase.__solve(
prob::NonlinearProblem, alg::FixedPointAccelerationJL, args...;
abstol = nothing, maxiters = 1000, alias_u0::Bool = false,
show_trace::Val{PrintReports} = Val(false),
termination_condition = nothing, kwargs...) where {PrintReports}
NonlinearSolve.__test_termination_condition(
termination_condition, :FixedPointAccelerationJL)
show_trace::Val = Val(false), termination_condition = nothing, kwargs...
)
NonlinearSolveBase.assert_extension_supported_termination_condition(
termination_condition, alg
)

f, u0, resid = NonlinearSolveBase.construct_extension_function_wrapper(
prob; alias_u0, make_fixed_point = Val(true), force_oop = Val(true)
)

f, u0, resid = NonlinearSolve.__construct_extension_f(
prob; alias_u0, make_fixed_point = Val(true), force_oop = Val(true))
tol = get_tolerance(abstol, eltype(u0))
tol = NonlinearSolveBase.get_tolerance(abstol, eltype(u0))

sol = fixed_point(f, u0; Algorithm = alg.algorithm, MaxIter = maxiters, MaxM = alg.m,
sol = fixed_point(
f, u0; Algorithm = alg.algorithm, MaxIter = maxiters, MaxM = alg.m,
ConvergenceMetricThreshold = tol, ExtrapolationPeriod = alg.extrapolation_period,
Dampening = alg.dampening, PrintReports, ReplaceInvalids = alg.replace_invalids,
ConditionNumberThreshold = alg.condition_number_threshold, quiet_errors = true)
Dampening = alg.dampening, PrintReports = show_trace isa Val{true},
ReplaceInvalids = alg.replace_invalids,
ConditionNumberThreshold = alg.condition_number_threshold, quiet_errors = true
)

if sol.FixedPoint_ === missing
u0 = prob.u0 isa Number ? u0[1] : u0
resid = NonlinearSolve.evaluate_f(prob, u0)
resid = NonlinearSolveBase.Utils.evaluate_f(prob, u0)
res = u0
converged = false
else
res = prob.u0 isa Number ? first(sol.FixedPoint_) :
reshape(sol.FixedPoint_, size(prob.u0))
resid = NonlinearSolve.evaluate_f(prob, res)
resid = NonlinearSolveBase.Utils.evaluate_f(prob, res)
converged = maximum(abs, resid) tol
end

return SciMLBase.build_solution(prob, alg, res, resid; original = sol,
return SciMLBase.build_solution(
prob, alg, res, resid; original = sol,
retcode = converged ? ReturnCode.Success : ReturnCode.Failure,
stats = SciMLBase.NLStats(sol.Iterations_, 0, 0, 0, sol.Iterations_))
stats = SciMLBase.NLStats(sol.Iterations_, 0, 0, 0, sol.Iterations_)
)
end

end
99 changes: 45 additions & 54 deletions ext/NonlinearSolveLeastSquaresOptimExt.jl
Original file line number Diff line number Diff line change
@@ -1,79 +1,70 @@
module NonlinearSolveLeastSquaresOptimExt

using ConcreteStructs: @concrete
using LeastSquaresOptim: LeastSquaresOptim
using NonlinearSolveBase: NonlinearSolveBase, TraceMinimal, get_tolerance

using NonlinearSolveBase: NonlinearSolveBase, TraceMinimal
using NonlinearSolve: NonlinearSolve, LeastSquaresOptimJL
using SciMLBase: SciMLBase, NonlinearLeastSquaresProblem, NonlinearProblem, ReturnCode
using SciMLBase: SciMLBase, AbstractNonlinearProblem, ReturnCode

const LSO = LeastSquaresOptim

@inline function _lso_solver(::LeastSquaresOptimJL{alg, ls}) where {alg, ls}
linsolve = ls === :qr ? LSO.QR() :
(ls === :cholesky ? LSO.Cholesky() : (ls === :lsmr ? LSO.LSMR() : nothing))
if alg === :lm
return LSO.LevenbergMarquardt(linsolve)
elseif alg === :dogleg
return LSO.Dogleg(linsolve)
else
throw(ArgumentError("Unknown LeastSquaresOptim Algorithm: $alg"))
end
end

@concrete struct LeastSquaresOptimJLCache
prob
alg
allocated_prob
kwargs
end

function Base.show(io::IO, cache::LeastSquaresOptimJLCache)
print(io, "LeastSquaresOptimJLCache()")
end

function SciMLBase.reinit!(cache::LeastSquaresOptimJLCache, args...; kwargs...)
error("Reinitialization not supported for LeastSquaresOptimJL.")
end

function SciMLBase.__init(prob::Union{NonlinearLeastSquaresProblem, NonlinearProblem},
alg::LeastSquaresOptimJL, args...; alias_u0 = false, abstol = nothing,
show_trace::Val{ShT} = Val(false), trace_level = TraceMinimal(),
reltol = nothing, store_trace::Val{StT} = Val(false), maxiters = 1000,
termination_condition = nothing, kwargs...) where {ShT, StT}
NonlinearSolve.__test_termination_condition(termination_condition, :LeastSquaresOptim)
function SciMLBase.__solve(
prob::AbstractNonlinearProblem, alg::LeastSquaresOptimJL, args...;
alias_u0 = false, abstol = nothing, reltol = nothing, maxiters = 1000,
trace_level = TraceMinimal(), termination_condition = nothing,
show_trace::Val = Val(false), store_trace::Val = Val(false), kwargs...
)
NonlinearSolveBase.assert_extension_supported_termination_condition(
termination_condition, alg
)

f!, u, resid = NonlinearSolve.__construct_extension_f(prob; alias_u0)
abstol = get_tolerance(abstol, eltype(u))
reltol = get_tolerance(reltol, eltype(u))
f!, u, resid = NonlinearSolveBase.construct_extension_function_wrapper(prob; alias_u0)
abstol = NonlinearSolveBase.get_tolerance(abstol, eltype(u))
reltol = NonlinearSolveBase.get_tolerance(reltol, eltype(u))

if prob.f.jac === nothing && alg.autodiff isa Symbol
lsoprob = LSO.LeastSquaresProblem(; x = u, f!, y = resid, alg.autodiff,
J = prob.f.jac_prototype, output_length = length(resid))
lsoprob = LSO.LeastSquaresProblem(;
x = u, f!, y = resid, alg.autodiff, J = prob.f.jac_prototype,
output_length = length(resid)
)
else
g! = NonlinearSolve.__construct_extension_jac(prob, alg, u, resid; alg.autodiff)
g! = NonlinearSolveBase.construct_extension_jac(prob, alg, u, resid; alg.autodiff)
lsoprob = LSO.LeastSquaresProblem(;
x = u, f!, y = resid, g!, J = prob.f.jac_prototype,
output_length = length(resid))
output_length = length(resid)
)
end

allocated_prob = LSO.LeastSquaresProblemAllocated(lsoprob, _lso_solver(alg))
linsolve = alg.ls === :qr ? LSO.QR() :
(alg.ls === :cholesky ? LSO.Cholesky() :
(alg.ls === :lsmr ? LSO.LSMR() : nothing))

return LeastSquaresOptimJLCache(prob,
alg,
allocated_prob,
(; x_tol = reltol, f_tol = abstol, g_tol = abstol, iterations = maxiters,
show_trace = ShT, store_trace = StT, show_every = trace_level.print_frequency))
end
lso_solver = if alg.alg === :lm
LSO.LevenbergMarquardt(linsolve)
elseif alg.alg === :dogleg
LSO.Dogleg(linsolve)
else
throw(ArgumentError("Unknown LeastSquaresOptim Algorithm: $(Meta.quot(alg.alg))"))
end

allocated_prob = LSO.LeastSquaresProblemAllocated(lsoprob, lso_solver(alg))
res = LSO.optimize!(
allocated_prob;
x_tol = reltol, f_tol = abstol, g_tol = abstol, iterations = maxiters,
show_trace = show_trace isa Val{true}, store_trace = store_trace isa Val{true},
show_every = trace_level.print_frequency
)

function SciMLBase.solve!(cache::LeastSquaresOptimJLCache)
res = LSO.optimize!(cache.allocated_prob; cache.kwargs...)
maxiters = cache.kwargs[:iterations]
retcode = res.x_converged || res.f_converged || res.g_converged ? ReturnCode.Success :
(res.iterations maxiters ? ReturnCode.MaxIters :
ReturnCode.ConvergenceFailure)
stats = SciMLBase.NLStats(res.f_calls, res.g_calls, -1, -1, res.iterations)

f!(resid, res.minimizer)

return SciMLBase.build_solution(
cache.prob, cache.alg, res.minimizer, res.ssr / 2; retcode, original = res, stats)
prob, alg, res.minimizer, resid; retcode, original = res, stats
)
end

end
Loading

0 comments on commit 693949b

Please sign in to comment.