Skip to content

Commit

Permalink
refactor: migrate to LineSearch.jl
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Sep 26, 2024
1 parent 05aa3db commit 40c4df9
Show file tree
Hide file tree
Showing 11 changed files with 324 additions and 540 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a"
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02"
LineSearch = "87fe0de2-c867-4266-b59a-2f0a94fc965b"
LineSearches = "d3d80556-e9d4-5f37-9878-2ab0fcc64255"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
Expand Down Expand Up @@ -78,6 +79,7 @@ Hwloc = "3"
InteractiveUtils = "<0.0.1, 1"
LazyArrays = "1.8.2, 2"
LeastSquaresOptim = "0.8.5"
LineSearch = "0.1"
LineSearches = "7.2"
LinearAlgebra = "1.10"
LinearSolve = "2.30"
Expand Down
7 changes: 0 additions & 7 deletions docs/src/devdocs/internal_interfaces.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,6 @@ NonlinearSolve.AbstractDampingFunction
NonlinearSolve.AbstractDampingFunctionCache
```

## Line Search

```@docs
NonlinearSolve.AbstractNonlinearSolveLineSearchAlgorithm
NonlinearSolve.AbstractNonlinearSolveLineSearchCache
```

## Trust Region

```@docs
Expand Down
98 changes: 50 additions & 48 deletions src/NonlinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ using LazyArrays: LazyArrays, ApplyArray, cache
using LinearAlgebra: LinearAlgebra, ColumnNorm, Diagonal, I, LowerTriangular, Symmetric,
UpperTriangular, axpy!, cond, diag, diagind, dot, issuccess, istril,
istriu, lu, mul!, norm, pinv, tril!, triu!
using LineSearch: LineSearch, AbstractLineSearchAlgorithm, AbstractLineSearchCache,
NoLineSearch
using LineSearches: LineSearches
using LinearSolve: LinearSolve, LUFactorization, QRFactorization, ComposePreconditioner,
InvPreconditioner, needs_concrete_A, AbstractFactorization,
Expand Down Expand Up @@ -103,54 +105,54 @@ include("algorithms/extension_algs.jl")
include("utils.jl")
include("default.jl")

@setup_workload begin
nlfuncs = ((NonlinearFunction{false}((u, p) -> u .* u .- p), 0.1),
(NonlinearFunction{true}((du, u, p) -> du .= u .* u .- p), [0.1]))
probs_nls = NonlinearProblem[]
for (fn, u0) in nlfuncs
push!(probs_nls, NonlinearProblem(fn, u0, 2.0))
end

nls_algs = (NewtonRaphson(), TrustRegion(), LevenbergMarquardt(),
PseudoTransient(), Broyden(), Klement(), DFSane(), nothing)

probs_nlls = NonlinearLeastSquaresProblem[]
nlfuncs = ((NonlinearFunction{false}((u, p) -> (u .^ 2 .- p)[1:1]), [0.1, 0.0]),
(NonlinearFunction{false}((u, p) -> vcat(u .* u .- p, u .* u .- p)), [0.1, 0.1]),
(
NonlinearFunction{true}(
(du, u, p) -> du[1] = u[1] * u[1] - p, resid_prototype = zeros(1)),
[0.1, 0.0]),
(
NonlinearFunction{true}((du, u, p) -> du .= vcat(u .* u .- p, u .* u .- p),
resid_prototype = zeros(4)),
[0.1, 0.1]))
for (fn, u0) in nlfuncs
push!(probs_nlls, NonlinearLeastSquaresProblem(fn, u0, 2.0))
end

nlls_algs = (LevenbergMarquardt(), GaussNewton(), TrustRegion(),
LevenbergMarquardt(; linsolve = LUFactorization()),
GaussNewton(; linsolve = LUFactorization()),
TrustRegion(; linsolve = LUFactorization()), nothing)

@compile_workload begin
@sync begin
for T in (Float32, Float64), (fn, u0) in nlfuncs
Threads.@spawn NonlinearProblem(fn, T.(u0), T(2))
end
for (fn, u0) in nlfuncs
Threads.@spawn NonlinearLeastSquaresProblem(fn, u0, 2.0)
end
for prob in probs_nls, alg in nls_algs
Threads.@spawn solve(prob, alg; abstol = 1e-2, verbose = false)
end
for prob in probs_nlls, alg in nlls_algs
Threads.@spawn solve(prob, alg; abstol = 1e-2, verbose = false)
end
end
end
end
# @setup_workload begin
# nlfuncs = ((NonlinearFunction{false}((u, p) -> u .* u .- p), 0.1),
# (NonlinearFunction{true}((du, u, p) -> du .= u .* u .- p), [0.1]))
# probs_nls = NonlinearProblem[]
# for (fn, u0) in nlfuncs
# push!(probs_nls, NonlinearProblem(fn, u0, 2.0))
# end

# nls_algs = (NewtonRaphson(), TrustRegion(), LevenbergMarquardt(),
# PseudoTransient(), Broyden(), Klement(), DFSane(), nothing)

# probs_nlls = NonlinearLeastSquaresProblem[]
# nlfuncs = ((NonlinearFunction{false}((u, p) -> (u .^ 2 .- p)[1:1]), [0.1, 0.0]),
# (NonlinearFunction{false}((u, p) -> vcat(u .* u .- p, u .* u .- p)), [0.1, 0.1]),
# (
# NonlinearFunction{true}(
# (du, u, p) -> du[1] = u[1] * u[1] - p, resid_prototype = zeros(1)),
# [0.1, 0.0]),
# (
# NonlinearFunction{true}((du, u, p) -> du .= vcat(u .* u .- p, u .* u .- p),
# resid_prototype = zeros(4)),
# [0.1, 0.1]))
# for (fn, u0) in nlfuncs
# push!(probs_nlls, NonlinearLeastSquaresProblem(fn, u0, 2.0))
# end

# nlls_algs = (LevenbergMarquardt(), GaussNewton(), TrustRegion(),
# LevenbergMarquardt(; linsolve = LUFactorization()),
# GaussNewton(; linsolve = LUFactorization()),
# TrustRegion(; linsolve = LUFactorization()), nothing)

# @compile_workload begin
# @sync begin
# for T in (Float32, Float64), (fn, u0) in nlfuncs
# Threads.@spawn NonlinearProblem(fn, T.(u0), T(2))
# end
# for (fn, u0) in nlfuncs
# Threads.@spawn NonlinearLeastSquaresProblem(fn, u0, 2.0)
# end
# for prob in probs_nls, alg in nls_algs
# Threads.@spawn solve(prob, alg; abstol = 1e-2, verbose = false)
# end
# for prob in probs_nlls, alg in nlls_algs
# Threads.@spawn solve(prob, alg; abstol = 1e-2, verbose = false)
# end
# end
# end
# end

# Core Algorithms
export NewtonRaphson, PseudoTransient, Klement, Broyden, LimitedMemoryBroyden, DFSane
Expand Down
22 changes: 3 additions & 19 deletions src/abstract_types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -106,22 +106,6 @@ function last_step_accepted(cache::AbstractDescentCache)
return true
end

"""
AbstractNonlinearSolveLineSearchAlgorithm
Abstract Type for all Line Search Algorithms used in NonlinearSolve.jl.
### `__internal_init` specification
```julia
__internal_init(
prob::AbstractNonlinearProblem, alg::AbstractNonlinearSolveLineSearchAlgorithm, f::F,
fu, u, p, args...; internalnorm::IN = DEFAULT_NORM, kwargs...) where {F, IN} -->
AbstractNonlinearSolveLineSearchCache
```
"""
abstract type AbstractNonlinearSolveLineSearchAlgorithm end

"""
AbstractNonlinearSolveLineSearchCache
Expand Down Expand Up @@ -512,9 +496,9 @@ SciMLBase.isinplace(::AbstractNonlinearSolveJacobianCache{iip}) where {iip} = ii
abstract type AbstractNonlinearSolveTraceLevel end

# Default Printing
for aType in (AbstractTrustRegionMethod, AbstractNonlinearSolveLineSearchAlgorithm,
AbstractResetCondition, AbstractApproximateJacobianUpdateRule,
AbstractDampingFunction, AbstractNonlinearSolveExtensionAlgorithm)
for aType in (AbstractTrustRegionMethod, AbstractResetCondition,
AbstractApproximateJacobianUpdateRule, AbstractDampingFunction,
AbstractNonlinearSolveExtensionAlgorithm)
@eval function Base.show(io::IO, alg::$(aType))
print(io, "$(nameof(typeof(alg)))()")
end
Expand Down
7 changes: 4 additions & 3 deletions src/algorithms/dfsane.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@ For other keyword arguments, see [`RobustNonMonotoneLineSearch`](@ref).
function DFSane(; σ_min = 1 // 10^10, σ_max = 1e10, σ_1 = 1, M::Int = 10, γ = 1 // 10^4,
τ_min = 1 // 10, τ_max = 1 // 2, n_exp::Int = 2, max_inner_iterations::Int = 100,
η_strategy::ETA = (fn_1, n, x_n, f_n) -> fn_1 / n^2) where {ETA}
linesearch = RobustNonMonotoneLineSearch(;
gamma = γ, sigma_1 = σ_1, M, tau_min = τ_min, tau_max = τ_max,
n_exp, η_strategy, maxiters = max_inner_iterations)
# linesearch = RobustNonMonotoneLineSearch(;
# gamma = γ, sigma_1 = σ_1, M, tau_min = τ_min, tau_max = τ_max,
# n_exp, η_strategy, maxiters = max_inner_iterations)
linesearch = NoLineSearch()
return GeneralizedDFSane{:DFSane}(linesearch, σ_min, σ_max, nothing)
end
12 changes: 6 additions & 6 deletions src/algorithms/klement.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,12 @@ over this.
function Klement(; max_resets::Int = 100, linsolve = nothing, alpha = nothing,
linesearch = NoLineSearch(), precs = DEFAULT_PRECS,
autodiff = nothing, init_jacobian::Val{IJ} = Val(:identity)) where {IJ}
if !(linesearch isa AbstractNonlinearSolveLineSearchAlgorithm)
Base.depwarn(
"Passing in a `LineSearches.jl` algorithm directly is deprecated. \
Please use `LineSearchesJL` instead.", :Klement)
linesearch = LineSearchesJL(; method = linesearch)
end
# if !(linesearch isa AbstractNonlinearSolveLineSearchAlgorithm)
# Base.depwarn(
# "Passing in a `LineSearches.jl` algorithm directly is deprecated. \
# Please use `LineSearchesJL` instead.", :Klement)
# linesearch = LineSearchesJL(; method = linesearch)
# end

if IJ === :identity
initialization = IdentityInitialization(alpha, DiagonalStructure())
Expand Down
7 changes: 3 additions & 4 deletions src/algorithms/pseudo_transient.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""
PseudoTransient(; concrete_jac = nothing, linsolve = nothing,
linesearch::AbstractNonlinearSolveLineSearchAlgorithm = NoLineSearch(),
precs = DEFAULT_PRECS, autodiff = nothing)
linesearch = NoLineSearch(), precs = DEFAULT_PRECS, autodiff = nothing)
An implementation of PseudoTransient Method [coffey2003pseudotransient](@cite) that is used
to solve steady state problems in an accelerated manner. It uses an adaptive time-stepping
Expand All @@ -16,8 +15,8 @@ This implementation specifically uses "switched evolution relaxation"
you are going to need more iterations to converge but it can be more stable.
"""
function PseudoTransient(; concrete_jac = nothing, linsolve = nothing,
linesearch::AbstractNonlinearSolveLineSearchAlgorithm = NoLineSearch(),
precs = DEFAULT_PRECS, autodiff = nothing, alpha_initial = 1e-3)
linesearch = NoLineSearch(), precs = DEFAULT_PRECS, autodiff = nothing,
alpha_initial = 1e-3)
descent = DampedNewtonDescent(; linsolve, precs, initial_damping = alpha_initial,
damping_fn = SwitchedEvolutionRelaxation())
return GeneralizedFirstOrderAlgorithm(;
Expand Down
12 changes: 6 additions & 6 deletions src/core/approximate_jacobian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,12 @@ function ApproximateJacobianSolveAlgorithm{concrete_jac, name}(;
linesearch = missing, trustregion = missing, descent, update_rule,
reinit_rule, initialization, max_resets::Int = typemax(Int),
max_shrink_times::Int = typemax(Int)) where {concrete_jac, name}
if linesearch !== missing && !(linesearch isa AbstractNonlinearSolveLineSearchAlgorithm)
Base.depwarn("Passing in a `LineSearches.jl` algorithm directly is deprecated. \
Please use `LineSearchesJL` instead.",
:GeneralizedFirstOrderAlgorithm)
linesearch = LineSearchesJL(; method = linesearch)
end
# if linesearch !== missing && !(linesearch isa AbstractNonlinearSolveLineSearchAlgorithm)
# Base.depwarn("Passing in a `LineSearches.jl` algorithm directly is deprecated. \
# Please use `LineSearchesJL` instead.",
# :GeneralizedFirstOrderAlgorithm)
# linesearch = LineSearchesJL(; method = linesearch)
# end
return ApproximateJacobianSolveAlgorithm{concrete_jac, name}(
linesearch, trustregion, descent, update_rule,
reinit_rule, max_resets, max_shrink_times, initialization)
Expand Down
24 changes: 14 additions & 10 deletions src/core/generalized_first_order.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,12 @@ function GeneralizedFirstOrderAlgorithm{concrete_jac, name}(;
jacobian_ad !== nothing && ADTypes.mode(jacobian_ad) isa ADTypes.ReverseMode,
jacobian_ad, nothing))

if linesearch !== missing && !(linesearch isa AbstractNonlinearSolveLineSearchAlgorithm)
Base.depwarn("Passing in a `LineSearches.jl` algorithm directly is deprecated. \
Please use `LineSearchesJL` instead.",
:GeneralizedFirstOrderAlgorithm)
linesearch = LineSearchesJL(; method = linesearch)
end
# if linesearch !== missing && !(linesearch isa AbstractNonlinearSolveLineSearchAlgorithm)
# Base.depwarn("Passing in a `LineSearches.jl` algorithm directly is deprecated. \
# Please use `LineSearchesJL` instead.",
# :GeneralizedFirstOrderAlgorithm)
# linesearch = LineSearchesJL(; method = linesearch)
# end

return GeneralizedFirstOrderAlgorithm{concrete_jac, name}(
linesearch, trustregion, descent, max_shrink_times,
Expand Down Expand Up @@ -199,8 +199,11 @@ function SciMLBase.__init(
if alg.linesearch !== missing
supports_line_search(alg.descent) || error("Line Search not supported by \
$(alg.descent).")
linesearch_cache = __internal_init(
prob, alg.linesearch, f, fu, u, p; stats, internalnorm, kwargs...)
linesearch_ad = alg.forward_ad === nothing ?
(alg.reverse_ad === nothing ? alg.jacobian_ad :
alg.reverse_ad) : alg.forward_ad
linesearch_cache = init(
prob, alg.linesearch, fu, u; stats, autodiff = linesearch_ad, kwargs...)
GB = :LineSearch
end

Expand Down Expand Up @@ -264,8 +267,9 @@ function __step!(cache::GeneralizedFirstOrderAlgorithmCache{iip, GB};
cache.make_new_jacobian = true
if GB === :LineSearch
@static_timeit cache.timer "linesearch" begin
linesearch_failed, α = __internal_solve!(
cache.linesearch_cache, cache.u, δu)
linesearch_sol = solve!(cache.linesearch_cache, cache.u, δu)
linesearch_failed = !SciMLBase.successful_retcode(linesearch_sol.retcode)
α = linesearch_sol.step_size
end
if linesearch_failed
cache.retcode = ReturnCode.InternalLineSearchFailed
Expand Down
5 changes: 2 additions & 3 deletions src/core/spectral_methods.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,8 @@ Method.
### Arguments
- `linesearch`: Globalization using a Line Search Method. This needs to follow the
[`NonlinearSolve.AbstractNonlinearSolveLineSearchAlgorithm`](@ref) interface. This
is not optional currently, but that restriction might be lifted in the future.
- `linesearch`: Globalization using a Line Search Method. This is not optional currently,
but that restriction might be lifted in the future.
- `σ_min`: The minimum spectral parameter allowed. This is used to ensure that the
spectral parameter is not too small.
- `σ_max`: The maximum spectral parameter allowed. This is used to ensure that the
Expand Down
Loading

0 comments on commit 40c4df9

Please sign in to comment.