Skip to content

Commit

Permalink
Make it a ext
Browse files Browse the repository at this point in the history
  • Loading branch information
tansongchen committed Nov 1, 2024
1 parent 637c3e8 commit fa47b24
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 24 deletions.
7 changes: 4 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
SparseMatrixColorings = "0a514795-09f3-496d-8182-132a7b665d35"
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
TaylorDiff = "b36ab563-344f-407b-a36a-4f200bebf99c"
TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"

[weakdeps]
Expand All @@ -49,6 +48,7 @@ PETSc = "ace2c81b-2b5f-4b1e-a30d-d662738edfe0"
SIAMFANLEquations = "084e46ad-d928-497d-ad5e-07fa361a48c4"
SpeedMapping = "f1835b91-879b-4a3f-a438-e4baacf14412"
Sundials = "c3572dad-4567-51f8-b174-8c6c989267f4"
TaylorDiff = "b36ab563-344f-407b-a36a-4f200bebf99c"

[extensions]
NonlinearSolveBandedMatricesExt = "BandedMatrices"
Expand All @@ -62,6 +62,7 @@ NonlinearSolvePETScExt = ["PETSc", "MPI"]
NonlinearSolveSIAMFANLEquationsExt = "SIAMFANLEquations"
NonlinearSolveSpeedMappingExt = "SpeedMapping"
NonlinearSolveSundialsExt = "Sundials"
NonlinearSolveTaylorDiffExt = "TaylorDiff"

[compat]
ADTypes = "1.9"
Expand Down Expand Up @@ -121,7 +122,6 @@ StaticArrays = "1.9"
StaticArraysCore = "1.4"
Sundials = "4.23.1"
SymbolicIndexingInterface = "0.3.31"
Symbolics = "6"
TaylorDiff = "0.3"
Test = "1.10"
TimerOutputs = "0.5.23"
Expand Down Expand Up @@ -157,8 +157,9 @@ SpeedMapping = "f1835b91-879b-4a3f-a438-e4baacf14412"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Sundials = "c3572dad-4567-51f8-b174-8c6c989267f4"
TaylorDiff = "b36ab563-344f-407b-a36a-4f200bebf99c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Aqua", "BandedMatrices", "BenchmarkTools", "CUDA", "Enzyme", "ExplicitImports", "FastLevenbergMarquardt", "FixedPointAcceleration", "Hwloc", "InteractiveUtils", "LeastSquaresOptim", "LineSearches", "MINPACK", "NLSolvers", "NLsolve", "NaNMath", "NonlinearProblemLibrary", "OrdinaryDiffEqTsit5", "PETSc", "Pkg", "Random", "ReTestItems", "SIAMFANLEquations", "SparseConnectivityTracer", "SpeedMapping", "StableRNGs", "StaticArrays", "Sundials", "Test", "Zygote"]
test = ["Aqua", "BandedMatrices", "BenchmarkTools", "CUDA", "Enzyme", "ExplicitImports", "FastLevenbergMarquardt", "FixedPointAcceleration", "Hwloc", "InteractiveUtils", "LeastSquaresOptim", "LineSearches", "MINPACK", "NLSolvers", "NLsolve", "NaNMath", "NonlinearProblemLibrary", "OrdinaryDiffEqTsit5", "PETSc", "Pkg", "Random", "ReTestItems", "SIAMFANLEquations", "SparseConnectivityTracer", "SpeedMapping", "StableRNGs", "StaticArrays", "Sundials", "TaylorDiff", "Test", "Zygote"]
19 changes: 19 additions & 0 deletions ext/NonlinearSolveTaylorDiffExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
module NonlinearSolveTaylorDiffExt
using NonlinearSolve: HalleyDescentCache, NonlinearFunction
import NonlinearSolve: evaluate_hvvp
using TaylorDiff: derivative, derivative!
using FastClosures: @closure

function evaluate_hvvp(
hvvp, cache::HalleyDescentCache, f::NonlinearFunction{iip}, p, u, δu) where {iip}
if iip
binary_f = @closure (y, x) -> f(y, x, p)
derivative!(hvvp, binary_f, cache.fu, u, δu, Val(2))
else
unary_f = Base.Fix2(f, p)
hvvp = derivative(unary_f, u, δu, Val(2))
end
hvvp
end

end
9 changes: 4 additions & 5 deletions src/algorithms/halley.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
Halley(; concrete_jac = nothing, linsolve = nothing, linesearch = NoLineSearch(),
Halley(; concrete_jac = nothing, linsolve = nothing, linesearch = nothing,
precs = DEFAULT_PRECS, autodiff = nothing)
An experimental Halley's method implementation. Improves the convergence rate of Newton's method by using second-order derivative information to correct the descent direction.
Expand All @@ -8,8 +8,7 @@ Currently depends on TaylorDiff.jl to handle the correction terms,
might have more general implementation in the future.
"""
function Halley(; concrete_jac = nothing, linsolve = nothing,
linesearch = NoLineSearch(), precs = DEFAULT_PRECS, autodiff = nothing)
descent = HalleyDescent(; linsolve, precs)
return GeneralizedFirstOrderAlgorithm(;
concrete_jac, name = :Halley, linesearch, descent, jacobian_ad = autodiff)
linesearch = nothing, precs = DEFAULT_PRECS, autodiff = nothing)
return GeneralizedFirstOrderAlgorithm{concrete_jac, :Halley}(;
linesearch, descent = HalleyDescent(; linsolve, precs), autodiff)
end
25 changes: 9 additions & 16 deletions src/descent/halley.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,15 @@ Improve the NewtonDescent with higher-order terms. First compute the descent dir
Then compute the hessian-vector-vector product and solve for the second-order correction term as ``J b = H a a``.
Finally, compute the descent direction as ``δu = a * a / (b / 2 - a)``.
Note that `import TaylorDiff` is required to use this descent algorithm.
See also [`NewtonDescent`](@ref).
"""
@kwdef @concrete struct HalleyDescent <: AbstractDescentAlgorithm
linsolve = nothing
precs = DEFAULT_PRECS
end

using TaylorDiff: derivative

function Base.show(io::IO, d::HalleyDescent)
modifiers = String[]
d.linsolve !== nothing && push!(modifiers, "linsolve = $(d.linsolve)")
Expand All @@ -30,6 +30,7 @@ supports_line_search(::HalleyDescent) = true
δus
b
fu
hvvp
lincache
timer
end
Expand All @@ -43,13 +44,14 @@ function __internal_init(prob::NonlinearProblem, alg::HalleyDescent, J, fu, u; s
@bb δu = similar(u)
@bb b = similar(u)
@bb fu = similar(fu)
@bb hvvp = similar(fu)
δus = N 1 ? nothing : map(2:N) do i
@bb δu_ = similar(u)
end
INV && return HalleyDescentCache{true}(prob.f, prob.p, δu, δus, b, nothing, timer)
lincache = LinearSolverCache(
lincache = INV ? nothing :
LinearSolverCache(
alg, alg.linsolve, J, _vec(fu), _vec(u); stats, abstol, reltol, linsolve_kwargs...)
return HalleyDescentCache{false}(prob.f, prob.p, δu, δus, b, fu, lincache, timer)
return HalleyDescentCache{false}(prob.f, prob.p, δu, δus, b, fu, hvvp, lincache, timer)
end

function __internal_solve!(cache::HalleyDescentCache{INV}, J, fu, u, idx::Val = Val(1);
Expand All @@ -73,7 +75,7 @@ function __internal_solve!(cache::HalleyDescentCache{INV}, J, fu, u, idx::Val =
end
b = cache.b
# compute the hessian-vector-vector product
hvvp = evaluate_hvvp(cache, cache.f, cache.p, u, δu)
hvvp = evaluate_hvvp(cache.hvvp, cache, cache.f, cache.p, u, δu)
# second linear solve, reuse factorization if possible
if INV
@bb b = J × vec(hvvp)
Expand All @@ -94,13 +96,4 @@ function __internal_solve!(cache::HalleyDescentCache{INV}, J, fu, u, idx::Val =
return DescentResult(; δu)
end

function evaluate_hvvp(
cache::HalleyDescentCache, f::NonlinearFunction{iip}, p, u, δu) where {iip}
if iip
binary_f = @closure (y, x) -> f(y, x, p)
derivative(binary_f, cache.fu, u, δu, Val{3}())
else
unary_f = Base.Fix2(f, p)
derivative(unary_f, u, δu, Val{3}())
end
end
evaluate_hvvp(hvvp, cache, f, p, u, δu) = error("not implemented. please import TaylorDiff")
1 change: 1 addition & 0 deletions test/core/23_test_problems_tests.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
@testsetup module RobustnessTesting
using NonlinearSolve, LinearAlgebra, LinearSolve, NonlinearProblemLibrary, Test
import TaylorDiff

problems = NonlinearProblemLibrary.problems
dicts = NonlinearProblemLibrary.dicts
Expand Down

0 comments on commit fa47b24

Please sign in to comment.