From 00bded1c7736114260a984fee8ddb6975c997c4a Mon Sep 17 00:00:00 2001 From: Songchen Tan Date: Fri, 1 Nov 2024 14:30:47 -0400 Subject: [PATCH] Make it a ext --- Project.toml | 5 +- lib/NonlinearSolveBase/Project.toml | 2 + .../ext/NonlinearSolveBaseTaylorDiffExt.jl | 20 +++++ .../src/NonlinearSolveBase.jl | 1 + lib/NonlinearSolveBase/src/descent/halley.jl | 78 +++++++++---------- .../src/NonlinearSolveFirstOrder.jl | 5 +- lib/NonlinearSolveFirstOrder/src/halley.jl | 4 +- test/23_test_problems_tests.jl | 9 ++- 8 files changed, 74 insertions(+), 50 deletions(-) create mode 100644 lib/NonlinearSolveBase/ext/NonlinearSolveBaseTaylorDiffExt.jl diff --git a/Project.toml b/Project.toml index 76ccfd389..32cd0b054 100644 --- a/Project.toml +++ b/Project.toml @@ -25,6 +25,7 @@ PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" Preferences = "21216c6a-2e73-6563-6e65-726566657250" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" +SciMLJacobianOperators = "19f34311-ddf3-4b8b-af20-060888a46c0e" SimpleNonlinearSolve = "727e6d20-b764-4bd8-a329-72de5adea6c7" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" SparseMatrixColorings = "0a514795-09f3-496d-8182-132a7b665d35" @@ -113,7 +114,6 @@ StaticArrays = "1.9" StaticArraysCore = "1.4" Sundials = "4.23.1" SymbolicIndexingInterface = "0.3.31" -Symbolics = "6" TaylorDiff = "0.3" Test = "1.10" Zygote = "0.6.69" @@ -148,8 +148,9 @@ SpeedMapping = "f1835b91-879b-4a3f-a438-e4baacf14412" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Sundials = "c3572dad-4567-51f8-b174-8c6c989267f4" +TaylorDiff = "b36ab563-344f-407b-a36a-4f200bebf99c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["Aqua", "BandedMatrices", "BenchmarkTools", "CUDA", "Enzyme", "ExplicitImports", "FastLevenbergMarquardt", "FixedPointAcceleration", "Hwloc", "InteractiveUtils", "LeastSquaresOptim", "LineSearches", "MINPACK", "NLSolvers", "NLsolve", "NaNMath", "NonlinearProblemLibrary", "OrdinaryDiffEqTsit5", "PETSc", "Pkg", "Random", "ReTestItems", "SIAMFANLEquations", "SparseConnectivityTracer", "SpeedMapping", "StableRNGs", "StaticArrays", "Sundials", "Test", "Zygote"] +test = ["Aqua", "BandedMatrices", "BenchmarkTools", "CUDA", "Enzyme", "ExplicitImports", "FastLevenbergMarquardt", "FixedPointAcceleration", "Hwloc", "InteractiveUtils", "LeastSquaresOptim", "LineSearches", "MINPACK", "NLSolvers", "NLsolve", "NaNMath", "NonlinearProblemLibrary", "OrdinaryDiffEqTsit5", "PETSc", "Pkg", "Random", "ReTestItems", "SIAMFANLEquations", "SparseConnectivityTracer", "SpeedMapping", "StableRNGs", "StaticArrays", "Sundials", "TaylorDiff", "Test", "Zygote"] diff --git a/lib/NonlinearSolveBase/Project.toml b/lib/NonlinearSolveBase/Project.toml index 4abb81cde..bf603d9e9 100644 --- a/lib/NonlinearSolveBase/Project.toml +++ b/lib/NonlinearSolveBase/Project.toml @@ -35,6 +35,7 @@ LineSearch = "87fe0de2-c867-4266-b59a-2f0a94fc965b" LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" SparseMatrixColorings = "0a514795-09f3-496d-8182-132a7b665d35" +TaylorDiff = "b36ab563-344f-407b-a36a-4f200bebf99c" [extensions] NonlinearSolveBaseBandedMatricesExt = "BandedMatrices" @@ -44,6 +45,7 @@ NonlinearSolveBaseLineSearchExt = "LineSearch" NonlinearSolveBaseLinearSolveExt = "LinearSolve" NonlinearSolveBaseSparseArraysExt = "SparseArrays" NonlinearSolveBaseSparseMatrixColoringsExt = "SparseMatrixColorings" +NonlinearSolveBaseTaylorDiffExt = "TaylorDiff" [compat] ADTypes = "1.9" diff --git a/lib/NonlinearSolveBase/ext/NonlinearSolveBaseTaylorDiffExt.jl b/lib/NonlinearSolveBase/ext/NonlinearSolveBaseTaylorDiffExt.jl new file mode 100644 index 000000000..1a4723527 --- /dev/null +++ b/lib/NonlinearSolveBase/ext/NonlinearSolveBaseTaylorDiffExt.jl @@ -0,0 +1,20 @@ +module NonlinearSolveBaseTaylorDiffExt +using SciMLBase: NonlinearFunction +using NonlinearSolveBase: HalleyDescentCache +import NonlinearSolveBase: evaluate_hvvp +using TaylorDiff: derivative, derivative! +using FastClosures: @closure + +function evaluate_hvvp( + hvvp, cache::HalleyDescentCache, f::NonlinearFunction{iip}, p, u, δu) where {iip} + if iip + binary_f = @closure (y, x) -> f(y, x, p) + derivative!(hvvp, binary_f, cache.fu, u, δu, Val(2)) + else + unary_f = Base.Fix2(f, p) + hvvp = derivative(unary_f, u, δu, Val(2)) + end + hvvp +end + +end diff --git a/lib/NonlinearSolveBase/src/NonlinearSolveBase.jl b/lib/NonlinearSolveBase/src/NonlinearSolveBase.jl index a08384677..5d43216a2 100644 --- a/lib/NonlinearSolveBase/src/NonlinearSolveBase.jl +++ b/lib/NonlinearSolveBase/src/NonlinearSolveBase.jl @@ -50,6 +50,7 @@ include("wrappers.jl") include("descent/common.jl") include("descent/newton.jl") +include("descent/halley.jl") include("descent/steepest.jl") include("descent/damped_newton.jl") include("descent/dogleg.jl") diff --git a/lib/NonlinearSolveBase/src/descent/halley.jl b/lib/NonlinearSolveBase/src/descent/halley.jl index 596d78f19..de5f5eecc 100644 --- a/lib/NonlinearSolveBase/src/descent/halley.jl +++ b/lib/NonlinearSolveBase/src/descent/halley.jl @@ -1,70 +1,71 @@ """ - HalleyDescent(; linsolve = nothing, precs = DEFAULT_PRECS) + HalleyDescent(; linsolve = nothing) Improve the NewtonDescent with higher-order terms. First compute the descent direction as ``J a = -fu``. Then compute the hessian-vector-vector product and solve for the second-order correction term as ``J b = H a a``. Finally, compute the descent direction as ``δu = a * a / (b / 2 - a)``. +Note that `import TaylorDiff` is required to use this descent algorithm. + See also [`NewtonDescent`](@ref). """ -@kwdef @concrete struct HalleyDescent <: AbstractDescentAlgorithm +@kwdef @concrete struct HalleyDescent <: AbstractDescentDirection linsolve = nothing - precs = DEFAULT_PRECS -end - -using TaylorDiff: derivative - -function Base.show(io::IO, d::HalleyDescent) - modifiers = String[] - d.linsolve !== nothing && push!(modifiers, "linsolve = $(d.linsolve)") - d.precs !== DEFAULT_PRECS && push!(modifiers, "precs = $(d.precs)") - print(io, "HalleyDescent($(join(modifiers, ", ")))") end supports_line_search(::HalleyDescent) = true -@concrete mutable struct HalleyDescentCache{pre_inverted} <: AbstractDescentCache +@concrete mutable struct HalleyDescentCache <: AbstractDescentCache f p δu δus b fu + hvvp lincache timer + preinverted_jacobian <: Union{Val{false}, Val{true}} end @internal_caches HalleyDescentCache :lincache -function __internal_init(prob::NonlinearProblem, alg::HalleyDescent, J, fu, u; stats, - shared::Val{N} = Val(1), pre_inverted::Val{INV} = False, +function InternalAPI.init( + prob::NonlinearProblem, alg::HalleyDescent, J, fu, u; stats, + shared = Val(1), pre_inverted::Val = Val(false), linsolve_kwargs = (;), abstol = nothing, reltol = nothing, - timer = get_timer_output(), kwargs...) where {INV, N} + timer = get_timer_output(), kwargs...) @bb δu = similar(u) @bb b = similar(u) @bb fu = similar(fu) - δus = N ≤ 1 ? nothing : map(2:N) do i + @bb hvvp = similar(fu) + δus = Utils.unwrap_val(shared) ≤ 1 ? nothing : map(2:Utils.unwrap_val(shared)) do i @bb δu_ = similar(u) end - INV && return HalleyDescentCache{true}(prob.f, prob.p, δu, δus, b, nothing, timer) - lincache = LinearSolverCache( - alg, alg.linsolve, J, _vec(fu), _vec(u); stats, abstol, reltol, linsolve_kwargs...) - return HalleyDescentCache{false}(prob.f, prob.p, δu, δus, b, fu, lincache, timer) + lincache = Utils.unwrap_val(pre_inverted) ? nothing : + construct_linear_solver( + alg, alg.linsolve, J, Utils.safe_vec(fu), Utils.safe_vec(u); + stats, abstol, reltol, linsolve_kwargs... + ) + return HalleyDescentCache( + prob.f, prob.p, δu, δus, b, fu, hvvp, lincache, timer, pre_inverted) end -function __internal_solve!(cache::HalleyDescentCache{INV}, J, fu, u, idx::Val = Val(1); - skip_solve::Bool = false, new_jacobian::Bool = true, kwargs...) where {INV} - δu = get_du(cache, idx) +function InternalAPI.solve!( + cache::HalleyDescentCache, J, fu, u, idx::Val = Val(1); + skip_solve::Bool = false, new_jacobian::Bool = true, kwargs...) + δu = SciMLBase.get_du(cache, idx) skip_solve && return DescentResult(; δu) - if INV + if preinverted_jacobian(cache) @assert J!==nothing "`J` must be provided when `pre_inverted = Val(true)`." @bb δu = J × vec(fu) else @static_timeit cache.timer "linear solve 1" begin linres = cache.lincache(; - A = J, b = _vec(fu), kwargs..., linu = _vec(δu), du = _vec(δu), + A = J, b = Utils.safe_vec(fu), + kwargs..., linu = Utils.safe_vec(δu), reuse_A_if_factorization = !new_jacobian || (idx !== Val(1))) - δu = _restructure(get_du(cache, idx), linres.u) + δu = Utils.restructure(SciMLBase.get_du(cache, idx), linres.u) if !linres.success set_du!(cache, δu, idx) return DescentResult(; δu, success = false, linsolve_success = false) @@ -73,15 +74,17 @@ function __internal_solve!(cache::HalleyDescentCache{INV}, J, fu, u, idx::Val = end b = cache.b # compute the hessian-vector-vector product - hvvp = evaluate_hvvp(cache, cache.f, cache.p, u, δu) + hvvp = evaluate_hvvp(cache.hvvp, cache, cache.f, cache.p, u, δu) # second linear solve, reuse factorization if possible - if INV + if preinverted_jacobian(cache) @bb b = J × vec(hvvp) else @static_timeit cache.timer "linear solve 2" begin - linres = cache.lincache(; A = J, b = _vec(hvvp), kwargs..., linu = _vec(b), - du = _vec(b), reuse_A_if_factorization = true) - b = _restructure(cache.b, linres.u) + linres = cache.lincache(; + A = J, b = Utils.safe_vec(hvvp), + kwargs..., linu = Utils.safe_vec(b), + reuse_A_if_factorization = true) + b = Utils.restructure(cache.b, linres.u) if !linres.success set_du!(cache, δu, idx) return DescentResult(; δu, success = false, linsolve_success = false) @@ -94,13 +97,4 @@ function __internal_solve!(cache::HalleyDescentCache{INV}, J, fu, u, idx::Val = return DescentResult(; δu) end -function evaluate_hvvp( - cache::HalleyDescentCache, f::NonlinearFunction{iip}, p, u, δu) where {iip} - if iip - binary_f = @closure (y, x) -> f(y, x, p) - derivative(binary_f, cache.fu, u, δu, Val{3}()) - else - unary_f = Base.Fix2(f, p) - derivative(unary_f, u, δu, Val{3}()) - end -end +evaluate_hvvp(hvvp, cache, f, p, u, δu) = error("not implemented. please import TaylorDiff") diff --git a/lib/NonlinearSolveFirstOrder/src/NonlinearSolveFirstOrder.jl b/lib/NonlinearSolveFirstOrder/src/NonlinearSolveFirstOrder.jl index 145468122..c07f03ae9 100644 --- a/lib/NonlinearSolveFirstOrder/src/NonlinearSolveFirstOrder.jl +++ b/lib/NonlinearSolveFirstOrder/src/NonlinearSolveFirstOrder.jl @@ -20,7 +20,7 @@ using NonlinearSolveBase: NonlinearSolveBase, AbstractNonlinearSolveAlgorithm, AbstractTrustRegionMethodCache, Utils, InternalAPI, get_timer_output, @static_timeit, update_trace!, L2_NORM, - NewtonDescent, DampedNewtonDescent, GeodesicAcceleration, + NewtonDescent, DampedNewtonDescent, HalleyDescent, GeodesicAcceleration, Dogleg using SciMLBase: SciMLBase, AbstractNonlinearProblem, NLStats, ReturnCode, NonlinearFunction, @@ -31,6 +31,7 @@ using FiniteDiff: FiniteDiff # Default Finite Difference Method using ForwardDiff: ForwardDiff # Default Forward Mode AD include("raphson.jl") +include("halley.jl") include("gauss_newton.jl") include("levenberg_marquardt.jl") include("trust_region.jl") @@ -93,7 +94,7 @@ end @reexport using SciMLBase, NonlinearSolveBase -export NewtonRaphson, PseudoTransient +export NewtonRaphson, Halley, PseudoTransient export GaussNewton, LevenbergMarquardt, TrustRegion export RadiusUpdateSchemes diff --git a/lib/NonlinearSolveFirstOrder/src/halley.jl b/lib/NonlinearSolveFirstOrder/src/halley.jl index bf9714f7b..9f099d726 100644 --- a/lib/NonlinearSolveFirstOrder/src/halley.jl +++ b/lib/NonlinearSolveFirstOrder/src/halley.jl @@ -1,6 +1,6 @@ """ - Halley(; concrete_jac = nothing, linsolve = nothing, linesearch = NoLineSearch(), - precs = DEFAULT_PRECS, autodiff = nothing) + Halley(; concrete_jac = nothing, linsolve = nothing, linesearch = missing, + autodiff = nothing) An experimental Halley's method implementation. Improves the convergence rate of Newton's method by using second-order derivative information to correct the descent direction. diff --git a/test/23_test_problems_tests.jl b/test/23_test_problems_tests.jl index 8fa4c47b6..2b882d643 100644 --- a/test/23_test_problems_tests.jl +++ b/test/23_test_problems_tests.jl @@ -1,5 +1,6 @@ @testsetup module RobustnessTesting using NonlinearSolve, LinearAlgebra, LinearSolve, NonlinearProblemLibrary, Test +import TaylorDiff problems = NonlinearProblemLibrary.problems dicts = NonlinearProblemLibrary.dicts @@ -61,10 +62,14 @@ end end @testitem "23 Test Problems: Halley" setup=[RobustnessTesting] tags=[:core] begin - alg_ops = (SimpleHalley(; autodiff = AutoForwardDiff()),) + alg_ops = ( + Halley(), + SimpleHalley(; autodiff = AutoForwardDiff()) + ) broken_tests = Dict(alg => Int[] for alg in alg_ops) - broken_tests[alg_ops[1]] = [1, 5, 15, 16, 18] + broken_tests[alg_ops[1]] = [1, 5, 15, 16] + broken_tests[alg_ops[2]] = [1, 5, 15, 16, 18] test_on_library(problems, dicts, alg_ops, broken_tests) end