Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Halley and Householder to SimpleNonlinearSolve #507

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion lib/SimpleNonlinearSolve/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,14 @@ StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
TaylorDiff = "b36ab563-344f-407b-a36a-4f200bebf99c"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"

[extensions]
SimpleNonlinearSolveChainRulesCoreExt = "ChainRulesCore"
SimpleNonlinearSolveDiffEqBaseExt = "DiffEqBase"
SimpleNonlinearSolveReverseDiffExt = "ReverseDiff"
SimpleNonlinearSolveTaylorDiffExt = "TaylorDiff"
SimpleNonlinearSolveTrackerExt = "Tracker"

[compat]
Expand Down Expand Up @@ -66,6 +68,7 @@ SciMLBase = "2.58"
Setfield = "1.1.1"
StaticArrays = "1.9"
StaticArraysCore = "1.4.3"
TaylorDiff = "0.3"
Test = "1.10"
TestItemRunner = "1"
Tracker = "0.2.35"
Expand All @@ -84,10 +87,11 @@ PolyesterForwardDiff = "98d1487c-24ca-40b6-b7ab-df2af84e126b"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
TaylorDiff = "b36ab563-344f-407b-a36a-4f200bebf99c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
TestItemRunner = "f8b46487-2199-4994-9208-9a1283c18c0a"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Aqua", "DiffEqBase", "Enzyme", "ExplicitImports", "InteractiveUtils", "NonlinearProblemLibrary", "Pkg", "PolyesterForwardDiff", "Random", "ReverseDiff", "StaticArrays", "Test", "TestItemRunner", "Tracker", "Zygote"]
test = ["Aqua", "DiffEqBase", "Enzyme", "ExplicitImports", "InteractiveUtils", "NonlinearProblemLibrary", "Pkg", "PolyesterForwardDiff", "Random", "ReverseDiff", "StaticArrays", "TaylorDiff", "Test", "TestItemRunner", "Tracker", "Zygote"]
74 changes: 74 additions & 0 deletions lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveTaylorDiffExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
module SimpleNonlinearSolveTaylorDiffExt
using SimpleNonlinearSolve: SimpleNonlinearSolve, SimpleHouseholder, Utils
using NonlinearSolveBase: NonlinearSolveBase, ImmutableNonlinearProblem,
AbstractNonlinearSolveAlgorithm
using MaybeInplace: @bb
using FastClosures: @closure
import SciMLBase
import TaylorDiff

SimpleNonlinearSolve.is_extension_loaded(::Val{:TaylorDiff}) = true

const NLBUtils = NonlinearSolveBase.Utils

@inline function __get_higher_order_derivatives(
::SimpleHouseholder{N}, prob, x, fx) where {N}
vN = Val(N)
l = map(one, x)
t = TaylorDiff.make_seed(x, l, vN)

if SciMLBase.isinplace(prob)
bundle = similar(fx, TaylorDiff.TaylorScalar{eltype(fx), N})
prob.f(bundle, t, prob.p)
map!(TaylorDiff.value, fx, bundle)
else
bundle = prob.f(t, prob.p)
fx = map(TaylorDiff.value, bundle)
end
invbundle = inv.(bundle)
num = N == 1 ? map(TaylorDiff.value, invbundle) :
TaylorDiff.extract_derivative(invbundle, Val(N - 1))
den = TaylorDiff.extract_derivative(invbundle, vN)
return num, den, fx
end

function SciMLBase.__solve(prob::ImmutableNonlinearProblem, alg::SimpleHouseholder{N},
args...; abstol = nothing, reltol = nothing, maxiters = 1000,
termination_condition = nothing, alias_u0 = false, kwargs...) where {N}
length(prob.u0) == 1 ||
throw(ArgumentError("SimpleHouseholder only supports scalar problems"))
x = NLBUtils.maybe_unaliased(prob.u0, alias_u0)
fx = NLBUtils.evaluate_f(prob, x)

iszero(fx) &&
return SciMLBase.build_solution(prob, alg, x, fx; retcode = ReturnCode.Success)

abstol, reltol, tc_cache = NonlinearSolveBase.init_termination_cache(
prob, abstol, reltol, fx, x, termination_condition, Val(:simple))

@bb xo = similar(x)

for i in 1:maxiters
@bb copyto!(xo, x)
num, den, fx = __get_higher_order_derivatives(alg, prob, x, fx)
@bb x .+= N .* num ./ den
solved, retcode, fx_sol, x_sol = Utils.check_termination(tc_cache, fx, x, xo, prob)
solved && return SciMLBase.build_solution(prob, alg, x_sol, fx_sol; retcode)
end

return SciMLBase.build_solution(prob, alg, x, fx; retcode = ReturnCode.MaxIters)
end

function SimpleNonlinearSolve.evaluate_hvvp_internal(
hvvp, prob::ImmutableNonlinearProblem, u, a)
if SciMLBase.isinplace(prob)
binary_f = @closure (y, x) -> prob.f(y, x, prob.p)
TaylorDiff.derivative!(hvvp, binary_f, cache.fu, u, a, Val(2))
else
unary_f = Base.Fix2(prob.f, prob.p)
hvvp = TaylorDiff.derivative(unary_f, u, a, Val(2))
end
hvvp
end

end
10 changes: 9 additions & 1 deletion lib/SimpleNonlinearSolve/src/SimpleNonlinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ include("utils.jl")
include("broyden.jl")
include("dfsane.jl")
include("halley.jl")
include("householder.jl")
include("klement.jl")
include("lbroyden.jl")
include("raphson.jl")
Expand Down Expand Up @@ -128,6 +129,13 @@ end

function solve_adjoint_internal end

function evaluate_hvvp(args...; kws...)
is_extension_loaded(Val(:TaylorDiff)) && return evaluate_hvvp_internal(args...; kws...)
error("Halley's method with Taylor mode requires `TaylorDiff.jl` to be explicitly loaded.")
end

function evaluate_hvvp_internal end

@setup_workload begin
for T in (Float64,)
prob_scalar = NonlinearProblem{false}((u, p) -> u .* u .- p, T(0.1), T(2))
Expand Down Expand Up @@ -161,7 +169,7 @@ end
export SimpleBroyden, SimpleKlement, SimpleLimitedMemoryBroyden
export SimpleDFSane
export SimpleGaussNewton, SimpleNewtonRaphson, SimpleTrustRegion
export SimpleHalley
export SimpleHalley, SimpleHouseholder

export solve

Expand Down
30 changes: 24 additions & 6 deletions lib/SimpleNonlinearSolve/src/halley.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ A low-overhead implementation of Halley's Method.
- `autodiff`: determines the backend used for the Jacobian. Defaults to `nothing` (i.e.
automatic backend selection). Valid choices include jacobian backends from
`DifferentiationInterface.jl`.
In addition, `AutoTaylorDiff` can be used to enable Taylor mode for computing the Hessian-vector-vector product more efficiently; in this case, the Jacobian would still be calculated using the default backend. You need to have `TaylorDiff.jl` loaded to use this option.
"""
@kwdef @concrete struct SimpleHalley <: AbstractSimpleNonlinearSolveAlgorithm
autodiff = nothing
Expand All @@ -38,6 +39,7 @@ function SciMLBase.__solve(

# The way we write the 2nd order derivatives, we know Enzyme won't work there
autodiff = alg.autodiff === nothing ? AutoForwardDiff() : alg.autodiff
jac_autodiff = autodiff === AutoTaylorDiff() ? AutoForwardDiff() : autodiff
@set! alg.autodiff = autodiff

@bb xo = copy(x)
Expand All @@ -50,8 +52,19 @@ function SciMLBase.__solve(
A, Aaᵢ, cᵢ = x, x, x
end

fx_cache = (SciMLBase.isinplace(prob) && !SciMLBase.has_jac(prob.f)) ?
NLBUtils.safe_similar(fx) : fx
jac_cache = Utils.prepare_jacobian(prob, jac_autodiff, fx_cache, x)
J = Utils.compute_jacobian!!(nothing, prob, jac_autodiff, fx_cache, x, jac_cache)

for _ in 1:maxiters
fx, J, H = Utils.compute_jacobian_and_hessian(autodiff, prob, fx, x)
if autodiff isa AutoTaylorDiff
fx = NLBUtils.evaluate_f!!(prob, fx, x)
J = Utils.compute_jacobian!!(J, prob, jac_autodiff, fx_cache, x, jac_cache)
H = nothing
else
fx, J, H = Utils.compute_jacobian_and_hessian(jac_autodiff, prob, fx, x)
end

NLBUtils.can_setindex(x) || (A = J)

Expand All @@ -67,12 +80,17 @@ function SciMLBase.__solve(
end

aᵢ = J_fact \ NLBUtils.safe_vec(fx)
A_ = NLBUtils.safe_vec(A)
@bb A_ = H × aᵢ
A = NLBUtils.restructure(A, A_)

@bb Aaᵢ = A × aᵢ
@bb A .*= -1
if autodiff isa AutoTaylorDiff
Aaᵢ = evaluate_hvvp(Aaᵢ, prob, x, typeof(x)(aᵢ))
else
A_ = NLBUtils.safe_vec(A)
@bb A_ = H × aᵢ
A = NLBUtils.restructure(A, A_)

@bb Aaᵢ = A × aᵢ
@bb A .*= -1
end
bᵢ = J_fact \ NLBUtils.safe_vec(Aaᵢ)

cᵢ_ = NLBUtils.safe_vec(cᵢ)
Expand Down
16 changes: 16 additions & 0 deletions lib/SimpleNonlinearSolve/src/householder.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
"""
SimpleHouseholder{order}()

A low-overhead implementation of Householder's method to arbitrary order.
This method is non-allocating on scalar and static array problems.

!!! warning

Needs `TaylorDiff.jl` to be explicitly loaded before using this functionality.
Internally, this uses TaylorDiff.jl for automatic differentiation.

### Type Parameters

- `order`: the order of the Householder method. `order = 1` is the same as Newton's method, `order = 2` is the same as Halley's method, etc.
"""
struct SimpleHouseholder{order} <: AbstractSimpleNonlinearSolveAlgorithm end
32 changes: 30 additions & 2 deletions lib/SimpleNonlinearSolve/test/core/rootfind_tests.jl
tansongchen marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
@testsnippet RootfindTestSnippet begin
using StaticArrays, Random, LinearAlgebra, ForwardDiff, NonlinearSolveBase, SciMLBase
using ADTypes, PolyesterForwardDiff, Enzyme, ReverseDiff
import TaylorDiff

quadratic_f(u, p) = u .* u .- p
quadratic_f!(du, u, p) = (du .= u .* u .- p)
Expand Down Expand Up @@ -81,11 +82,13 @@ end
AutoForwardDiff(),
AutoFiniteDiff(),
AutoReverseDiff(),
AutoTaylorDiff(),
nothing
)
@testset "[OOP] u0: $(typeof(u0))" for u0 in (
[1.0, 1.0], @SVector[1.0, 1.0], 1.0)
sol = run_nlsolve_oop(quadratic_f, u0; solver = alg(; autodiff))
sol = run_nlsolve_oop(
quadratic_f, u0; solver = alg(; autodiff))
@test SciMLBase.successful_retcode(sol)
@test maximum(abs, quadratic_f(sol.u, 2.0)) < 1e-9
end
Expand All @@ -96,7 +99,32 @@ end

probN = NonlinearProblem(quadratic_f, u0, 2.0)
@test all(solve(
probN, alg(; autodiff = AutoForwardDiff()); termination_condition).u .≈
probN, alg(; autodiff = AutoTaylorDiff());
termination_condition).u .≈
sqrt(2.0))
end
end
end

@testitem "Higher Order Methods" setup=[RootfindTestSnippet] tags=[:core] begin
@testset for alg in (
SimpleHouseholder,
)
@testset for order in (1, 2, 3, 4)
@testset "[OOP] u0: $(typeof(u0))" for u0 in (
[1.0], @SVector[1.0], 1.0)
sol = run_nlsolve_oop(quadratic_f, u0; solver = alg{order}())
@test SciMLBase.successful_retcode(sol)
@test maximum(abs, quadratic_f(sol.u, 2.0)) < 1e-9
end
end

@testset "Termination Condition: $(nameof(typeof(termination_condition))) u0: $(nameof(typeof(u0)))" for termination_condition in TERMINATION_CONDITIONS,
u0 in (1.0, [1.0], @SVector[1.0])

probN = NonlinearProblem(quadratic_f, u0, 2.0)
@test all(solve(
probN, alg{2}(); termination_condition).u .≈
sqrt(2.0))
end
end
Expand Down
Loading