From b821eafd8d198bc133a720b3de66b4808bd1d753 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 20 Oct 2023 20:30:30 -0400 Subject: [PATCH] Fast General Klement Implementation --- docs/src/solvers/NonlinearSystemSolvers.md | 2 + src/NonlinearSolve.jl | 13 +- src/broyden.jl | 5 +- src/default.jl | 4 +- src/klement.jl | 190 +++++++++++++++++++++ test/23_test_problems.jl | 11 ++ test/basictests.jl | 89 ++++++++++ 7 files changed, 302 insertions(+), 12 deletions(-) diff --git a/docs/src/solvers/NonlinearSystemSolvers.md b/docs/src/solvers/NonlinearSystemSolvers.md index a39ab597e..89cde3513 100644 --- a/docs/src/solvers/NonlinearSystemSolvers.md +++ b/docs/src/solvers/NonlinearSystemSolvers.md @@ -67,6 +67,8 @@ features, but have a bit of overhead on very small problems. robustnes on the hard problems. - `GeneralBroyden()`: Generalization of Broyden's Quasi-Newton Method with Line Search and Automatic Jacobian Resetting. This is a fast method but unstable for most problems! + - `GeneralKlement()`: Generalization of Klement's Quasi-Newton Method with Line Search and + Automatic Jacobian Resetting. This is a fast method but unstable for most problems! ### SimpleNonlinearSolve.jl diff --git a/src/NonlinearSolve.jl b/src/NonlinearSolve.jl index 5e21cdf44..2fad1e083 100644 --- a/src/NonlinearSolve.jl +++ b/src/NonlinearSolve.jl @@ -52,10 +52,13 @@ function SciMLBase.solve!(cache::AbstractNonlinearSolveCache) cache.stats.nsteps += 1 end - if cache.stats.nsteps == cache.maxiters - cache.retcode = ReturnCode.MaxIters - else - cache.retcode = ReturnCode.Success + # The solver might have set a different `retcode` + if cache.retcode == ReturnCode.Default + if cache.stats.nsteps == cache.maxiters + cache.retcode = ReturnCode.MaxIters + else + cache.retcode = ReturnCode.Success + end end return SciMLBase.build_solution(cache.prob, cache.alg, cache.u, get_fu(cache); @@ -85,7 +88,7 @@ import PrecompileTools prob = NonlinearProblem{false}((u, p) -> u .* u .- p, T(0.1), T(2)) precompile_algs = (NewtonRaphson(), TrustRegion(), LevenbergMarquardt(), - nothing) + PseudoTransient(), GeneralBroyden(), nothing) for alg in precompile_algs solve(prob, alg, abstol = T(1e-2)) diff --git a/src/broyden.jl b/src/broyden.jl index 1f8bc73bc..6a767a8af 100644 --- a/src/broyden.jl +++ b/src/broyden.jl @@ -113,10 +113,7 @@ function perform_step!(cache::GeneralBroydenCache{false}) cache.dfu = cache.fu2 .- cache.fu if cache.resets < cache.max_resets && (all(x -> abs(x) ≤ 1e-12, cache.du) || all(x -> abs(x) ≤ 1e-12, cache.dfu)) - J⁻¹ = similar(cache.J⁻¹) - fill!(J⁻¹, 0) - J⁻¹[diagind(J⁻¹)] .= T(1) - cache.J⁻¹ = J⁻¹ + cache.J⁻¹ = __init_identity_jacobian(cache.u, cache.fu) cache.resets += 1 else cache.J⁻¹df = _restructure(cache.J⁻¹df, cache.J⁻¹ * _vec(cache.dfu)) diff --git a/src/default.jl b/src/default.jl index 76289cc1c..07effeecb 100644 --- a/src/default.jl +++ b/src/default.jl @@ -159,10 +159,8 @@ end ] else [ - # FIXME: Broyden and Klement are type unstable - # (upstream SimpleNonlinearSolve.jl issue) - !iip ? :(Klement()) : nothing, # Klement not yet implemented for IIP :(GeneralBroyden()), + :(GeneralKlement()), :(NewtonRaphson(; linsolve, precs, adkwargs...)), :(NewtonRaphson(; linsolve, precs, linesearch = BackTracking(), adkwargs...)), :(TrustRegion(; linsolve, precs, adkwargs...)), diff --git a/src/klement.jl b/src/klement.jl index 8b1378917..42cac5bbe 100644 --- a/src/klement.jl +++ b/src/klement.jl @@ -1 +1,191 @@ +@concrete struct GeneralKlement <: AbstractNewtonAlgorithm{false, Nothing} + max_resets::Int + linsolve + precs + linesearch + singular_tolerance +end +function GeneralKlement(; max_resets::Int = 5, linsolve = nothing, + linesearch = LineSearch(), precs = DEFAULT_PRECS, singular_tolerance = nothing) + linesearch = linesearch isa LineSearch ? linesearch : LineSearch(; method = linesearch) + return GeneralKlement(max_resets, linsolve, precs, linesearch, singular_tolerance) +end + +@concrete mutable struct GeneralKlementCache{iip} <: AbstractNonlinearSolveCache{iip} + f + alg + u + fu + fu2 + du + p + linsolve + J + J_cache + J_cache2 + Jᵀ²du + Jdu + resets + singular_tolerance + force_stop + maxiters::Int + internalnorm + retcode::ReturnCode.T + abstol + prob + stats::NLStats + lscache +end + +get_fu(cache::GeneralKlementCache) = cache.fu + +function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::GeneralKlement, args...; + alias_u0 = false, maxiters = 1000, abstol = 1e-6, internalnorm = DEFAULT_NORM, + linsolve_kwargs = (;), kwargs...) where {uType, iip} + @unpack f, u0, p = prob + u = alias_u0 ? u0 : deepcopy(u0) + fu = evaluate_f(prob, u) + J = __init_identity_jacobian(u, fu) + + if u isa Number + linsolve = nothing + else + weight = similar(u) + recursivefill!(weight, true) + Pl, Pr = wrapprecs(alg.precs(J, nothing, u, p, nothing, nothing, nothing, nothing, + nothing)..., weight) + linprob = LinearProblem(J, _vec(fu); u0 = _vec(fu)) + linsolve = init(linprob, alg.linsolve; alias_A = true, alias_b = true, Pl, Pr, + linsolve_kwargs...) + end + + singular_tolerance = alg.singular_tolerance === nothing ? inv(sqrt(eps(eltype(u)))) : + eltype(u)(alg.singular_tolerance) + + return GeneralKlementCache{iip}(f, alg, u, fu, zero(fu), _mutable_zero(u), p, linsolve, + J, zero(J), zero(J), zero(fu), zero(fu), 0, singular_tolerance, false, + maxiters, internalnorm, ReturnCode.Default, abstol, prob, NLStats(1, 0, 0, 0, 0), + init_linesearch_cache(alg.linesearch, f, u, p, fu, Val(iip))) +end + +function perform_step!(cache::GeneralKlementCache{true}) + @unpack u, fu, f, p, alg, J, linsolve, du = cache + T = eltype(J) + + # FIXME: How can we do this faster? + if cond(J) > cache.singular_tolerance + if cache.resets == alg.max_resets + cache.force_stop = true + cache.retcode = ReturnCode.Unstable + return nothing + end + fill!(J, zero(T)) + J[diagind(J)] .= T(1) + cache.resets += 1 + end + + # u = u - J \ fu + linres = dolinsolve(alg.precs, linsolve; A = J, b = -_vec(fu), linu = _vec(du), + p, reltol = cache.abstol) + cache.linsolve = linres.cache + + # Line Search + α = perform_linesearch!(cache.lscache, u, du) + axpy!(α, du, u) + f(cache.fu2, u, p) + + cache.internalnorm(cache.fu2) < cache.abstol && (cache.force_stop = true) + cache.stats.nf += 1 + cache.stats.nsolve += 1 + cache.stats.nfactors += 1 + + cache.force_stop && return nothing + + # Update the Jacobian + cache.J_cache .= cache.J' .^ 2 + cache.Jdu .= _vec(du) .^ 2 + mul!(cache.Jᵀ²du, cache.J_cache, cache.Jdu) + mul!(cache.Jdu, J, _vec(du)) + cache.fu .= cache.fu2 .- cache.fu + cache.fu .= (cache.fu .- _restructure(cache.fu, cache.Jdu)) ./ max.(cache.Jᵀ²du, eps(T)) + mul!(cache.J_cache, _vec(cache.fu), _vec(du)') + cache.J_cache .*= J + mul!(cache.J_cache2, cache.J_cache, J) + J .+= cache.J_cache2 + + cache.fu .= cache.fu2 + + return nothing +end + +function perform_step!(cache::GeneralKlementCache{false}) + @unpack fu, f, p, alg, J, linsolve = cache + T = eltype(J) + + # FIXME: How can we do this faster? + if cond(J) > cache.singular_tolerance + if cache.resets == alg.max_resets + cache.force_stop = true + cache.retcode = ReturnCode.Unstable + return nothing + end + cache.J = __init_identity_jacobian(u, fu) + cache.resets += 1 + end + + # u = u - J \ fu + if linsolve === nothing + cache.du = -fu / cache.J + else + linres = dolinsolve(alg.precs, linsolve; A = J, b = -_vec(fu), + linu = _vec(cache.du), p, reltol = cache.abstol) + cache.linsolve = linres.cache + end + + # Line Search + α = perform_linesearch!(cache.lscache, cache.u, cache.du) + cache.u = @. cache.u + α * cache.du # `u` might not support mutation + cache.fu2 = f(cache.u, p) + + cache.internalnorm(cache.fu2) < cache.abstol && (cache.force_stop = true) + cache.stats.nf += 1 + cache.stats.nsolve += 1 + cache.stats.nfactors += 1 + + cache.force_stop && return nothing + + # Update the Jacobian + cache.J_cache = cache.J' .^ 2 + cache.Jdu = _vec(cache.du) .^ 2 + cache.Jᵀ²du = cache.J_cache * cache.Jdu + cache.Jdu = J * _vec(cache.du) + cache.fu = cache.fu2 .- cache.fu + cache.fu = (cache.fu .- _restructure(cache.fu, cache.Jdu)) ./ max.(cache.Jᵀ²du, eps(T)) + cache.J_cache = ((_vec(cache.fu) * _vec(cache.du)') .* J) * J + cache.J = J .+ cache.J_cache + + cache.fu = cache.fu2 + + return nothing +end + +function SciMLBase.reinit!(cache::GeneralKlementCache{iip}, u0 = cache.u; p = cache.p, + abstol = cache.abstol, maxiters = cache.maxiters) where {iip} + cache.p = p + if iip + recursivecopy!(cache.u, u0) + cache.f(cache.fu, cache.u, p) + else + # don't have alias_u0 but cache.u is never mutated for OOP problems so it doesn't matter + cache.u = u0 + cache.fu = cache.f(cache.u, p) + end + cache.abstol = abstol + cache.maxiters = maxiters + cache.stats.nf = 1 + cache.stats.nsteps = 1 + cache.force_stop = false + cache.retcode = ReturnCode.Default + return cache +end diff --git a/test/23_test_problems.jl b/test/23_test_problems.jl index 8c739c085..39ae155a3 100644 --- a/test/23_test_problems.jl +++ b/test/23_test_problems.jl @@ -92,3 +92,14 @@ end test_on_library(problems, dicts, alg_ops, broken_tests) end + +@testset "GeneralKlement 23 Test Problems" begin + alg_ops = (GeneralKlement(), + GeneralKlement(; linesearch = BackTracking())) + + broken_tests = Dict(alg => Int[] for alg in alg_ops) + broken_tests[alg_ops[1]] = [1, 2, 3, 4, 5, 6, 7, 13, 22] + broken_tests[alg_ops[2]] = [1, 2, 4, 5, 6, 7, 11, 12, 22] + + test_on_library(problems, dicts, alg_ops, broken_tests) +end diff --git a/test/basictests.jl b/test/basictests.jl index 22ba9ab25..a39d25656 100644 --- a/test/basictests.jl +++ b/test/basictests.jl @@ -754,3 +754,92 @@ end @test nlprob_iterator_interface(quadratic_f, p, Val(false)) ≈ sqrt.(p) @test nlprob_iterator_interface(quadratic_f!, p, Val(true)) ≈ sqrt.(p) end + +# --- GeneralKlement tests --- + +@testset "GeneralKlement" begin + function benchmark_nlsolve_oop(f, u0, p = 2.0; linesearch = LineSearch()) + prob = NonlinearProblem{false}(f, u0, p) + return solve(prob, GeneralKlement(; linesearch), abstol = 1e-9) + end + + function benchmark_nlsolve_iip(f, u0, p = 2.0; linesearch = LineSearch()) + prob = NonlinearProblem{true}(f, u0, p) + return solve(prob, GeneralKlement(; linesearch), abstol = 1e-9) + end + + @testset "LineSearch: $(_nameof(lsmethod)) LineSearch AD: $(_nameof(ad))" for lsmethod in (Static(), + StrongWolfe(), BackTracking(), HagerZhang(), MoreThuente()), + ad in (AutoFiniteDiff(), AutoZygote()) + + linesearch = LineSearch(; method = lsmethod, autodiff = ad) + u0s = ([1.0, 1.0], @SVector[1.0, 1.0], 1.0) + + @testset "[OOP] u0: $(typeof(u0))" for u0 in u0s + sol = benchmark_nlsolve_oop(quadratic_f, u0; linesearch) + @test SciMLBase.successful_retcode(sol) + @test all(abs.(sol.u .* sol.u .- 2) .< 1e-9) + + cache = init(NonlinearProblem{false}(quadratic_f, u0, 2.0), + GeneralKlement(; linesearch), abstol = 1e-9) + @test (@ballocated solve!($cache)) < 200 + end + + @testset "[IIP] u0: $(typeof(u0))" for u0 in ([1.0, 1.0],) + ad isa AutoZygote && continue + sol = benchmark_nlsolve_iip(quadratic_f!, u0; linesearch) + @test SciMLBase.successful_retcode(sol) + @test all(abs.(sol.u .* sol.u .- 2) .< 1e-9) + + cache = init(NonlinearProblem{true}(quadratic_f!, u0, 2.0), + GeneralKlement(; linesearch), abstol = 1e-9) + @test (@ballocated solve!($cache)) ≤ 64 + end + end + + @testset "[OOP] [Immutable AD]" begin + for p in 1.0:0.1:100.0 + @test begin + res = benchmark_nlsolve_oop(quadratic_f, @SVector[1.0, 1.0], p) + res_true = sqrt(p) + all(res.u .≈ res_true) + end + @test ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f, + @SVector[1.0, 1.0], p).u[end], p) ≈ 1 / (2 * sqrt(p)) + end + end + + @testset "[OOP] [Scalar AD]" begin + for p in 1.0:0.1:100.0 + @test begin + res = benchmark_nlsolve_oop(quadratic_f, 1.0, p) + res_true = sqrt(p) + res.u ≈ res_true + end + @test ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f, 1.0, p).u, + p) ≈ 1 / (2 * sqrt(p)) + end + end + + t = (p) -> [sqrt(p[2] / p[1])] + p = [0.9, 50.0] + @test benchmark_nlsolve_oop(quadratic_f2, 0.5, p).u ≈ sqrt(p[2] / p[1]) + @test ForwardDiff.jacobian(p -> [benchmark_nlsolve_oop(quadratic_f2, 0.5, p).u], + p) ≈ ForwardDiff.jacobian(t, p) + + # Iterator interface + function nlprob_iterator_interface(f, p_range, ::Val{iip}) where {iip} + probN = NonlinearProblem{iip}(f, iip ? [0.5] : 0.5, p_range[begin]) + cache = init(probN, GeneralKlement(); maxiters = 100, abstol = 1e-10) + sols = zeros(length(p_range)) + for (i, p) in enumerate(p_range) + reinit!(cache, iip ? [cache.u[1]] : cache.u; p = p) + sol = solve!(cache) + sols[i] = iip ? sol.u[1] : sol.u + end + return sols + end + p = range(0.01, 2, length = 200) + @test nlprob_iterator_interface(quadratic_f, p, Val(false)) ≈ sqrt.(p) + @test nlprob_iterator_interface(quadratic_f!, p, Val(true)) ≈ sqrt.(p) +end