Skip to content

Commit

Permalink
refactor: minor cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Oct 28, 2024
1 parent 5af6010 commit 4e1098f
Show file tree
Hide file tree
Showing 11 changed files with 120 additions and 77 deletions.
22 changes: 22 additions & 0 deletions common/nlls_problem_workloads.jl
Original file line number Diff line number Diff line change
@@ -1 +1,23 @@
using SciMLBase: NonlinearLeastSquaresProblem, NonlinearFunction

nonlinear_functions = (
(NonlinearFunction{false}((u, p) -> (u .^ 2 .- p)[1:1]), [0.1, 0.0]),
(NonlinearFunction{false}((u, p) -> vcat(u .* u .- p, u .* u .- p)), [0.1, 0.1]),
(
NonlinearFunction{true}(
(du, u, p) -> du[1] = u[1] * u[1] - p, resid_prototype = zeros(1)
),
[0.1, 0.0]
),
(
NonlinearFunction{true}(
(du, u, p) -> du .= vcat(u .* u .- p, u .* u .- p), resid_prototype = zeros(4)
),
[0.1, 0.1]
)
)

nlls_problems = NonlinearLeastSquaresProblem[]
for (fn, u0) in nonlinear_functions
push!(nlls_problems, NonlinearLeastSquaresProblem(fn, u0, 2.0))
end
4 changes: 3 additions & 1 deletion lib/BracketingNonlinearSolve/Project.toml
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
name = "BracketingNonlinearSolve"
uuid = "70df07ce-3d50-431d-a3e7-ca6ddb60ac1e"
authors = ["Avik Pal <[email protected]> and contributors"]
version = "1.0.0"
version = "1.1.0"

[deps]
CommonSolve = "38540f10-b2f7-11e9-35d8-d573e4eb0ff2"
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
NonlinearSolveBase = "be0214bd-f91f-a760-ac4e-3421ce2b2da0"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"

[weakdeps]
Expand All @@ -25,6 +26,7 @@ ForwardDiff = "0.10.36"
InteractiveUtils = "<0.0.1, 1"
NonlinearSolveBase = "1"
PrecompileTools = "1.2"
Reexport = "1.2"
SciMLBase = "2.50"
Test = "1.10"
TestItemRunner = "1"
Expand Down
7 changes: 4 additions & 3 deletions lib/BracketingNonlinearSolve/src/BracketingNonlinearSolve.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
module BracketingNonlinearSolve

using ConcreteStructs: @concrete
using Reexport: @reexport

using CommonSolve: CommonSolve, solve
using NonlinearSolveBase: NonlinearSolveBase
Expand Down Expand Up @@ -30,7 +31,8 @@ end
@setup_workload begin
for T in (Float32, Float64)
prob_brack = IntervalNonlinearProblem{false}(
(u, p) -> u^2 - p, T.((0.0, 2.0)), T(2))
(u, p) -> u^2 - p, T.((0.0, 2.0)), T(2)
)
algs = (Alefeld(), Bisection(), Brent(), Falsi(), ITP(), Ridder())

@compile_workload begin
Expand All @@ -41,8 +43,7 @@ end
end
end

export IntervalNonlinearProblem
export solve
@reexport using SciMLBase, NonlinearSolveBase

export Alefeld, Bisection, Brent, Falsi, ITP, Ridder

Expand Down
2 changes: 0 additions & 2 deletions lib/NonlinearSolveFirstOrder/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ NonlinearSolveBase = "be0214bd-f91f-a760-ac4e-3421ce2b2da0"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
SciMLOperators = "c0aeaf25-5076-4817-a8d5-81caf7dfa961"
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"

Expand All @@ -46,7 +45,6 @@ PrecompileTools = "1.2"
ReTestItems = "1.24"
Reexport = "1"
SciMLBase = "2.54"
SciMLOperators = "0.3.11"
Setfield = "1.1.1"
StableRNGs = "1"
StaticArraysCore = "1.4.3"
Expand Down
43 changes: 31 additions & 12 deletions lib/NonlinearSolveFirstOrder/src/NonlinearSolveFirstOrder.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,24 +10,20 @@ using ConcreteStructs: @concrete
using DiffEqBase: DiffEqBase # Needed for `init` / `solve` dispatches
using FiniteDiff: FiniteDiff # Default Finite Difference Method
using ForwardDiff: ForwardDiff # Default Forward Mode AD
using LinearAlgebra: LinearAlgebra, Diagonal, dot, inv, diag
using LinearAlgebra: LinearAlgebra, Diagonal, dot
using LinearSolve: LinearSolve # Trigger Linear Solve extension in NonlinearSolveBase
using MaybeInplace: @bb
using NonlinearSolveBase: NonlinearSolveBase, AbstractNonlinearSolveAlgorithm,
AbstractNonlinearSolveCache, AbstractResetCondition,
AbstractResetConditionCache, AbstractApproximateJacobianStructure,
AbstractJacobianCache, AbstractJacobianInitialization,
AbstractApproximateJacobianUpdateRule, AbstractDescentDirection,
AbstractApproximateJacobianUpdateRuleCache,
AbstractDampingFunction, AbstractDampingFunctionCache,
AbstractTrustRegionMethod, AbstractTrustRegionMethodCache,
AbstractNonlinearSolveCache, AbstractDampingFunction,
AbstractDampingFunctionCache, AbstractTrustRegionMethod,
AbstractTrustRegionMethodCache,
Utils, InternalAPI, get_timer_output, @static_timeit,
update_trace!, L2_NORM,
NewtonDescent, DampedNewtonDescent
using SciMLBase: SciMLBase, AbstractNonlinearProblem, NLStats, ReturnCode
using SciMLOperators: AbstractSciMLOperator
NewtonDescent, DampedNewtonDescent, GeodesicAcceleration,
Dogleg
using SciMLBase: SciMLBase, AbstractNonlinearProblem, NLStats, ReturnCode, NonlinearFunction
using Setfield: @set!
using StaticArraysCore: StaticArray, SArray, Size, MArray
using StaticArraysCore: SArray

include("raphson.jl")
include("gauss_newton.jl")
Expand All @@ -37,6 +33,29 @@ include("pseudo_transient.jl")

include("solve.jl")

@setup_workload begin
include(joinpath(
@__DIR__, "..", "..", "..", "common", "nonlinear_problem_workloads.jl"
))
include(joinpath(
@__DIR__, "..", "..", "..", "common", "nlls_problem_workloads.jl"
))

# XXX: TrustRegion
nlp_algs = [NewtonRaphson(), LevenbergMarquardt()]
nlls_algs = [GaussNewton(), LevenbergMarquardt()]

@compile_workload begin
for prob in nonlinear_problems, alg in nlp_algs
CommonSolve.solve(prob, alg; abstol = 1e-2, verbose = false)
end

for prob in nlls_problems, alg in nlls_algs
CommonSolve.solve(prob, alg; abstol = 1e-2, verbose = false)
end
end
end

@reexport using SciMLBase, NonlinearSolveBase

export NewtonRaphson, PseudoTransient
Expand Down
6 changes: 3 additions & 3 deletions lib/NonlinearSolveFirstOrder/src/levenberg_marquardt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -131,10 +131,10 @@ end
function InternalAPI.solve!(
cache::LevenbergMarquardtDampingCache, J, fu, ::Val{false}; kwargs...
)
if ArrayInterface.can_setindex(cache.J_diag_cache)
sum!(abs2, Utils.safe_vec(cache.J_diag_cache), J')
elseif cache.J_diag_cache isa Number
if cache.J_diag_cache isa Number
cache.J_diag_cache = abs2(J)
elseif ArrayInterface.can_setindex(cache.J_diag_cache)
sum!(abs2, Utils.safe_vec(cache.J_diag_cache), J')
else
cache.J_diag_cache = dropdims(sum(abs2, J'; dims = 1); dims = 1)
end
Expand Down
39 changes: 39 additions & 0 deletions lib/NonlinearSolveFirstOrder/src/trust_region.jl
Original file line number Diff line number Diff line change
@@ -1 +1,40 @@
"""
TrustRegion(;
concrete_jac = nothing, linsolve = nothing, precs = nothing,
radius_update_scheme = RadiusUpdateSchemes.Simple, max_trust_radius::Real = 0 // 1,
initial_trust_radius::Real = 0 // 1, step_threshold::Real = 1 // 10000,
shrink_threshold::Real = 1 // 4, expand_threshold::Real = 3 // 4,
shrink_factor::Real = 1 // 4, expand_factor::Real = 2 // 1,
max_shrink_times::Int = 32,
vjp_autodiff = nothing, autodiff = nothing, jvp_autodiff = nothing
)
An advanced TrustRegion implementation with support for efficient handling of sparse
matrices via colored automatic differentiation and preconditioned linear solvers. Designed
for large-scale and numerically-difficult nonlinear systems.
### Keyword Arguments
- `radius_update_scheme`: the scheme used to update the trust region radius. Defaults to
`RadiusUpdateSchemes.Simple`. See [`RadiusUpdateSchemes`](@ref) for more details. For a
review on trust region radius update schemes, see [yuan2015recent](@citet).
For the remaining arguments, see [`NonlinearSolve.GenericTrustRegionScheme`](@ref)
documentation.
"""
function TrustRegion(;
concrete_jac = nothing, linsolve = nothing, precs = nothing,
radius_update_scheme = RadiusUpdateSchemes.Simple, max_trust_radius::Real = 0 // 1,
initial_trust_radius::Real = 0 // 1, step_threshold::Real = 1 // 10000,
shrink_threshold::Real = 1 // 4, expand_threshold::Real = 3 // 4,
shrink_factor::Real = 1 // 4, expand_factor::Real = 2 // 1,
max_shrink_times::Int = 32,
autodiff = nothing, vjp_autodiff = nothing, jvp_autodiff = nothing
)
descent = Dogleg(; linsolve, precs)
trustregion = GenericTrustRegionScheme(;
method = radius_update_scheme, step_threshold, shrink_threshold, expand_threshold,
shrink_factor, expand_factor, initial_trust_radius, max_trust_radius)
return GeneralizedFirstOrderAlgorithm{concrete_jac, :TrustRegion}(;
trustregion, descent, autodiff, vjp_autodiff, jvp_autodiff, max_shrink_times)
end
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,8 @@ include("solve.jl")
algs = [Broyden(), Klement()]

@compile_workload begin
@sync begin
for prob in nonlinear_problems, alg in algs
Threads.@spawn CommonSolve.solve(prob, alg; abstol = 1e-2, verbose = false)
end
for prob in nonlinear_problems, alg in algs
CommonSolve.solve(prob, alg; abstol = 1e-2, verbose = false)
end
end
end
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,8 @@ include("solve.jl")
algs = [DFSane()]

@compile_workload begin
@sync begin
for prob in nonlinear_problems, alg in algs
Threads.@spawn CommonSolve.solve(prob, alg; abstol = 1e-2, verbose = false)
end
for prob in nonlinear_problems, alg in algs
CommonSolve.solve(prob, alg; abstol = 1e-2, verbose = false)
end
end
end
Expand Down
6 changes: 2 additions & 4 deletions lib/SimpleNonlinearSolve/src/SimpleNonlinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -130,10 +130,8 @@ function solve_adjoint_internal end
#!format: on

@compile_workload begin
@sync for alg in algs
for prob in (prob_scalar, prob_iip, prob_oop)
Threads.@spawn CommonSolve.solve(prob, alg; abstol = 1e-2)
end
for prob in (prob_scalar, prob_iip, prob_oop)
CommonSolve.solve(prob, alg; abstol = 1e-2)
end
end
end
Expand Down
56 changes: 12 additions & 44 deletions src/NonlinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,22 +57,9 @@ using SparseMatrixColorings: SparseMatrixColorings # NOTE: This triggers an exte

const DI = DifferentiationInterface

const True = Val(true)
const False = Val(false)

include("timer_outputs.jl")
include("internal/helpers.jl")

include("globalization/trust_region.jl")

include("core/generalized_first_order.jl")

include("algorithms/raphson.jl")
include("algorithms/pseudo_transient.jl")
include("algorithms/gauss_newton.jl")
include("algorithms/levenberg_marquardt.jl")
include("algorithms/trust_region.jl")

include("algorithms/extension_algs.jl")

include("utils.jl")
Expand Down Expand Up @@ -100,9 +87,6 @@ include("internal/forward_diff.jl") # we need to define after the algorithms
end

nls_algs = (
NewtonRaphson(),
TrustRegion(),
LevenbergMarquardt(),
nothing
)

Expand Down Expand Up @@ -132,44 +116,28 @@ include("internal/forward_diff.jl") # we need to define after the algorithms
)

@compile_workload begin
@sync begin
for prob in probs_nls, alg in nls_algs
Threads.@spawn solve(prob, alg; abstol = 1e-2, verbose = false)
end
for prob in probs_nlls, alg in nlls_algs
Threads.@spawn solve(prob, alg; abstol = 1e-2, verbose = false)
end
for prob in probs_nls, alg in nls_algs
solve(prob, alg; abstol = 1e-2, verbose = false)
end
for prob in probs_nlls, alg in nlls_algs
solve(prob, alg; abstol = 1e-2, verbose = false)
end
end
end

# Rexexports
@reexport using SciMLBase, SimpleNonlinearSolve, NonlinearSolveBase,
NonlinearSolveSpectralMethods, NonlinearSolveQuasiNewton
@reexport using SciMLBase, NonlinearSolveBase
@reexport using NonlinearSolveFirstOrder, NonlinearSolveSpectralMethods,
NonlinearSolveQuasiNewton, SimpleNonlinearSolve
@reexport using LineSearch
@reexport using ADTypes

# Core Algorithms
export NewtonRaphson, PseudoTransient
export GaussNewton, LevenbergMarquardt, TrustRegion
export NonlinearSolvePolyAlgorithm, RobustMultiNewton, FastShortcutNonlinearPolyalg,
FastShortcutNLLSPolyalg
export NonlinearSolvePolyAlgorithm, RobustMultiNewton,
FastShortcutNonlinearPolyalg, FastShortcutNLLSPolyalg

# Extension Algorithms
export LeastSquaresOptimJL, FastLevenbergMarquardtJL, NLsolveJL, NLSolversJL,
FixedPointAccelerationJL, SpeedMappingJL, SIAMFANLEquationsJL
export PETScSNES, CMINPACK

# Advanced Algorithms -- Without Bells and Whistles
export GeneralizedFirstOrderAlgorithm

# Globalization
## Line Search Algorithms
export LineSearch, BackTracking, NoLineSearch, RobustNonMonotoneLineSearch,
LiFukushimaLineSearch, LineSearchesJL
## Trust Region Algorithms
export RadiusUpdateSchemes

# Reexport ADTypes
export AutoFiniteDiff, AutoForwardDiff, AutoPolyesterForwardDiff, AutoZygote, AutoEnzyme,
AutoSparse

end

0 comments on commit 4e1098f

Please sign in to comment.