From 7e26d18c78173f887c7bda5c2a0b1bc20112701d Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 20 Sep 2023 14:26:23 -0400 Subject: [PATCH] Add support for line search in Newton Raphson --- Project.toml | 2 + src/NonlinearSolve.jl | 5 +- src/jacobian.jl | 11 ++-- src/levenberg.jl | 12 ++-- src/linesearch.jl | 146 ++++++++++++++++++++++++++++++++++++++++++ src/raphson.jl | 35 ++++++---- src/trustRegion.jl | 11 +--- src/utils.jl | 23 +++++++ test/basictests.jl | 56 +++++++++------- 9 files changed, 241 insertions(+), 60 deletions(-) create mode 100644 src/linesearch.jl diff --git a/Project.toml b/Project.toml index ab1f6a500..f5d4ddcef 100644 --- a/Project.toml +++ b/Project.toml @@ -11,6 +11,7 @@ DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e" EnumX = "4e289a0a-7415-4d19-859d-a7e5c4648b56" FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +LineSearches = "d3d80556-e9d4-5f37-9878-2ab0fcc64255" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae" PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" @@ -33,6 +34,7 @@ Enzyme = "0.11" FiniteDiff = "2" ForwardDiff = "0.10.3" LinearSolve = "2" +LineSearches = "7" PrecompileTools = "1" RecursiveArrayTools = "2" Reexport = "0.2, 1" diff --git a/src/NonlinearSolve.jl b/src/NonlinearSolve.jl index 2f851faa3..615f96c03 100644 --- a/src/NonlinearSolve.jl +++ b/src/NonlinearSolve.jl @@ -20,7 +20,7 @@ import SciMLBase: AbstractNonlinearAlgorithm, NLStats, _unwrap_val, has_jac, isi import StaticArraysCore: StaticArray, SVector, SArray, MArray import UnPack: @unpack -@reexport using ADTypes, SciMLBase, SimpleNonlinearSolve +@reexport using ADTypes, LineSearches, SciMLBase, SimpleNonlinearSolve const AbstractSparseADType = Union{ADTypes.AbstractSparseFiniteDifferences, ADTypes.AbstractSparseForwardMode, ADTypes.AbstractSparseReverseMode} @@ -35,6 +35,7 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::AbstractNonlinearSolveAl end include("utils.jl") +include("linesearch.jl") include("raphson.jl") include("trustRegion.jl") include("levenberg.jl") @@ -69,4 +70,6 @@ export RadiusUpdateSchemes export NewtonRaphson, TrustRegion, LevenbergMarquardt +export LineSearch + end # module diff --git a/src/jacobian.jl b/src/jacobian.jl index aea7b4270..83d26fee6 100644 --- a/src/jacobian.jl +++ b/src/jacobian.jl @@ -9,7 +9,7 @@ end (uf::JacobianWrapper{false})(res, u) = (vec(res) .= vec(uf.f(u, uf.p))) (uf::JacobianWrapper{true})(res, u) = uf.f(res, u, uf.p) -sparsity_detection_alg(f, ad) = NoSparsityDetection() +sparsity_detection_alg(_, _) = NoSparsityDetection() function sparsity_detection_alg(f, ad::AbstractSparseADType) if f.sparsity === nothing if f.jac_prototype === nothing @@ -49,8 +49,8 @@ end jacobian!!(::Number, cache) = last(value_derivative(cache.uf, cache.u)) # Build Jacobian Caches -function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f, u, p, - ::Val{iip}) where {iip} +function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f, u, p, ::Val{iip}; + linsolve_kwargs=(;)) where {iip} uf = JacobianWrapper{iip}(f, p) haslinsolve = hasfield(typeof(alg), :linsolve) @@ -92,14 +92,15 @@ function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f, u, p, Pl, Pr = wrapprecs(alg.precs(J, nothing, u, p, nothing, nothing, nothing, nothing, nothing)..., weight) - linsolve = init(linprob, alg.linsolve; alias_A = true, alias_b = true, Pl, Pr) + linsolve = init(linprob, alg.linsolve; alias_A = true, alias_b = true, Pl, Pr, + linsolve_kwargs...) return uf, linsolve, J, fu, jac_cache, du end ## Special Handling for Scalars function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f, u::Number, p, - ::Val{false}) + ::Val{false}; kwargs...) # NOTE: Scalar `u` assumes scalar output from `f` uf = JacobianWrapper{false}(f, p) return uf, nothing, u, nothing, nothing, u diff --git a/src/levenberg.jl b/src/levenberg.jl index 6265eba3f..17f61475f 100644 --- a/src/levenberg.jl +++ b/src/levenberg.jl @@ -142,16 +142,12 @@ isinplace(::LevenbergMarquardtCache{iip}) where {iip} = iip function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::LevenbergMarquardt, args...; alias_u0 = false, maxiters = 1000, abstol = 1e-6, internalnorm = DEFAULT_NORM, - kwargs...) where {uType, iip} + linsolve_kwargs=(;), kwargs...) where {uType, iip} @unpack f, u0, p = prob u = alias_u0 ? u0 : deepcopy(u0) - if iip - fu1 = f.resid_prototype === nothing ? zero(u) : f.resid_prototype - f(fu1, u, p) - else - fu1 = f(u, p) - end - uf, linsolve, J, fu2, jac_cache, du = jacobian_caches(alg, f, u, p, Val(iip)) + fu1 = evaluate_f(prob, u) + uf, linsolve, J, fu2, jac_cache, du = jacobian_caches(alg, f, u, p, Val(iip); + linsolve_kwargs) λ = convert(eltype(u), alg.damping_initial) λ_factor = convert(eltype(u), alg.damping_increase_factor) diff --git a/src/linesearch.jl b/src/linesearch.jl new file mode 100644 index 000000000..3890f8230 --- /dev/null +++ b/src/linesearch.jl @@ -0,0 +1,146 @@ +""" + LineSearch(method = Static(), autodiff = AutoFiniteDiff(), alpha = true) + +Wrapper over algorithms from +[LineSeaches.jl](https://github.com/JuliaNLSolvers/LineSearches.jl/). Allows automatic +construction of the objective functions for the line search algorithms utilizing automatic +differentiation for fast Vector Jacobian Products. + +### Arguments + + - `method`: the line search algorithm to use. Defaults to `Static()`, which means that the + step size is fixed to the value of `alpha`. + - `autodiff`: the automatic differentiation backend to use for the line search. Defaults to + `AutoFiniteDiff()`, which means that finite differencing is used to compute the VJP. + `AutoZygote()` will be faster in most cases, but it requires `Zygote.jl` to be manually + installed and loaded + - `alpha`: the initial step size to use. Defaults to `true` (which is equivalent to `1`). +""" +@concrete struct LineSearch + method + autodiff + α +end + +function LineSearch(; method = Static(), autodiff = AutoFiniteDiff(), alpha = true) + return LineSearch(method, autodiff, alpha) +end + +@concrete mutable struct LineSearchCache + f + ϕ + dϕ + ϕdϕ + α + ls +end + +function LineSearchCache(ls::LineSearch, f, u::Number, p, _, ::Val{false}) + eval_f(u, du, α) = eval_f(u - α * du) + eval_f(u) = f(u, p) + + ls.method isa Static && return LineSearchCache(eval_f, nothing, nothing, nothing, + convert(typeof(u), ls.α), ls) + + g(u, fu) = last(value_derivative(Base.Fix2(f, p), u)) * fu + + function ϕ(u, du) + function ϕ_internal(α) + u_ = u - α * du + _fu = eval_f(u_) + return dot(_fu, _fu) / 2 + end + return ϕ_internal + end + + function dϕ(u, du) + function dϕ_internal(α) + u_ = u - α * du + _fu = eval_f(u_) + g₀ = g(u_, _fu) + return dot(g₀, -du) + end + return dϕ_internal + end + + function ϕdϕ(u, du) + function ϕdϕ_internal(α) + u_ = u - α * du + _fu = eval_f(u_) + g₀ = g(u_, _fu) + return dot(_fu, _fu) / 2, dot(g₀, -du) + end + return ϕdϕ_internal + end + + return LineSearchCache(eval_f, ϕ, dϕ, ϕdϕ, convert(eltype(u), ls.α), ls) +end + +function LineSearchCache(ls::LineSearch, f, u, p, fu1, IIP::Val{iip}) where {iip} + fu = iip ? fu1 : nothing + u_ = _mutable_zero(u) + + function eval_f(u, du, α) + @. u_ = u - α * du + return eval_f(u_) + end + eval_f(u) = evaluate_f(f, u, p, IIP; fu) + + ls.method isa Static && return LineSearchCache(eval_f, nothing, nothing, nothing, + convert(eltype(u), ls.α), ls) + + g₀ = _mutable_zero(u) + + function g!(u, fu) + op = VecJac((args...) -> f(args..., p), u) + if iip + mul!(g₀, op, fu) + return g₀ + else + return op * fu + end + end + + function ϕ(u, du) + function ϕ_internal(α) + @. u_ = u - α * du + _fu = eval_f(u_) + return dot(_fu, _fu) / 2 + end + return ϕ_internal + end + + function dϕ(u, du) + function dϕ_internal(α) + @. u_ = u - α * du + _fu = eval_f(u_) + g₀ = g!(u_, _fu) + return dot(g₀, -du) + end + return dϕ_internal + end + + function ϕdϕ(u, du) + function ϕdϕ_internal(α) + @. u_ = u - α * du + _fu = eval_f(u_) + g₀ = g!(u_, _fu) + return dot(_fu, _fu) / 2, dot(g₀, -du) + end + return ϕdϕ_internal + end + + return LineSearchCache(eval_f, ϕ, dϕ, ϕdϕ, convert(eltype(u), ls.α), ls) +end + +function perform_linesearch!(cache::LineSearchCache, u, du) + cache.ls.method isa Static && return (cache.α, cache.f(u, du, cache.α)) + + ϕ = cache.ϕ(u, du) + dϕ = cache.dϕ(u, du) + ϕdϕ = cache.ϕdϕ(u, du) + + ϕ₀, dϕ₀ = ϕdϕ(zero(eltype(u))) + + return cache.ls.method(ϕ, cache.dϕ(u, du), cache.ϕdϕ(u, du), cache.α, ϕ₀, dϕ₀) +end diff --git a/src/raphson.jl b/src/raphson.jl index 33d12c4ba..d01881dc4 100644 --- a/src/raphson.jl +++ b/src/raphson.jl @@ -25,19 +25,24 @@ for large-scale and numerically-difficult nonlinear systems. preconditioners. For more information on specifying preconditioners for LinearSolve algorithms, consult the [LinearSolve.jl documentation](https://docs.sciml.ai/LinearSolve/stable/). + - `linesearch`: the line search algorithm to use. Defaults to [`LineSearch()`](@ref), + which means that no line search is performed. Algorithms from `LineSearches.jl` can be + used here directly, and they will be converted to the correct `LineSearch`. """ @concrete struct NewtonRaphson{CJ, AD} <: AbstractNewtonAlgorithm{CJ, AD} ad::AD linsolve precs + linesearch end concrete_jac(::NewtonRaphson{CJ}) where {CJ} = CJ function NewtonRaphson(; concrete_jac = nothing, linsolve = nothing, - precs = DEFAULT_PRECS, adkwargs...) + linesearch = LineSearch(), precs = DEFAULT_PRECS, adkwargs...) ad = default_adargs_to_adtype(; adkwargs...) - return NewtonRaphson{_unwrap_val(concrete_jac)}(ad, linsolve, precs) + linesearch = linesearch isa LineSearch ? linesearch : LineSearch(; method=linesearch) + return NewtonRaphson{_unwrap_val(concrete_jac)}(ad, linsolve, precs, linesearch) end @concrete mutable struct NewtonRaphsonCache{iip} @@ -59,26 +64,23 @@ end abstol prob stats::NLStats + lscache end isinplace(::NewtonRaphsonCache{iip}) where {iip} = iip function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::NewtonRaphson, args...; alias_u0 = false, maxiters = 1000, abstol = 1e-6, internalnorm = DEFAULT_NORM, - kwargs...) where {uType, iip} + linsolve_kwargs=(;), kwargs...) where {uType, iip} @unpack f, u0, p = prob u = alias_u0 ? u0 : deepcopy(u0) - if iip - fu1 = f.resid_prototype === nothing ? zero(u) : f.resid_prototype - f(fu1, u, p) - else - fu1 = _mutable(f(u, p)) - end - uf, linsolve, J, fu2, jac_cache, du = jacobian_caches(alg, f, u, p, Val(iip)) + fu1 = evaluate_f(prob, u) + uf, linsolve, J, fu2, jac_cache, du = jacobian_caches(alg, f, u, p, Val(iip); + linsolve_kwargs) return NewtonRaphsonCache{iip}(f, alg, u, fu1, fu2, du, p, uf, linsolve, J, jac_cache, false, maxiters, internalnorm, ReturnCode.Default, abstol, prob, - NLStats(1, 0, 0, 0, 0)) + NLStats(1, 0, 0, 0, 0), LineSearchCache(alg.linesearch, f, u, p, fu1, Val(iip))) end function perform_step!(cache::NewtonRaphsonCache{true}) @@ -89,8 +91,10 @@ function perform_step!(cache::NewtonRaphsonCache{true}) linres = dolinsolve(alg.precs, linsolve; A = J, b = _vec(fu1), linu = _vec(du), p, reltol = cache.abstol) cache.linsolve = linres.cache - @. u = u - du - f(fu1, u, p) + + # Line Search + α, _ = perform_linesearch!(cache.lscache, u, du) + @. u = u - α * du cache.internalnorm(fu1) < cache.abstol && (cache.force_stop = true) cache.stats.nf += 1 @@ -112,7 +116,10 @@ function perform_step!(cache::NewtonRaphsonCache{false}) linu = _vec(cache.du), p, reltol = cache.abstol) cache.linsolve = linres.cache end - cache.u = @. u - cache.du # `u` might not support mutation + + # Line Search + α, _fu = perform_linesearch!(cache.lscache, u, cache.du) + cache.u = @. u - α * cache.du # `u` might not support mutation cache.fu1 = f(cache.u, p) cache.internalnorm(fu1) < cache.abstol && (cache.force_stop = true) diff --git a/src/trustRegion.jl b/src/trustRegion.jl index 41ccb994e..e0892a4da 100644 --- a/src/trustRegion.jl +++ b/src/trustRegion.jl @@ -202,20 +202,15 @@ end function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::TrustRegion, args...; alias_u0 = false, maxiters = 1000, abstol = 1e-8, internalnorm = DEFAULT_NORM, - kwargs...) where {uType, iip} + linsolve_kwargs=(;), kwargs...) where {uType, iip} @unpack f, u0, p = prob u = alias_u0 ? u0 : deepcopy(u0) u_prev = zero(u) - if iip - fu1 = f.resid_prototype === nothing ? zero(u) : f.resid_prototype - f(fu1, u, p) - else - fu1 = f(u, p) - end + fu1 = evaluate_f(prob, u) fu_prev = zero(fu1) loss = get_loss(fu1) - uf, linsolve, J, fu2, jac_cache, du = jacobian_caches(alg, f, u, p, Val(iip)) + uf, linsolve, J, fu2, jac_cache, du = jacobian_caches(alg, f, u, p, Val(iip); linsolve_kwargs) radius_update_scheme = alg.radius_update_scheme max_trust_radius = convert(eltype(u), alg.max_trust_radius) diff --git a/src/utils.jl b/src/utils.jl index 3df540632..7498d5afa 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -142,3 +142,26 @@ _maybe_mutable(x, ::AbstractFiniteDifferencesMode) = _mutable(x) # The shadow allocated for Enzyme needs to be mutable _maybe_mutable(x, ::AutoSparseEnzyme) = _mutable(x) _maybe_mutable(x, _) = x + +# Helper function to get value of `f(u, p)` +function evaluate_f(prob::NonlinearProblem{uType, iip}, u) where {uType, iip} + @unpack f, u0, p = prob + if iip + fu = f.resid_prototype === nothing ? zero(u) : f.resid_prototype + f(fu, u, p) + else + fu = _mutable(f(u, p)) + end + return fu +end + +evaluate_f(cache, u; fu = nothing) = evaluate_f(cache.f, u, cache.p, Val(cache.iip); fu) + +function evaluate_f(f, u, p, ::Val{iip}; fu = nothing) where {iip} + if iip + f(fu, u, p) + return fu + else + return f(u, p) + end +end diff --git a/test/basictests.jl b/test/basictests.jl index 11e64307d..c31be05fa 100644 --- a/test/basictests.jl +++ b/test/basictests.jl @@ -21,41 +21,49 @@ end # --- NewtonRaphson tests --- @testset "NewtonRaphson" begin - function benchmark_nlsolve_oop(f, u0, p = 2.0) + function benchmark_nlsolve_oop(f, u0, p = 2.0; linesearch = LineSearch()) prob = NonlinearProblem{false}(f, u0, p) - return solve(prob, NewtonRaphson(), abstol = 1e-9) + return solve(prob, NewtonRaphson(; linesearch), abstol = 1e-9) end - function benchmark_nlsolve_iip(f, u0, p = 2.0; linsolve, precs) + function benchmark_nlsolve_iip(f, u0, p = 2.0; linsolve, precs, + linesearch = LineSearch()) prob = NonlinearProblem{true}(f, u0, p) - return solve(prob, NewtonRaphson(; linsolve, precs), abstol = 1e-9) + return solve(prob, NewtonRaphson(; linsolve, precs, linesearch), abstol = 1e-9) end - u0s = VERSION ≥ v"1.9" ? ([1.0, 1.0], @SVector[1.0, 1.0], 1.0) : ([1.0, 1.0], 1.0) - @testset "[OOP] u0: $(typeof(u0))" for u0 in u0s - sol = benchmark_nlsolve_oop(quadratic_f, u0) - @test SciMLBase.successful_retcode(sol) - @test all(abs.(sol.u .* sol.u .- 2) .< 1e-9) + @testset "LineSearch: $(_nameof(lsmethod)) LineSearch AD: $(_nameof(ad))" for lsmethod in (Static(), + StrongWolfe(), BackTracking(), HagerZhang(), MoreThuente()), + ad in (AutoFiniteDiff(), AutoZygote()) - cache = init(NonlinearProblem{false}(quadratic_f, u0, 2.0), NewtonRaphson(), - abstol = 1e-9) - @test (@ballocated solve!($cache)) < 200 - end + linesearch = LineSearch(; method = lsmethod, autodiff = ad) + u0s = VERSION ≥ v"1.9" ? ([1.0, 1.0], @SVector[1.0, 1.0], 1.0) : ([1.0, 1.0], 1.0) - precs = [NonlinearSolve.DEFAULT_PRECS, :Random] + @testset "[OOP] u0: $(typeof(u0))" for u0 in u0s + sol = benchmark_nlsolve_oop(quadratic_f, u0; linesearch) + @test SciMLBase.successful_retcode(sol) + @test all(abs.(sol.u .* sol.u .- 2) .< 1e-9) - @testset "[IIP] u0: $(typeof(u0)) precs: $(_nameof(prec)) linsolve: $(_nameof(linsolve))" for u0 in ([ - 1.0, 1.0],), prec in precs, linsolve in (nothing, KrylovJL_GMRES()) - if prec === :Random - prec = (args...) -> (Diagonal(randn!(similar(u0))), nothing) + cache = init(NonlinearProblem{false}(quadratic_f, u0, 2.0), NewtonRaphson(), + abstol = 1e-9) + @test (@ballocated solve!($cache)) < 200 end - sol = benchmark_nlsolve_iip(quadratic_f!, u0; linsolve, precs = prec) - @test SciMLBase.successful_retcode(sol) - @test all(abs.(sol.u .* sol.u .- 2) .< 1e-9) - cache = init(NonlinearProblem{true}(quadratic_f!, u0, 2.0), - NewtonRaphson(; linsolve, precs = prec), abstol = 1e-9) - @test (@ballocated solve!($cache)) ≤ 64 + precs = [NonlinearSolve.DEFAULT_PRECS, :Random] + + @testset "[IIP] u0: $(typeof(u0)) precs: $(_nameof(prec)) linsolve: $(_nameof(linsolve))" for u0 in ([ + 1.0, 1.0],), prec in precs, linsolve in (nothing, KrylovJL_GMRES()) + if prec === :Random + prec = (args...) -> (Diagonal(randn!(similar(u0))), nothing) + end + sol = benchmark_nlsolve_iip(quadratic_f!, u0; linsolve, precs = prec, linesearch) + @test SciMLBase.successful_retcode(sol) + @test all(abs.(sol.u .* sol.u .- 2) .< 1e-9) + + cache = init(NonlinearProblem{true}(quadratic_f!, u0, 2.0), + NewtonRaphson(; linsolve, precs = prec), abstol = 1e-9) + @test (@ballocated solve!($cache)) ≤ 64 + end end if VERSION ≥ v"1.9"