From 342b1ec095c0fe1a84b69e5183d81d0f25adb5e3 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 28 Oct 2024 23:07:10 -0400 Subject: [PATCH] refactor(SimpleNonlinearSolve): reuse more code from NLB --- Project.toml | 1 + lib/BracketingNonlinearSolve/Project.toml | 2 +- .../NonlinearSolveBaseBandedMatricesExt.jl | 1 + .../ext/NonlinearSolveBaseForwardDiffExt.jl | 18 ++-- .../ext/NonlinearSolveBaseLineSearchExt.jl | 3 +- .../ext/NonlinearSolveBaseLinearSolveExt.jl | 10 +- .../ext/NonlinearSolveBaseSparseArraysExt.jl | 3 +- ...linearSolveBaseSparseMatrixColoringsExt.jl | 4 +- lib/NonlinearSolveBase/src/utils.jl | 4 +- lib/SimpleNonlinearSolve/Project.toml | 8 +- .../SimpleNonlinearSolveChainRulesCoreExt.jl | 13 ++- .../ext/SimpleNonlinearSolveReverseDiffExt.jl | 5 +- .../ext/SimpleNonlinearSolveTrackerExt.jl | 3 +- .../src/SimpleNonlinearSolve.jl | 92 +++++++++++-------- lib/SimpleNonlinearSolve/src/broyden.jl | 25 ++--- lib/SimpleNonlinearSolve/src/dfsane.jl | 45 +++++---- lib/SimpleNonlinearSolve/src/halley.jl | 41 +++++---- lib/SimpleNonlinearSolve/src/klement.jl | 16 ++-- lib/SimpleNonlinearSolve/src/lbroyden.jl | 86 +++++++++-------- lib/SimpleNonlinearSolve/src/raphson.jl | 18 ++-- lib/SimpleNonlinearSolve/src/trust_region.jl | 66 +++++++------ lib/SimpleNonlinearSolve/src/utils.jl | 90 +++++------------- src/NonlinearSolve.jl | 2 +- 23 files changed, 295 insertions(+), 261 deletions(-) diff --git a/Project.toml b/Project.toml index 6d907607e..768a46005 100644 --- a/Project.toml +++ b/Project.toml @@ -6,6 +6,7 @@ version = "4.0.0" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" +BracketingNonlinearSolve = "70df07ce-3d50-431d-a3e7-ca6ddb60ac1e" CommonSolve = "38540f10-b2f7-11e9-35d8-d573e4eb0ff2" ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471" DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e" diff --git a/lib/BracketingNonlinearSolve/Project.toml b/lib/BracketingNonlinearSolve/Project.toml index eb81fd7f7..f2a8e2b6d 100644 --- a/lib/BracketingNonlinearSolve/Project.toml +++ b/lib/BracketingNonlinearSolve/Project.toml @@ -24,7 +24,7 @@ ConcreteStructs = "0.2.3" ExplicitImports = "1.10.1" ForwardDiff = "0.10.36" InteractiveUtils = "<0.0.1, 1" -NonlinearSolveBase = "1" +NonlinearSolveBase = "1.1" PrecompileTools = "1.2" Reexport = "1.2" SciMLBase = "2.50" diff --git a/lib/NonlinearSolveBase/ext/NonlinearSolveBaseBandedMatricesExt.jl b/lib/NonlinearSolveBase/ext/NonlinearSolveBaseBandedMatricesExt.jl index 7f2ac7f90..93f01f51f 100644 --- a/lib/NonlinearSolveBase/ext/NonlinearSolveBaseBandedMatricesExt.jl +++ b/lib/NonlinearSolveBase/ext/NonlinearSolveBaseBandedMatricesExt.jl @@ -2,6 +2,7 @@ module NonlinearSolveBaseBandedMatricesExt using BandedMatrices: BandedMatrix using LinearAlgebra: Diagonal + using NonlinearSolveBase: NonlinearSolveBase, Utils # This is used if we vcat a Banded Jacobian with a Diagonal Matrix in Levenberg diff --git a/lib/NonlinearSolveBase/ext/NonlinearSolveBaseForwardDiffExt.jl b/lib/NonlinearSolveBase/ext/NonlinearSolveBaseForwardDiffExt.jl index c4f1dc901..0b16391c4 100644 --- a/lib/NonlinearSolveBase/ext/NonlinearSolveBaseForwardDiffExt.jl +++ b/lib/NonlinearSolveBase/ext/NonlinearSolveBaseForwardDiffExt.jl @@ -25,7 +25,8 @@ Utils.value(x::AbstractArray{<:Dual}) = Utils.value.(x) function NonlinearSolveBase.nonlinearsolve_forwarddiff_solve( prob::Union{IntervalNonlinearProblem, NonlinearProblem, ImmutableNonlinearProblem}, - alg, args...; kwargs...) + alg, args...; kwargs... +) p = Utils.value(prob.p) if prob isa IntervalNonlinearProblem tspan = Utils.value.(prob.tspan) @@ -55,7 +56,8 @@ function NonlinearSolveBase.nonlinearsolve_forwarddiff_solve( end function NonlinearSolveBase.nonlinearsolve_forwarddiff_solve( - prob::NonlinearLeastSquaresProblem, alg, args...; kwargs...) + prob::NonlinearLeastSquaresProblem, alg, args...; kwargs... +) p = Utils.value(prob.p) newprob = remake(prob; p, u0 = Utils.value(prob.u0)) sol = solve(newprob, alg, args...; kwargs...) @@ -168,13 +170,17 @@ function NonlinearSolveBase.nonlinearsolve_∂f_∂u(prob, f::F, u, p) where {F} return ForwardDiff.jacobian(Base.Fix2(f, p), u) end -function NonlinearSolveBase.nonlinearsolve_dual_solution(u::Number, partials, - ::Union{<:AbstractArray{<:Dual{T, V, P}}, Dual{T, V, P}}) where {T, V, P} +function NonlinearSolveBase.nonlinearsolve_dual_solution( + u::Number, partials, + ::Union{<:AbstractArray{<:Dual{T, V, P}}, Dual{T, V, P}} +) where {T, V, P} return Dual{T, V, P}(u, partials) end -function NonlinearSolveBase.nonlinearsolve_dual_solution(u::AbstractArray, partials, - ::Union{<:AbstractArray{<:Dual{T, V, P}}, Dual{T, V, P}}) where {T, V, P} +function NonlinearSolveBase.nonlinearsolve_dual_solution( + u::AbstractArray, partials, + ::Union{<:AbstractArray{<:Dual{T, V, P}}, Dual{T, V, P}} +) where {T, V, P} return map(((uᵢ, pᵢ),) -> Dual{T, V, P}(uᵢ, pᵢ), zip(u, Utils.restructure(u, partials))) end diff --git a/lib/NonlinearSolveBase/ext/NonlinearSolveBaseLineSearchExt.jl b/lib/NonlinearSolveBase/ext/NonlinearSolveBaseLineSearchExt.jl index d68007dc0..3b705b0dc 100644 --- a/lib/NonlinearSolveBase/ext/NonlinearSolveBaseLineSearchExt.jl +++ b/lib/NonlinearSolveBase/ext/NonlinearSolveBaseLineSearchExt.jl @@ -1,9 +1,10 @@ module NonlinearSolveBaseLineSearchExt using LineSearch: LineSearch, AbstractLineSearchCache -using NonlinearSolveBase: NonlinearSolveBase, InternalAPI using SciMLBase: SciMLBase +using NonlinearSolveBase: NonlinearSolveBase, InternalAPI + function NonlinearSolveBase.callback_into_cache!( topcache, cache::AbstractLineSearchCache, args... ) diff --git a/lib/NonlinearSolveBase/ext/NonlinearSolveBaseLinearSolveExt.jl b/lib/NonlinearSolveBase/ext/NonlinearSolveBaseLinearSolveExt.jl index 13c4adca5..28b8b1937 100644 --- a/lib/NonlinearSolveBase/ext/NonlinearSolveBaseLinearSolveExt.jl +++ b/lib/NonlinearSolveBase/ext/NonlinearSolveBaseLinearSolveExt.jl @@ -1,15 +1,19 @@ module NonlinearSolveBaseLinearSolveExt using ArrayInterface: ArrayInterface + using CommonSolve: CommonSolve, init, solve! -using LinearAlgebra: ColumnNorm using LinearSolve: LinearSolve, QRFactorization, SciMLLinearSolveAlgorithm -using NonlinearSolveBase: NonlinearSolveBase, LinearSolveJLCache, LinearSolveResult, Utils using SciMLBase: ReturnCode, LinearProblem +using LinearAlgebra: ColumnNorm + +using NonlinearSolveBase: NonlinearSolveBase, LinearSolveJLCache, LinearSolveResult, Utils + function (cache::LinearSolveJLCache)(; A = nothing, b = nothing, linu = nothing, du = nothing, p = nothing, - cachedata = nothing, reuse_A_if_factorization = false, verbose = true, kwargs...) + cachedata = nothing, reuse_A_if_factorization = false, verbose = true, kwargs... +) cache.stats.nsolve += 1 update_A!(cache, A, reuse_A_if_factorization) diff --git a/lib/NonlinearSolveBase/ext/NonlinearSolveBaseSparseArraysExt.jl b/lib/NonlinearSolveBase/ext/NonlinearSolveBaseSparseArraysExt.jl index 09b113c4a..bc7350d21 100644 --- a/lib/NonlinearSolveBase/ext/NonlinearSolveBaseSparseArraysExt.jl +++ b/lib/NonlinearSolveBase/ext/NonlinearSolveBaseSparseArraysExt.jl @@ -1,8 +1,9 @@ module NonlinearSolveBaseSparseArraysExt -using NonlinearSolveBase: NonlinearSolveBase, Utils using SparseArrays: AbstractSparseMatrix, AbstractSparseMatrixCSC, nonzeros, sparse +using NonlinearSolveBase: NonlinearSolveBase, Utils + function NonlinearSolveBase.NAN_CHECK(x::AbstractSparseMatrixCSC) return any(NonlinearSolveBase.NAN_CHECK, nonzeros(x)) end diff --git a/lib/NonlinearSolveBase/ext/NonlinearSolveBaseSparseMatrixColoringsExt.jl b/lib/NonlinearSolveBase/ext/NonlinearSolveBaseSparseMatrixColoringsExt.jl index e2029d7a2..4daf5ea98 100644 --- a/lib/NonlinearSolveBase/ext/NonlinearSolveBaseSparseMatrixColoringsExt.jl +++ b/lib/NonlinearSolveBase/ext/NonlinearSolveBaseSparseMatrixColoringsExt.jl @@ -1,11 +1,13 @@ module NonlinearSolveBaseSparseMatrixColoringsExt using ADTypes: ADTypes, AbstractADType -using NonlinearSolveBase: NonlinearSolveBase, Utils using SciMLBase: SciMLBase, NonlinearFunction + using SparseMatrixColorings: ConstantColoringAlgorithm, GreedyColoringAlgorithm, LargestFirst +using NonlinearSolveBase: NonlinearSolveBase, Utils + Utils.is_extension_loaded(::Val{:SparseMatrixColorings}) = true function NonlinearSolveBase.select_fastest_coloring_algorithm( diff --git a/lib/NonlinearSolveBase/src/utils.jl b/lib/NonlinearSolveBase/src/utils.jl index 6e739c0f8..826c7f66c 100644 --- a/lib/NonlinearSolveBase/src/utils.jl +++ b/lib/NonlinearSolveBase/src/utils.jl @@ -138,7 +138,9 @@ maybe_unaliased(x::AbstractSciMLOperator, ::Bool) = x can_setindex(x) = ArrayInterface.can_setindex(x) can_setindex(::Number) = false -evaluate_f!!(prob::AbstractNonlinearProblem, fu, u, p) = evaluate_f!!(prob.f, fu, u, p) +function evaluate_f!!(prob::AbstractNonlinearProblem, fu, u, p = prob.p) + return evaluate_f!!(prob.f, fu, u, p) +end function evaluate_f!!(f::NonlinearFunction, fu, u, p) if SciMLBase.isinplace(f) f(fu, u, p) diff --git a/lib/SimpleNonlinearSolve/Project.toml b/lib/SimpleNonlinearSolve/Project.toml index cfda24544..c154a4b54 100644 --- a/lib/SimpleNonlinearSolve/Project.toml +++ b/lib/SimpleNonlinearSolve/Project.toml @@ -5,7 +5,6 @@ version = "2.0.0" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" -Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" BracketingNonlinearSolve = "70df07ce-3d50-431d-a3e7-ca6ddb60ac1e" CommonSolve = "38540f10-b2f7-11e9-35d8-d573e4eb0ff2" @@ -21,6 +20,7 @@ NonlinearSolveBase = "be0214bd-f91f-a760-ac4e-3421ce2b2da0" PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" +Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" [weakdeps] @@ -37,10 +37,9 @@ SimpleNonlinearSolveTrackerExt = "Tracker" [compat] ADTypes = "1.2" -Accessors = "0.1" Aqua = "0.8.7" ArrayInterface = "7.16" -BracketingNonlinearSolve = "1" +BracketingNonlinearSolve = "1.1" ChainRulesCore = "1.24" CommonSolve = "0.2.4" ConcreteStructs = "0.2.3" @@ -56,7 +55,7 @@ LineSearch = "0.1.3" LinearAlgebra = "1.10" MaybeInplace = "0.1.4" NonlinearProblemLibrary = "0.1.2" -NonlinearSolveBase = "1" +NonlinearSolveBase = "1.1" Pkg = "1.10" PolyesterForwardDiff = "0.1" PrecompileTools = "1.2" @@ -64,6 +63,7 @@ Random = "1.10" Reexport = "1.2" ReverseDiff = "1.15" SciMLBase = "2.50" +Setfield = "1.1.1" StaticArrays = "1.9" StaticArraysCore = "1.4.3" Test = "1.10" diff --git a/lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveChainRulesCoreExt.jl b/lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveChainRulesCoreExt.jl index f56dee537..50905279a 100644 --- a/lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveChainRulesCoreExt.jl +++ b/lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveChainRulesCoreExt.jl @@ -1,21 +1,26 @@ module SimpleNonlinearSolveChainRulesCoreExt using ChainRulesCore: ChainRulesCore, NoTangent + using NonlinearSolveBase: ImmutableNonlinearProblem using SciMLBase: ChainRulesOriginator, NonlinearLeastSquaresProblem using SimpleNonlinearSolve: SimpleNonlinearSolve, simplenonlinearsolve_solve_up, solve_adjoint -function ChainRulesCore.rrule(::typeof(simplenonlinearsolve_solve_up), +function ChainRulesCore.rrule( + ::typeof(simplenonlinearsolve_solve_up), prob::Union{ImmutableNonlinearProblem, NonlinearLeastSquaresProblem}, - sensealg, u0, u0_changed, p, p_changed, alg, args...; kwargs...) + sensealg, u0, u0_changed, p, p_changed, alg, args...; kwargs... +) out, ∇internal = solve_adjoint( - prob, sensealg, u0, p, ChainRulesOriginator(), alg, args...; kwargs...) + prob, sensealg, u0, p, ChainRulesOriginator(), alg, args...; kwargs... + ) function ∇simplenonlinearsolve_solve_up(Δ) ∂f, ∂prob, ∂sensealg, ∂u0, ∂p, _, ∂args... = ∇internal(Δ) return ( - ∂f, ∂prob, ∂sensealg, ∂u0, NoTangent(), ∂p, NoTangent(), NoTangent(), ∂args...) + ∂f, ∂prob, ∂sensealg, ∂u0, NoTangent(), ∂p, NoTangent(), NoTangent(), ∂args... + ) end return out, ∇simplenonlinearsolve_solve_up end diff --git a/lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveReverseDiffExt.jl b/lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveReverseDiffExt.jl index 0a407986e..d34d8bac7 100644 --- a/lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveReverseDiffExt.jl +++ b/lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveReverseDiffExt.jl @@ -1,10 +1,11 @@ module SimpleNonlinearSolveReverseDiffExt -using ArrayInterface: ArrayInterface using NonlinearSolveBase: ImmutableNonlinearProblem -using ReverseDiff: ReverseDiff, TrackedArray, TrackedReal using SciMLBase: ReverseDiffOriginator, NonlinearLeastSquaresProblem, remake +using ArrayInterface: ArrayInterface +using ReverseDiff: ReverseDiff, TrackedArray, TrackedReal + using SimpleNonlinearSolve: SimpleNonlinearSolve, solve_adjoint import SimpleNonlinearSolve: simplenonlinearsolve_solve_up diff --git a/lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveTrackerExt.jl b/lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveTrackerExt.jl index d29c2ac61..d56854316 100644 --- a/lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveTrackerExt.jl +++ b/lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveTrackerExt.jl @@ -1,8 +1,9 @@ module SimpleNonlinearSolveTrackerExt -using ArrayInterface: ArrayInterface using NonlinearSolveBase: ImmutableNonlinearProblem using SciMLBase: TrackerOriginator, NonlinearLeastSquaresProblem, remake + +using ArrayInterface: ArrayInterface using Tracker: Tracker, TrackedArray, TrackedReal using SimpleNonlinearSolve: SimpleNonlinearSolve, solve_adjoint diff --git a/lib/SimpleNonlinearSolve/src/SimpleNonlinearSolve.jl b/lib/SimpleNonlinearSolve/src/SimpleNonlinearSolve.jl index f51064000..528838568 100644 --- a/lib/SimpleNonlinearSolve/src/SimpleNonlinearSolve.jl +++ b/lib/SimpleNonlinearSolve/src/SimpleNonlinearSolve.jl @@ -1,33 +1,46 @@ module SimpleNonlinearSolve -using Accessors: @reset -using BracketingNonlinearSolve: BracketingNonlinearSolve -using CommonSolve: CommonSolve, solve, init, solve! using ConcreteStructs: @concrete using FastClosures: @closure -using LineSearch: LiFukushimaLineSearch -using LinearAlgebra: LinearAlgebra, dot -using MaybeInplace: @bb, setindex_trait, CannotSetindex, CanSetindex using PrecompileTools: @compile_workload, @setup_workload using Reexport: @reexport -using SciMLBase: SciMLBase, AbstractNonlinearAlgorithm, NonlinearFunction, NonlinearProblem, - NonlinearLeastSquaresProblem, IntervalNonlinearProblem, ReturnCode, remake +using Setfield: @set! + +using BracketingNonlinearSolve: BracketingNonlinearSolve +using CommonSolve: CommonSolve, solve, init, solve! +using LineSearch: LiFukushimaLineSearch +using MaybeInplace: @bb +using NonlinearSolveBase: NonlinearSolveBase, ImmutableNonlinearProblem, L2_NORM, + nonlinearsolve_forwarddiff_solve, nonlinearsolve_dual_solution, + AbstractNonlinearSolveAlgorithm +using SciMLBase: SciMLBase, NonlinearFunction, NonlinearProblem, + NonlinearLeastSquaresProblem, ReturnCode, remake + +using LinearAlgebra: LinearAlgebra, dot + using StaticArraysCore: StaticArray, SArray, SVector, MArray # AD Dependencies using ADTypes: ADTypes, AutoForwardDiff using DifferentiationInterface: DifferentiationInterface using FiniteDiff: FiniteDiff -using ForwardDiff: ForwardDiff - -using NonlinearSolveBase: NonlinearSolveBase, ImmutableNonlinearProblem, L2_NORM, - nonlinearsolve_forwarddiff_solve, nonlinearsolve_dual_solution +using ForwardDiff: ForwardDiff, Dual const DI = DifferentiationInterface -abstract type AbstractSimpleNonlinearSolveAlgorithm <: AbstractNonlinearAlgorithm end +const DualNonlinearProblem = NonlinearProblem{ + <:Union{Number, <:AbstractArray}, iip, + <:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}} +} where {iip, T, V, P} + +const DualNonlinearLeastSquaresProblem = NonlinearLeastSquaresProblem{ + <:Union{Number, <:AbstractArray}, iip, + <:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}} +} where {iip, T, V, P} -const safe_similar = NonlinearSolveBase.Utils.safe_similar +abstract type AbstractSimpleNonlinearSolveAlgorithm <: AbstractNonlinearSolveAlgorithm end + +const NLBUtils = NonlinearSolveBase.Utils is_extension_loaded(::Val) = false @@ -42,61 +55,66 @@ include("raphson.jl") include("trust_region.jl") # By Pass the highlevel checks for NonlinearProblem for Simple Algorithms -function CommonSolve.solve(prob::NonlinearProblem, - alg::AbstractSimpleNonlinearSolveAlgorithm, args...; kwargs...) +function CommonSolve.solve( + prob::NonlinearProblem, alg::AbstractSimpleNonlinearSolveAlgorithm, args...; + kwargs... +) prob = convert(ImmutableNonlinearProblem, prob) return solve(prob, alg, args...; kwargs...) end function CommonSolve.solve( - prob::NonlinearProblem{<:Union{Number, <:AbstractArray}, iip, - <:Union{ - <:ForwardDiff.Dual{T, V, P}, <:AbstractArray{<:ForwardDiff.Dual{T, V, P}}}}, - alg::AbstractSimpleNonlinearSolveAlgorithm, - args...; - kwargs...) where {T, V, P, iip} + prob::DualNonlinearProblem, alg::AbstractSimpleNonlinearSolveAlgorithm, + args...; kwargs... +) if hasfield(typeof(alg), :autodiff) && alg.autodiff === nothing - @reset alg.autodiff = AutoForwardDiff() + @set! alg.autodiff = AutoForwardDiff() end prob = convert(ImmutableNonlinearProblem, prob) sol, partials = nonlinearsolve_forwarddiff_solve(prob, alg, args...; kwargs...) dual_soln = nonlinearsolve_dual_solution(sol.u, partials, prob.p) return SciMLBase.build_solution( - prob, alg, dual_soln, sol.resid; sol.retcode, sol.stats, sol.original) + prob, alg, dual_soln, sol.resid; sol.retcode, sol.stats, sol.original + ) end function CommonSolve.solve( - prob::NonlinearLeastSquaresProblem{<:Union{Number, <:AbstractArray}, iip, - <:Union{ - <:ForwardDiff.Dual{T, V, P}, <:AbstractArray{<:ForwardDiff.Dual{T, V, P}}}}, - alg::AbstractSimpleNonlinearSolveAlgorithm, - args...; - kwargs...) where {T, V, P, iip} + prob::DualNonlinearLeastSquaresProblem, alg::AbstractSimpleNonlinearSolveAlgorithm, + args...; kwargs... +) if hasfield(typeof(alg), :autodiff) && alg.autodiff === nothing - @reset alg.autodiff = AutoForwardDiff() + @set! alg.autodiff = AutoForwardDiff() end sol, partials = nonlinearsolve_forwarddiff_solve(prob, alg, args...; kwargs...) dual_soln = nonlinearsolve_dual_solution(sol.u, partials, prob.p) return SciMLBase.build_solution( - prob, alg, dual_soln, sol.resid; sol.retcode, sol.stats, sol.original) + prob, alg, dual_soln, sol.resid; sol.retcode, sol.stats, sol.original + ) end function CommonSolve.solve( prob::Union{ImmutableNonlinearProblem, NonlinearLeastSquaresProblem}, alg::AbstractSimpleNonlinearSolveAlgorithm, - args...; sensealg = nothing, u0 = nothing, p = nothing, kwargs...) + args...; sensealg = nothing, u0 = nothing, p = nothing, kwargs... +) if sensealg === nothing && haskey(prob.kwargs, :sensealg) sensealg = prob.kwargs[:sensealg] end new_u0 = u0 !== nothing ? u0 : prob.u0 new_p = p !== nothing ? p : prob.p - return simplenonlinearsolve_solve_up(prob, sensealg, new_u0, u0 === nothing, new_p, - p === nothing, alg, args...; prob.kwargs..., kwargs...) + return simplenonlinearsolve_solve_up( + prob, sensealg, + new_u0, u0 === nothing, + new_p, p === nothing, + alg, args...; + prob.kwargs..., kwargs... + ) end function simplenonlinearsolve_solve_up( prob::Union{ImmutableNonlinearProblem, NonlinearLeastSquaresProblem}, sensealg, u0, - u0_changed, p, p_changed, alg, args...; kwargs...) + u0_changed, p, p_changed, alg, args...; kwargs... +) (u0_changed || p_changed) && (prob = remake(prob; u0, p)) return SciMLBase.__solve(prob, alg, args...; kwargs...) end @@ -131,7 +149,7 @@ function solve_adjoint_internal end @compile_workload begin for prob in (prob_scalar, prob_iip, prob_oop), alg in algs - CommonSolve.solve(prob, alg; abstol = 1e-2) + CommonSolve.solve(prob, alg; abstol = 1e-2, verbose = false) end end end diff --git a/lib/SimpleNonlinearSolve/src/broyden.jl b/lib/SimpleNonlinearSolve/src/broyden.jl index 6537a4d2d..48a056b7d 100644 --- a/lib/SimpleNonlinearSolve/src/broyden.jl +++ b/lib/SimpleNonlinearSolve/src/broyden.jl @@ -18,17 +18,19 @@ array problems. end function SimpleBroyden(; - linesearch::Union{Bool, Val{true}, Val{false}} = Val(false), alpha = nothing) + linesearch::Union{Bool, Val{true}, Val{false}} = Val(false), alpha = nothing +) linesearch = linesearch isa Bool ? Val(linesearch) : linesearch return SimpleBroyden(linesearch, alpha) end -function SciMLBase.__solve(prob::ImmutableNonlinearProblem, alg::SimpleBroyden, args...; +function SciMLBase.__solve( + prob::ImmutableNonlinearProblem, alg::SimpleBroyden, args...; abstol = nothing, reltol = nothing, maxiters = 1000, - alias_u0 = false, termination_condition = nothing, kwargs...) - x = Utils.maybe_unaliased(prob.u0, alias_u0) - fx = Utils.get_fx(prob, x) - fx = Utils.eval_f(prob, fx, x) + alias_u0 = false, termination_condition = nothing, kwargs... +) + x = NLBUtils.maybe_unaliased(prob.u0, alias_u0) + fx = NLBUtils.evaluate_f(prob, x) T = promote_type(eltype(fx), eltype(x)) iszero(fx) && @@ -54,9 +56,10 @@ function SciMLBase.__solve(prob::ImmutableNonlinearProblem, alg::SimpleBroyden, @bb δJ⁻¹ = copy(J⁻¹) abstol, reltol, tc_cache = NonlinearSolveBase.init_termination_cache( - prob, abstol, reltol, fx, x, termination_condition, Val(:simple)) + prob, abstol, reltol, fx, x, termination_condition, Val(:simple) + ) - if alg.linesearch === Val(true) + if alg.linesearch isa Val{true} ls_alg = LiFukushimaLineSearch(; nan_maxiters = nothing) ls_cache = init(prob, ls_alg, fx, x) else @@ -75,7 +78,7 @@ function SciMLBase.__solve(prob::ImmutableNonlinearProblem, alg::SimpleBroyden, end @bb @. x = xo + α * δx - fx = Utils.eval_f(prob, fx, x) + fx = NLBUtils.evaluate_f!!(prob, fx, x) @bb @. δf = fx - fprev # Termination Checks @@ -88,8 +91,8 @@ function SciMLBase.__solve(prob::ImmutableNonlinearProblem, alg::SimpleBroyden, @bb @. δJ⁻¹n = (δx - J⁻¹δf) / d - δJ⁻¹n_ = Utils.safe_vec(δJ⁻¹n) - xᵀJ⁻¹_ = Utils.safe_vec(xᵀJ⁻¹) + δJ⁻¹n_ = NLBUtils.safe_vec(δJ⁻¹n) + xᵀJ⁻¹_ = NLBUtils.safe_vec(xᵀJ⁻¹) @bb δJ⁻¹ = δJ⁻¹n_ × transpose(xᵀJ⁻¹_) @bb J⁻¹ .+= δJ⁻¹ diff --git a/lib/SimpleNonlinearSolve/src/dfsane.jl b/lib/SimpleNonlinearSolve/src/dfsane.jl index 0d400b0ce..fb371e3c3 100644 --- a/lib/SimpleNonlinearSolve/src/dfsane.jl +++ b/lib/SimpleNonlinearSolve/src/dfsane.jl @@ -1,7 +1,9 @@ """ - SimpleDFSane(; σ_min::Real = 1e-10, σ_max::Real = 1e10, σ_1::Real = 1.0, + SimpleDFSane(; + σ_min::Real = 1e-10, σ_max::Real = 1e10, σ_1::Real = 1.0, M::Union{Int, Val} = Val(10), γ::Real = 1e-4, τ_min::Real = 0.1, τ_max::Real = 0.5, - nexp::Int = 2, η_strategy::Function = (f_1, k, x, F) -> f_1 ./ k^2) + nexp::Int = 2, η_strategy::Function = (f_1, k, x, F) -> f_1 ./ k^2 + ) A low-overhead implementation of the df-sane method for solving large-scale nonlinear systems of equations. For in depth information about all the parameters and the algorithm, @@ -48,20 +50,26 @@ see [la2006spectral](@citet). M <: Val end -# XXX[breaking]: we should change the names to not have unicode -function SimpleDFSane(; σ_min::Real = 1e-10, σ_max::Real = 1e10, σ_1::Real = 1.0, - M::Union{Int, Val} = Val(10), γ::Real = 1e-4, τ_min::Real = 0.1, τ_max::Real = 0.5, - nexp::Int = 2, η_strategy::F = (f_1, k, x, F) -> f_1 ./ k^2) where {F} +function SimpleDFSane(; + sigma_min::Real = 1e-10, sigma_max::Real = 1e10, sigma_1::Real = 1.0, + M::Union{Int, Val} = Val(10), gamma::Real = 1e-4, tau_min::Real = 0.1, + tau_max::Real = 0.5, n_exp::Int = 2, + eta_strategy::F = (fn_1, n, x_n, f_n) -> fn_1 / n^2 +) where {F} M = M isa Int ? Val(M) : M - return SimpleDFSane(σ_min, σ_max, σ_1, γ, τ_min, τ_max, nexp, η_strategy, M) + return SimpleDFSane( + sigma_min, sigma_max, sigma_1, gamma, tau_min, tau_max, n_exp, + eta_strategy, M + ) end -function SciMLBase.__solve(prob::ImmutableNonlinearProblem, alg::SimpleDFSane, args...; +function SciMLBase.__solve( + prob::ImmutableNonlinearProblem, alg::SimpleDFSane, args...; abstol = nothing, reltol = nothing, maxiters = 1000, alias_u0 = false, - termination_condition = nothing, kwargs...) - x = Utils.maybe_unaliased(prob.u0, alias_u0) - fx = Utils.get_fx(prob, x) - fx = Utils.eval_f(prob, fx, x) + termination_condition = nothing, kwargs... +) + x = NLBUtils.maybe_unaliased(prob.u0, alias_u0) + fx = NLBUtils.evaluate_f(prob, x) T = promote_type(eltype(fx), eltype(x)) σ_min = T(alg.σ_min) @@ -74,7 +82,8 @@ function SciMLBase.__solve(prob::ImmutableNonlinearProblem, alg::SimpleDFSane, a τ_max = T(alg.τ_max) abstol, reltol, tc_cache = NonlinearSolveBase.init_termination_cache( - prob, abstol, reltol, fx, x, termination_condition, Val(:simple)) + prob, abstol, reltol, fx, x, termination_condition, Val(:simple) + ) fx_norm = L2_NORM(fx)^nexp α_1 = one(T) @@ -104,7 +113,7 @@ function SciMLBase.__solve(prob::ImmutableNonlinearProblem, alg::SimpleDFSane, a @bb @. x_cache = x + α_p * d - fx = Utils.eval_f(prob, fx, x_cache) + fx = NLBUtils.evaluate_f!!(prob, fx, x_cache) fx_norm_new = L2_NORM(fx)^nexp while k < maxiters @@ -113,7 +122,7 @@ function SciMLBase.__solve(prob::ImmutableNonlinearProblem, alg::SimpleDFSane, a α_tp = α_p^2 * fx_norm / (fx_norm_new + (T(2) * α_p - T(1)) * fx_norm) @bb @. x_cache = x - α_m * d - fx = Utils.eval_f(prob, fx, x_cache) + fx = NLBUtils.evaluate_f!!(prob, fx, x_cache) fx_norm_new = L2_NORM(fx)^nexp (fx_norm_new ≤ (f_bar + η - γ * α_m^2 * fx_norm)) && break @@ -123,7 +132,7 @@ function SciMLBase.__solve(prob::ImmutableNonlinearProblem, alg::SimpleDFSane, a α_m = clamp(α_tm, τ_min * α_m, τ_max * α_m) @bb @. x_cache = x + α_p * d - fx = Utils.eval_f(prob, fx, x_cache) + fx = NLBUtils.evaluate_f!!(prob, fx, x_cache) fx_norm_new = L2_NORM(fx)^nexp k += 1 @@ -146,11 +155,11 @@ function SciMLBase.__solve(prob::ImmutableNonlinearProblem, alg::SimpleDFSane, a fx_norm = fx_norm_new # Store function value - idx = mod1(k, SciMLBase._unwrap_val(alg.M)) + idx = mod1(k, NLBUtils.unwrap_val(alg.M)) if history_f_k isa SVector history_f_k = Base.setindex(history_f_k, fx_norm_new, idx) elseif history_f_k isa NTuple - @reset history_f_k[idx] = fx_norm_new + @set! history_f_k[idx] = fx_norm_new else history_f_k[idx] = fx_norm_new end diff --git a/lib/SimpleNonlinearSolve/src/halley.jl b/lib/SimpleNonlinearSolve/src/halley.jl index 30eb1a821..2d8446d90 100644 --- a/lib/SimpleNonlinearSolve/src/halley.jl +++ b/lib/SimpleNonlinearSolve/src/halley.jl @@ -23,33 +23,37 @@ end function SciMLBase.__solve( prob::ImmutableNonlinearProblem, alg::SimpleHalley, args...; abstol = nothing, reltol = nothing, maxiters = 1000, - alias_u0 = false, termination_condition = nothing, kwargs...) - x = Utils.maybe_unaliased(prob.u0, alias_u0) - fx = Utils.get_fx(prob, x) - fx = Utils.eval_f(prob, fx, x) + alias_u0 = false, termination_condition = nothing, kwargs... +) + x = NLBUtils.maybe_unaliased(prob.u0, alias_u0) + fx = NLBUtils.evaluate_f(prob, x) T = promote_type(eltype(fx), eltype(x)) iszero(fx) && return SciMLBase.build_solution(prob, alg, x, fx; retcode = ReturnCode.Success) abstol, reltol, tc_cache = NonlinearSolveBase.init_termination_cache( - prob, abstol, reltol, fx, x, termination_condition, Val(:simple)) + prob, abstol, reltol, fx, x, termination_condition, Val(:simple) + ) # The way we write the 2nd order derivatives, we know Enzyme won't work there autodiff = alg.autodiff === nothing ? AutoForwardDiff() : alg.autodiff + @set! alg.autodiff = autodiff @bb xo = copy(x) - strait = setindex_trait(x) - - A = strait isa CanSetindex ? safe_similar(x, length(x), length(x)) : x - Aaᵢ = strait isa CanSetindex ? safe_similar(x, length(x)) : x - cᵢ = strait isa CanSetindex ? safe_similar(x) : x + if NLBUtils.can_setindex(x) + A = NLBUtils.safe_similar(x, length(x), length(x)) + Aaᵢ = NLBUtils.safe_similar(x, length(x)) + cᵢ = NLBUtils.safe_similar(x) + else + A, Aaᵢ, cᵢ = x, x, x + end for _ in 1:maxiters fx, J, H = Utils.compute_jacobian_and_hessian(autodiff, prob, fx, x) - strait isa CannotSetindex && (A = J) + NLBUtils.can_setindex(x) || (A = J) # Factorize Once and Reuse J_fact = if J isa Number @@ -57,22 +61,23 @@ function SciMLBase.__solve( else fact = LinearAlgebra.lu(J; check = false) !LinearAlgebra.issuccess(fact) && return SciMLBase.build_solution( - prob, alg, x, fx; retcode = ReturnCode.Unstable) + prob, alg, x, fx; retcode = ReturnCode.Unstable + ) fact end - aᵢ = J_fact \ Utils.safe_vec(fx) - A_ = Utils.safe_vec(A) + aᵢ = J_fact \ NLBUtils.safe_vec(fx) + A_ = NLBUtils.safe_vec(A) @bb A_ = H × aᵢ - A = Utils.restructure(A, A_) + A = NLBUtils.restructure(A, A_) @bb Aaᵢ = A × aᵢ @bb A .*= -1 - bᵢ = J_fact \ Utils.safe_vec(Aaᵢ) + bᵢ = J_fact \ NLBUtils.safe_vec(Aaᵢ) - cᵢ_ = Utils.safe_vec(cᵢ) + cᵢ_ = NLBUtils.safe_vec(cᵢ) @bb @. cᵢ_ = (aᵢ * aᵢ) / (-aᵢ + (T(0.5) * bᵢ)) - cᵢ = Utils.restructure(cᵢ, cᵢ_) + cᵢ = NLBUtils.restructure(cᵢ, cᵢ_) solved, retcode, fx_sol, x_sol = Utils.check_termination(tc_cache, fx, x, xo, prob) solved && return SciMLBase.build_solution(prob, alg, x_sol, fx_sol; retcode) diff --git a/lib/SimpleNonlinearSolve/src/klement.jl b/lib/SimpleNonlinearSolve/src/klement.jl index 31c4cca96..a8fb7705c 100644 --- a/lib/SimpleNonlinearSolve/src/klement.jl +++ b/lib/SimpleNonlinearSolve/src/klement.jl @@ -6,16 +6,18 @@ method is non-allocating on scalar and static array problems. """ struct SimpleKlement <: AbstractSimpleNonlinearSolveAlgorithm end -function SciMLBase.__solve(prob::ImmutableNonlinearProblem, alg::SimpleKlement, args...; +function SciMLBase.__solve( + prob::ImmutableNonlinearProblem, alg::SimpleKlement, args...; abstol = nothing, reltol = nothing, maxiters = 1000, - alias_u0 = false, termination_condition = nothing, kwargs...) - x = Utils.maybe_unaliased(prob.u0, alias_u0) + alias_u0 = false, termination_condition = nothing, kwargs... +) + x = NLBUtils.maybe_unaliased(prob.u0, alias_u0) T = eltype(x) - fx = Utils.get_fx(prob, x) - fx = Utils.eval_f(prob, fx, x) + fx = NLBUtils.evaluate_f(prob, x) abstol, reltol, tc_cache = NonlinearSolveBase.init_termination_cache( - prob, abstol, reltol, fx, x, termination_condition, Val(:simple)) + prob, abstol, reltol, fx, x, termination_condition, Val(:simple) + ) @bb δx = copy(x) @bb fprev = copy(fx) @@ -31,7 +33,7 @@ function SciMLBase.__solve(prob::ImmutableNonlinearProblem, alg::SimpleKlement, @bb @. δx = fprev / J @bb @. x = xo - δx - fx = Utils.eval_f(prob, fx, x) + fx = NLBUtils.evaluate_f!!(prob, fx, x) # Termination Checks solved, retcode, fx_sol, x_sol = Utils.check_termination(tc_cache, fx, x, xo, prob) diff --git a/lib/SimpleNonlinearSolve/src/lbroyden.jl b/lib/SimpleNonlinearSolve/src/lbroyden.jl index d2bd6ef83..a0d33f942 100644 --- a/lib/SimpleNonlinearSolve/src/lbroyden.jl +++ b/lib/SimpleNonlinearSolve/src/lbroyden.jl @@ -1,6 +1,7 @@ """ - SimpleLimitedMemoryBroyden(; threshold::Union{Val, Int} = Val(27), - linesearch = Val(false), alpha = nothing) + SimpleLimitedMemoryBroyden(; + threshold::Union{Val, Int} = Val(27), linesearch = Val(false), alpha = nothing + ) A limited memory implementation of Broyden. This method applies the L-BFGS scheme to Broyden's method. @@ -40,7 +41,8 @@ function SciMLBase.__solve( if termination_condition === nothing || termination_condition isa NonlinearSolveBase.AbsNormTerminationMode return internal_static_solve( - prob, alg, args...; termination_condition, kwargs...) + prob, alg, args...; termination_condition, kwargs... + ) end @warn "Specifying `termination_condition = $(termination_condition)` for \ `SimpleLimitedMemoryBroyden` with `SArray` is not non-allocating. Use \ @@ -53,22 +55,26 @@ end @views function internal_generic_solve( prob::ImmutableNonlinearProblem, alg::SimpleLimitedMemoryBroyden, args...; abstol = nothing, reltol = nothing, maxiters = 1000, - alias_u0 = false, termination_condition = nothing, kwargs...) - x = Utils.maybe_unaliased(prob.u0, alias_u0) - η = min(SciMLBase._unwrap_val(alg.threshold), maxiters) + alias_u0 = false, termination_condition = nothing, kwargs... +) + x = NLBUtils.maybe_unaliased(prob.u0, alias_u0) + η = min(NLBUtils.unwrap_val(alg.threshold), maxiters) # For scalar problems / if the threshold is larger than problem size just use Broyden if x isa Number || length(x) ≤ η - return SciMLBase.__solve(prob, SimpleBroyden(; alg.linesearch), args...; abstol, - reltol, maxiters, termination_condition, kwargs...) + return SciMLBase.__solve( + prob, SimpleBroyden(; alg.linesearch), args...; + abstol, reltol, maxiters, termination_condition, kwargs... + ) end - fx = Utils.get_fx(prob, x) + fx = NLBUtils.evaluate_f(prob, x) U, Vᵀ = init_low_rank_jacobian(x, fx, x isa StaticArray ? alg.threshold : Val(η)) abstol, reltol, tc_cache = NonlinearSolveBase.init_termination_cache( - prob, abstol, reltol, fx, x, termination_condition, Val(:simple)) + prob, abstol, reltol, fx, x, termination_condition, Val(:simple) + ) @bb xo = copy(x) @bb δx = copy(fx) @@ -80,7 +86,7 @@ end Tcache = lbroyden_threshold_cache(x, x isa StaticArray ? alg.threshold : Val(η)) @bb mat_cache = copy(x) - if alg.linesearch === Val(true) + if alg.linesearch isa Val{true} ls_alg = LiFukushimaLineSearch(; nan_maxiters = nothing) ls_cache = init(prob, ls_alg, fx, x) else @@ -96,7 +102,7 @@ end end @bb @. x = xo + α * δx - fx = Utils.eval_f(prob, fx, x) + fx = NLBUtils.evaluate_f!!(prob, fx, x) @bb @. δf = fx - fo # Termination Checks @@ -111,8 +117,8 @@ end d = dot(vᵀ, δf) @bb @. δx = (δx - mvec) / d - selectdim(U, 2, mod1(i, η)) .= Utils.safe_vec(δx) - selectdim(Vᵀ, 1, mod1(i, η)) .= Utils.safe_vec(vᵀ) + selectdim(U, 2, mod1(i, η)) .= NLBUtils.safe_vec(δx) + selectdim(Vᵀ, 1, mod1(i, η)) .= NLBUtils.safe_vec(vᵀ) Uₚ = selectdim(U, 2, 1:min(η, i)) Vᵀₚ = selectdim(Vᵀ, 1, 1:min(η, i)) @@ -130,10 +136,11 @@ end # finicky, so we'll implement it separately from the generic version # Ignore termination_condition. Don't pass things into internal functions function internal_static_solve( - prob::ImmutableNonlinearProblem{<:SArray}, alg::SimpleLimitedMemoryBroyden, - args...; abstol = nothing, maxiters = 1000, kwargs...) + prob::ImmutableNonlinearProblem{<:SArray}, alg::SimpleLimitedMemoryBroyden, args...; + abstol = nothing, maxiters = 1000, kwargs... +) x = prob.u0 - fx = Utils.get_fx(prob, x) + fx = NLBUtils.evaluate_f(prob, x) U, Vᵀ = init_low_rank_jacobian(vec(x), vec(fx), alg.threshold) @@ -165,7 +172,7 @@ function internal_static_solve( xo, fo, δx = res.x, res.fx, res.δx - for i in 1:(maxiters - SciMLBase._unwrap_val(alg.threshold)) + for i in 1:(maxiters - NLBUtils.unwrap_val(alg.threshold)) if ls_cache === nothing α = true else @@ -174,22 +181,22 @@ function internal_static_solve( end x = xo + α * δx - fx = Utils.eval_f(prob, fx, x) + fx = NLBUtils.evaluate_f!!(prob, fx, x) δf = fx - fo maximum(abs, fx) ≤ abstol && return SciMLBase.build_solution(prob, alg, x, fx; retcode = ReturnCode.Success) - vᵀ = Utils.restructure(x, rmatvec!!(U, Vᵀ, vec(δx), init_α)) - mvec = Utils.restructure(x, matvec!!(U, Vᵀ, vec(δf), init_α)) + vᵀ = NLBUtils.restructure(x, rmatvec!!(U, Vᵀ, vec(δx), init_α)) + mvec = NLBUtils.restructure(x, matvec!!(U, Vᵀ, vec(δf), init_α)) d = dot(vᵀ, δf) δx = @. (δx - mvec) / d - U = Base.setindex(U, vec(δx), mod1(i, SciMLBase._unwrap_val(alg.threshold))) - Vᵀ = Base.setindex(Vᵀ, vec(vᵀ), mod1(i, SciMLBase._unwrap_val(alg.threshold))) + U = Base.setindex(U, vec(δx), mod1(i, NLBUtils.unwrap_val(alg.threshold))) + Vᵀ = Base.setindex(Vᵀ, vec(vᵀ), mod1(i, NLBUtils.unwrap_val(alg.threshold))) - δx = -Utils.restructure(fx, matvec!!(U, Vᵀ, vec(fx), init_α)) + δx = -NLBUtils.restructure(fx, matvec!!(U, Vᵀ, vec(fx), init_α)) xo, fo = x, fx end @@ -198,8 +205,8 @@ function internal_static_solve( end @generated function internal_unrolled_lbroyden_initial_iterations( - prob, xo, fo, δx, abstol, U, Vᵀ, ::Val{threshold}, - ls_cache, init_α) where {threshold} + prob, xo, fo, δx, abstol, U, Vᵀ, ::Val{threshold}, ls_cache, init_α +) where {threshold} calls = [] for i in 1:threshold static_idx, static_idx_p1 = Val(i - 1), Val(i) @@ -219,8 +226,8 @@ end Uₚ = first_n_getindex(U, $(static_idx)) Vᵀₚ = first_n_getindex(Vᵀ, $(static_idx)) - vᵀ = Utils.restructure(x, rmatvec!!(Uₚ, Vᵀₚ, vec(δx), init_α)) - mvec = Utils.restructure(x, matvec!!(Uₚ, Vᵀₚ, vec(δf), init_α)) + vᵀ = NLBUtils.restructure(x, rmatvec!!(Uₚ, Vᵀₚ, vec(δx), init_α)) + mvec = NLBUtils.restructure(x, matvec!!(Uₚ, Vᵀₚ, vec(δf), init_α)) d = dot(vᵀ, δf) δx = @. (δx - mvec) / d @@ -230,7 +237,7 @@ end Uₚ = first_n_getindex(U, $(static_idx_p1)) Vᵀₚ = first_n_getindex(Vᵀ, $(static_idx_p1)) - δx = -Utils.restructure(fx, matvec!!(Uₚ, Vᵀₚ, vec(fx), init_α)) + δx = -NLBUtils.restructure(fx, matvec!!(Uₚ, Vᵀₚ, vec(fx), init_α)) x0, fo = x, fx end) @@ -284,7 +291,8 @@ function fast_mapdot(x::SVector{S1}, Y::SVector{S2, <:SVector{S1}}) where {S1, S return map(Base.Fix1(dot, x), Y) end @generated function fast_mapTdot( - x::SVector{S1}, Y::SVector{S1, <:SVector{S2}}) where {S1, S2} + x::SVector{S1}, Y::SVector{S1, <:SVector{S2}} +) where {S1, S2} calls = [] syms = [gensym("m$(i)") for i in 1:S1] for i in 1:S1 @@ -301,22 +309,26 @@ end return :(return SVector{$N, $T}(($(getcalls...)))) end -lbroyden_threshold_cache(x, ::Val{threshold}) where {threshold} = safe_similar(x, threshold) +function lbroyden_threshold_cache(x, ::Val{threshold}) where {threshold} + return NLBUtils.safe_similar(x, threshold) +end function lbroyden_threshold_cache(x::StaticArray, ::Val{threshold}) where {threshold} return zeros(MArray{Tuple{threshold}, eltype(x)}) end lbroyden_threshold_cache(::SArray, ::Val{threshold}) where {threshold} = nothing -function init_low_rank_jacobian(u::StaticArray{S1, T1}, fu::StaticArray{S2, T2}, - ::Val{threshold}) where {S1, S2, T1, T2, threshold} +function init_low_rank_jacobian( + u::StaticArray{S1, T1}, fu::StaticArray{S2, T2}, ::Val{threshold} +) where {S1, S2, T1, T2, threshold} T = promote_type(T1, T2) fuSize, uSize = Size(fu), Size(u) Vᵀ = MArray{Tuple{threshold, prod(uSize)}, T}(undef) U = MArray{Tuple{prod(fuSize), threshold}, T}(undef) return U, Vᵀ end -@generated function init_low_rank_jacobian(u::SVector{Lu, T1}, fu::SVector{Lfu, T2}, - ::Val{threshold}) where {Lu, Lfu, T1, T2, threshold} +@generated function init_low_rank_jacobian( + u::SVector{Lu, T1}, fu::SVector{Lfu, T2}, ::Val{threshold} +) where {Lu, Lfu, T1, T2, threshold} T = promote_type(T1, T2) inner_inits_Vᵀ = [:(zeros(SVector{$Lu, $T})) for i in 1:threshold] inner_inits_U = [:(zeros(SVector{$Lfu, $T})) for i in 1:threshold] @@ -327,7 +339,7 @@ end end end function init_low_rank_jacobian(u, fu, ::Val{threshold}) where {threshold} - Vᵀ = safe_similar(u, threshold, length(u)) - U = safe_similar(u, length(fu), threshold) + Vᵀ = NLBUtils.safe_similar(u, threshold, length(u)) + U = NLBUtils.safe_similar(u, length(fu), threshold) return U, Vᵀ end diff --git a/lib/SimpleNonlinearSolve/src/raphson.jl b/lib/SimpleNonlinearSolve/src/raphson.jl index ebbb5f9f9..34efcbb90 100644 --- a/lib/SimpleNonlinearSolve/src/raphson.jl +++ b/lib/SimpleNonlinearSolve/src/raphson.jl @@ -27,35 +27,37 @@ function SciMLBase.__solve( prob::Union{ImmutableNonlinearProblem, NonlinearLeastSquaresProblem}, alg::SimpleNewtonRaphson, args...; abstol = nothing, reltol = nothing, maxiters = 1000, - alias_u0 = false, termination_condition = nothing, kwargs...) - x = Utils.maybe_unaliased(prob.u0, alias_u0) - fx = Utils.get_fx(prob, x) - fx = Utils.eval_f(prob, fx, x) + alias_u0 = false, termination_condition = nothing, kwargs... +) + x = NLBUtils.maybe_unaliased(prob.u0, alias_u0) + fx = NLBUtils.evaluate_f(prob, x) iszero(fx) && return SciMLBase.build_solution(prob, alg, x, fx; retcode = ReturnCode.Success) abstol, reltol, tc_cache = NonlinearSolveBase.init_termination_cache( - prob, abstol, reltol, fx, x, termination_condition, Val(:simple)) + prob, abstol, reltol, fx, x, termination_condition, Val(:simple) + ) autodiff = SciMLBase.has_jac(prob.f) ? alg.autodiff : NonlinearSolveBase.select_jacobian_autodiff(prob, alg.autodiff) + @set! alg.autodiff = autodiff @bb xo = similar(x) fx_cache = (SciMLBase.isinplace(prob) && !SciMLBase.has_jac(prob.f)) ? - safe_similar(fx) : fx + NLBUtils.safe_similar(fx) : fx jac_cache = Utils.prepare_jacobian(prob, autodiff, fx_cache, x) J = Utils.compute_jacobian!!(nothing, prob, autodiff, fx_cache, x, jac_cache) for _ in 1:maxiters @bb copyto!(xo, x) - δx = Utils.restructure(x, J \ Utils.safe_vec(fx)) + δx = NLBUtils.restructure(x, J \ NLBUtils.safe_vec(fx)) @bb x .-= δx solved, retcode, fx_sol, x_sol = Utils.check_termination(tc_cache, fx, x, xo, prob) solved && return SciMLBase.build_solution(prob, alg, x_sol, fx_sol; retcode) - fx = Utils.eval_f(prob, fx, x) + fx = NLBUtils.evaluate_f!!(prob, fx, x) J = Utils.compute_jacobian!!(J, prob, autodiff, fx_cache, x, jac_cache) end diff --git a/lib/SimpleNonlinearSolve/src/trust_region.jl b/lib/SimpleNonlinearSolve/src/trust_region.jl index 32e7a6219..d9ac54235 100644 --- a/lib/SimpleNonlinearSolve/src/trust_region.jl +++ b/lib/SimpleNonlinearSolve/src/trust_region.jl @@ -1,9 +1,11 @@ """ - SimpleTrustRegion(; autodiff = AutoForwardDiff(), max_trust_radius = 0.0, + SimpleTrustRegion(; + autodiff = AutoForwardDiff(), max_trust_radius = 0.0, initial_trust_radius = 0.0, step_threshold = nothing, shrink_threshold = nothing, expand_threshold = nothing, shrink_factor = 0.25, expand_factor = 2.0, max_shrink_times::Int = 32, - nlsolve_update_rule = Val(false)) + nlsolve_update_rule = Val(false) + ) A low-overhead implementation of a trust-region solver. This method is non-allocating on scalar and static array problems. @@ -18,18 +20,18 @@ scalar and static array problems. - `initial_trust_radius`: the initial trust region radius. Defaults to `max_trust_radius / 11`. - `step_threshold`: the threshold for taking a step. In every iteration, the threshold is - compared with a value `r`, which is the actual reduction in the objective function divided - by the predicted reduction. If `step_threshold > r` the model is not a good approximation, - and the step is rejected. Defaults to `0.1`. For more details, see + compared with a value `r`, which is the actual reduction in the objective function + divided by the predicted reduction. If `step_threshold > r` the model is not a good + approximation, and the step is rejected. Defaults to `0.1`. For more details, see [Rahpeymaii, F.](https://link.springer.com/article/10.1007/s40096-020-00339-4) - `shrink_threshold`: the threshold for shrinking the trust region radius. In every - iteration, the threshold is compared with a value `r` which is the actual reduction in the - objective function divided by the predicted reduction. If `shrink_threshold > r` the trust - region radius is shrunk by `shrink_factor`. Defaults to `0.25`. For more details, see - [Rahpeymaii, F.](https://link.springer.com/article/10.1007/s40096-020-00339-4) + iteration, the threshold is compared with a value `r` which is the actual reduction in + the objective function divided by the predicted reduction. If `shrink_threshold > r` the + trust region radius is shrunk by `shrink_factor`. Defaults to `0.25`. For more details, + see [Rahpeymaii, F.](https://link.springer.com/article/10.1007/s40096-020-00339-4) - `expand_threshold`: the threshold for expanding the trust region radius. If a step is - taken, i.e `step_threshold < r` (with `r` defined in `shrink_threshold`), a check is also - made to see if `expand_threshold < r`. If that is true, the trust region radius is + taken, i.e `step_threshold < r` (with `r` defined in `shrink_threshold`), a check is + also made to see if `expand_threshold < r`. If that is true, the trust region radius is expanded by `expand_factor`. Defaults to `0.75`. - `shrink_factor`: the factor to shrink the trust region radius with if `shrink_threshold > r` (with `r` defined in `shrink_threshold`). Defaults to `0.25`. @@ -55,29 +57,32 @@ scalar and static array problems. nlsolve_update_rule = Val(false) end -function SciMLBase.__solve(prob::ImmutableNonlinearProblem, alg::SimpleTrustRegion, - args...; abstol = nothing, reltol = nothing, maxiters = 1000, - alias_u0 = false, termination_condition = nothing, kwargs...) - x = Utils.maybe_unaliased(prob.u0, alias_u0) +function SciMLBase.__solve( + prob::Union{ImmutableNonlinearProblem, NonlinearLeastSquaresProblem}, + alg::SimpleTrustRegion, args...; + abstol = nothing, reltol = nothing, maxiters = 1000, + alias_u0 = false, termination_condition = nothing, kwargs... +) + x = NLBUtils.maybe_unaliased(prob.u0, alias_u0) T = eltype(x) Δₘₐₓ = T(alg.max_trust_radius) Δ = T(alg.initial_trust_radius) η₁ = T(alg.step_threshold) if alg.shrink_threshold === nothing - η₂ = T(ifelse(SciMLBase._unwrap_val(alg.nlsolve_update_rule), 0.05, 0.25)) + η₂ = T(ifelse(NLBUtils.unwrap_val(alg.nlsolve_update_rule), 0.05, 0.25)) else η₂ = T(alg.shrink_threshold) end if alg.expand_threshold === nothing - η₃ = T(ifelse(SciMLBase._unwrap_val(alg.nlsolve_update_rule), 0.9, 0.75)) + η₃ = T(ifelse(NLBUtils.unwrap_val(alg.nlsolve_update_rule), 0.9, 0.75)) else η₃ = T(alg.expand_threshold) end if alg.shrink_factor === nothing - t₁ = T(ifelse(SciMLBase._unwrap_val(alg.nlsolve_update_rule), 0.5, 0.25)) + t₁ = T(ifelse(NLBUtils.unwrap_val(alg.nlsolve_update_rule), 0.5, 0.25)) else t₁ = T(alg.shrink_factor) end @@ -88,23 +93,23 @@ function SciMLBase.__solve(prob::ImmutableNonlinearProblem, alg::SimpleTrustRegi autodiff = SciMLBase.has_jac(prob.f) ? alg.autodiff : NonlinearSolveBase.select_jacobian_autodiff(prob, alg.autodiff) - fx = Utils.get_fx(prob, x) - fx = Utils.eval_f(prob, fx, x) + fx = NLBUtils.evaluate_f(prob, x) norm_fx = L2_NORM(fx) @bb xo = copy(x) fx_cache = (SciMLBase.isinplace(prob) && !SciMLBase.has_jac(prob.f)) ? - safe_similar(fx) : fx + NLBUtils.safe_similar(fx) : fx jac_cache = Utils.prepare_jacobian(prob, autodiff, fx_cache, x) J = Utils.compute_jacobian!!(nothing, prob, autodiff, fx_cache, x, jac_cache) abstol, reltol, tc_cache = NonlinearSolveBase.init_termination_cache( - prob, abstol, reltol, fx, x, termination_condition, Val(:simple)) + prob, abstol, reltol, fx, x, termination_condition, Val(:simple) + ) # Set default trust region radius if not specified by user. iszero(Δₘₐₓ) && (Δₘₐₓ = max(L2_NORM(fx), maximum(x) - minimum(x))) if iszero(Δ) - if SciMLBase._unwrap_val(alg.nlsolve_update_rule) + if NLBUtils.unwrap_val(alg.nlsolve_update_rule) norm_x = L2_NORM(x) Δ = T(ifelse(norm_x > 0, norm_x, 1)) else @@ -114,7 +119,7 @@ function SciMLBase.__solve(prob::ImmutableNonlinearProblem, alg::SimpleTrustRegi fₖ = 0.5 * norm_fx^2 H = transpose(J) * J - g = Utils.restructure(x, J' * Utils.safe_vec(fx)) + g = NLBUtils.restructure(x, J' * NLBUtils.safe_vec(fx)) shrink_counter = 0 @bb δsd = copy(x) @@ -128,7 +133,7 @@ function SciMLBase.__solve(prob::ImmutableNonlinearProblem, alg::SimpleTrustRegi δ = dogleg_method!!(dogleg_cache, J, fx, g, Δ) @bb @. x = xo + δ - fx = Utils.eval_f(prob, fx, x) + fx = NLBUtils.evaluate_f!!(prob, fx, x) fₖ₊₁ = L2_NORM(fx)^2 / T(2) @@ -149,17 +154,18 @@ function SciMLBase.__solve(prob::ImmutableNonlinearProblem, alg::SimpleTrustRegi if r ≥ η₁ # Termination Checks solved, retcode, fx_sol, x_sol = Utils.check_termination( - tc_cache, fx, x, xo, prob) + tc_cache, fx, x, xo, prob + ) solved && return SciMLBase.build_solution(prob, alg, x_sol, fx_sol; retcode) # Take the step. @bb copyto!(xo, x) J = Utils.compute_jacobian!!(J, prob, autodiff, fx_cache, x, jac_cache) - fx = Utils.eval_f(prob, fx, x) + fx = NLBUtils.evaluate_f!!(prob, fx, x) # Update the trust region radius. - if !SciMLBase._unwrap_val(alg.nlsolve_update_rule) && r > η₃ + if !NLBUtils.unwrap_val(alg.nlsolve_update_rule) && r > η₃ Δ = min(t₂ * Δ, Δₘₐₓ) end fₖ = fₖ₊₁ @@ -168,7 +174,7 @@ function SciMLBase.__solve(prob::ImmutableNonlinearProblem, alg::SimpleTrustRegi @bb g = transpose(J) × vec(fx) end - if SciMLBase._unwrap_val(alg.nlsolve_update_rule) + if NLBUtils.unwrap_val(alg.nlsolve_update_rule) if r > η₃ Δ = t₂ * L2_NORM(δ) elseif r > 0.5 @@ -184,7 +190,7 @@ function dogleg_method!!(cache, J, f, g, Δ) (; δsd, δN_δsd, δN) = cache # Compute the Newton step - @bb δN .= Utils.restructure(δN, J \ Utils.safe_vec(f)) + @bb δN .= NLBUtils.restructure(δN, J \ NLBUtils.safe_vec(f)) @bb δN .*= -1 # Test if the full step is within the trust region (L2_NORM(δN) ≤ Δ) && return δN diff --git a/lib/SimpleNonlinearSolve/src/utils.jl b/lib/SimpleNonlinearSolve/src/utils.jl index bd7368bd7..8c35a324f 100644 --- a/lib/SimpleNonlinearSolve/src/utils.jl +++ b/lib/SimpleNonlinearSolve/src/utils.jl @@ -1,70 +1,31 @@ module Utils using ArrayInterface: ArrayInterface -using ConcreteStructs: @concrete using DifferentiationInterface: DifferentiationInterface, Constant using FastClosures: @closure using LinearAlgebra: LinearAlgebra, I, diagind -using NonlinearSolveBase: NonlinearSolveBase, ImmutableNonlinearProblem, - AbstractNonlinearTerminationMode, +using NonlinearSolveBase: NonlinearSolveBase, AbstractNonlinearTerminationMode, AbstractSafeNonlinearTerminationMode, AbstractSafeBestNonlinearTerminationMode -using SciMLBase: SciMLBase, AbstractNonlinearProblem, NonlinearLeastSquaresProblem, - NonlinearProblem, NonlinearFunction, ReturnCode +using SciMLBase: SciMLBase, ReturnCode using StaticArraysCore: StaticArray, SArray, SMatrix, SVector const DI = DifferentiationInterface - -const safe_similar = NonlinearSolveBase.Utils.safe_similar - -pickchunksize(n::Int) = min(n, 12) - -can_dual(::Type{<:Real}) = true -can_dual(::Type) = false - -maybe_unaliased(x::Union{Number, SArray}, ::Bool) = x -function maybe_unaliased(x::T, alias::Bool) where {T <: AbstractArray} - (alias || !ArrayInterface.can_setindex(T)) && return x - return copy(x) -end - -# NOTE: This doesn't initialize the `f(x)` but just returns a buffer of the same size -function get_fx(prob::NonlinearLeastSquaresProblem, x) - if SciMLBase.isinplace(prob) && prob.f.resid_prototype === nothing - error("Inplace NonlinearLeastSquaresProblem requires a `resid_prototype` to be \ - specified.") - end - return get_fx(prob.f, x, prob.p) -end -function get_fx(prob::Union{ImmutableNonlinearProblem, NonlinearProblem}, x) - return get_fx(prob.f, x, prob.p) -end -function get_fx(f::NonlinearFunction, x, p) - if SciMLBase.isinplace(f) - f.resid_prototype === nothing || return eltype(x).(f.resid_prototype) - return safe_similar(x) - end - return f(x, p) -end - -function eval_f(prob, fx, x) - SciMLBase.isinplace(prob) || return prob.f(x, prob.p) - prob.f(fx, x, prob.p) - return fx -end - -function fixed_parameter_function(prob::AbstractNonlinearProblem) - SciMLBase.isinplace(prob) && return @closure (du, u) -> prob.f(du, u, prob.p) - return Base.Fix2(prob.f, prob.p) -end +const NLBUtils = NonlinearSolveBase.Utils function identity_jacobian(u::Number, fu::Number, α = true) return convert(promote_type(eltype(u), eltype(fu)), α) end function identity_jacobian(u, fu, α = true) - J = safe_similar(u, promote_type(eltype(u), eltype(fu)), length(fu), length(u)) - fill!(J, zero(eltype(J))) - J[diagind(J)] .= eltype(J)(α) + J = NLBUtils.safe_similar(u, promote_type(eltype(u), eltype(fu)), length(fu), length(u)) + fill!(J, false) + if ArrayInterface.fast_scalar_indexing(J) + @simd ivdep for i in axes(J, 1) + @inbounds J[i, i] = α + end + else + J[diagind(J)] .= α + end return J end function identity_jacobian(u::StaticArray, fu, α = true) @@ -97,30 +58,21 @@ function check_termination(cache, fx, x, xo, _, ::AbstractSafeNonlinearTerminati return cache(fx, x, xo), cache.retcode, fx, x end function check_termination( - cache, fx, x, xo, prob, ::AbstractSafeBestNonlinearTerminationMode) + cache, fx, x, xo, prob, ::AbstractSafeBestNonlinearTerminationMode +) if cache(fx, x, xo) x = cache.u - if SciMLBase.isinplace(prob) - prob.f(fx, x, prob.p) - else - fx = prob.f(x, prob.p) - end + fx = NLBUtils.evaluate_f!!(prob, fx, x) return true, cache.retcode, fx, x end return false, ReturnCode.Default, fx, x end -restructure(y, x) = ArrayInterface.restructure(y, x) -restructure(::Number, x::Number) = x - -safe_vec(x::AbstractArray) = vec(x) -safe_vec(x::Number) = x - abstract type AbstractJacobianMode end struct AnalyticJacobian <: AbstractJacobianMode end -@concrete struct DIExtras <: AbstractJacobianMode - prep +struct DIExtras{P} <: AbstractJacobianMode + prep::P end struct DINoPreparation <: AbstractJacobianMode end @@ -161,7 +113,7 @@ end function compute_jacobian!!(J, prob, autodiff, fx, x, ::AnalyticJacobian) if J === nothing if SciMLBase.isinplace(prob.f) - J = safe_similar(fx, length(fx), length(x)) + J = NLBUtils.safe_similar(fx, length(fx), length(x)) prob.f.jac(J, x, prob.p) return J else @@ -214,16 +166,16 @@ end function compute_jacobian_and_hessian(autodiff, prob, fx, x) if SciMLBase.isinplace(prob) jac_fn = @closure (u, p) -> begin - du = safe_similar(fx, promote_type(eltype(fx), eltype(u))) + du = NLBUtils.safe_similar(fx, promote_type(eltype(fx), eltype(u))) return DI.jacobian(prob.f, du, autodiff, u, Constant(p)) end J, H = DI.value_and_jacobian(jac_fn, autodiff, x, Constant(prob.p)) - fx = Utils.eval_f(prob, fx, x) + fx = NLBUtils.evaluate_f!!(prob, fx, x) return fx, J, H else jac_fn = @closure (u, p) -> DI.jacobian(prob.f, autodiff, u, Constant(p)) J, H = DI.value_and_jacobian(jac_fn, autodiff, x, Constant(prob.p)) - fx = Utils.eval_f(prob, fx, x) + fx = NLBUtils.evaluate_f!!(prob, fx, x) return fx, J, H end end diff --git a/src/NonlinearSolve.jl b/src/NonlinearSolve.jl index e352f4891..fc6d3f722 100644 --- a/src/NonlinearSolve.jl +++ b/src/NonlinearSolve.jl @@ -71,7 +71,7 @@ end # Rexexports @reexport using SciMLBase, NonlinearSolveBase, LineSearch, ADTypes @reexport using NonlinearSolveFirstOrder, NonlinearSolveSpectralMethods, - NonlinearSolveQuasiNewton, SimpleNonlinearSolve + NonlinearSolveQuasiNewton, SimpleNonlinearSolve, BracketingNonlinearSolve @reexport using LinearSolve # Poly Algorithms