Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: Move dual nonlinear solving to NonlinearSolveBase #513

Merged
merged 10 commits into from
Dec 11, 2024
4 changes: 2 additions & 2 deletions docs/src/basics/faq.md
Original file line number Diff line number Diff line change
Expand Up @@ -152,10 +152,10 @@ nothing # hide
```

And boom! Type stable again. We always recommend picking the chunksize via
[`NonlinearSolve.pickchunksize`](@ref), however, if you manually specify the chunksize, it
[`NonlinearSolveBase.pickchunksize`](@ref), however, if you manually specify the chunksize, it
must be `≤ length of input`. However, a very large chunksize can lead to excessive
compilation times and slowdown.

```@docs
NonlinearSolve.pickchunksize
NonlinearSolveBase.pickchunksize
```
98 changes: 95 additions & 3 deletions lib/NonlinearSolveBase/ext/NonlinearSolveBaseForwardDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,35 @@ module NonlinearSolveBaseForwardDiffExt

using ADTypes: ADTypes, AutoForwardDiff, AutoPolyesterForwardDiff
using ArrayInterface: ArrayInterface
using CommonSolve: solve
using CommonSolve: CommonSolve, solve, solve!, init
using ConcreteStructs: @concrete
using DifferentiationInterface: DifferentiationInterface
using FastClosures: @closure
using ForwardDiff: ForwardDiff, Dual
using ForwardDiff: ForwardDiff, Dual, pickchunksize
using SciMLBase: SciMLBase, AbstractNonlinearProblem, IntervalNonlinearProblem,
NonlinearProblem, NonlinearLeastSquaresProblem, remake

using NonlinearSolveBase: NonlinearSolveBase, ImmutableNonlinearProblem, Utils
using NonlinearSolveBase: NonlinearSolveBase, ImmutableNonlinearProblem, Utils, InternalAPI,
NonlinearSolvePolyAlgorithm, NonlinearSolveForwardDiffCache

const DI = DifferentiationInterface

const GENERAL_SOLVER_TYPES = [
Nothing, NonlinearSolvePolyAlgorithm
]

const DualNonlinearProblem = NonlinearProblem{
<:Union{Number, <:AbstractArray}, iip,
<:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}
} where {iip, T, V, P}
const DualNonlinearLeastSquaresProblem = NonlinearLeastSquaresProblem{
<:Union{Number, <:AbstractArray}, iip,
<:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}
} where {iip, T, V, P}
const DualAbstractNonlinearProblem = Union{
DualNonlinearProblem, DualNonlinearLeastSquaresProblem
}

function NonlinearSolveBase.additional_incompatible_backend_check(
prob::AbstractNonlinearProblem, ::Union{AutoForwardDiff, AutoPolyesterForwardDiff})
return !ForwardDiff.can_dual(eltype(prob.u0))
Expand Down Expand Up @@ -102,4 +120,78 @@ function NonlinearSolveBase.nonlinearsolve_dual_solution(
return map(((uᵢ, pᵢ),) -> Dual{T, V, P}(uᵢ, pᵢ), zip(u, Utils.restructure(u, partials)))
end

for algType in GENERAL_SOLVER_TYPES
@eval function SciMLBase.__solve(
prob::DualAbstractNonlinearProblem, alg::$(algType), args...; kwargs...
)
sol, partials = NonlinearSolveBase.nonlinearsolve_forwarddiff_solve(
prob, alg, args...; kwargs...
)
dual_soln = NonlinearSolveBase.nonlinearsolve_dual_solution(sol.u, partials, prob.p)
return SciMLBase.build_solution(
prob, alg, dual_soln, sol.resid; sol.retcode, sol.stats, sol.original
)
end
end

function InternalAPI.reinit!(
cache::NonlinearSolveForwardDiffCache, args...;
p = cache.p, u0 = NonlinearSolveBase.get_u(cache.cache), kwargs...
)
InternalAPI.reinit!(
cache.cache; p = NonlinearSolveBase.nodual_value(p),
u0 = NonlinearSolveBase.nodual_value(u0), kwargs...
)
cache.p = p
cache.values_p = NonlinearSolveBase.nodual_value(p)
cache.partials_p = ForwardDiff.partials(p)
return cache
end

for algType in GENERAL_SOLVER_TYPES
@eval function SciMLBase.__init(
prob::DualAbstractNonlinearProblem, alg::$(algType), args...; kwargs...
)
p = NonlinearSolveBase.nodual_value(prob.p)
newprob = SciMLBase.remake(prob; u0 = NonlinearSolveBase.nodual_value(prob.u0), p)
cache = init(newprob, alg, args...; kwargs...)
return NonlinearSolveForwardDiffCache(
cache, newprob, alg, prob.p, p, ForwardDiff.partials(prob.p)
)
end
end

function CommonSolve.solve!(cache::NonlinearSolveForwardDiffCache)
sol = solve!(cache.cache)
prob = cache.prob
uu = sol.u

fn = prob isa NonlinearLeastSquaresProblem ?
NonlinearSolveBase.nlls_generate_vjp_function(prob, sol, uu) : prob.f

Jₚ = NonlinearSolveBase.nonlinearsolve_∂f_∂p(prob, fn, uu, cache.values_p)
Jᵤ = NonlinearSolveBase.nonlinearsolve_∂f_∂u(prob, fn, uu, cache.values_p)

z_arr = -Jᵤ \ Jₚ

sumfun = ((z, p),) -> map(zᵢ -> zᵢ * ForwardDiff.partials(p), z)
if cache.p isa Number
partials = sumfun((z_arr, cache.p))
else
partials = sum(sumfun, zip(eachcol(z_arr), cache.p))
end

dual_soln = NonlinearSolveBase.nonlinearsolve_dual_solution(sol.u, partials, cache.p)
return SciMLBase.build_solution(
prob, cache.alg, dual_soln, sol.resid; sol.retcode, sol.stats, sol.original
)
end

NonlinearSolveBase.nodual_value(x) = x
NonlinearSolveBase.nodual_value(x::Dual) = ForwardDiff.value(x)
NonlinearSolveBase.nodual_value(x::AbstractArray{<:Dual}) = map(ForwardDiff.value, x)

@inline NonlinearSolveBase.pickchunksize(x) = pickchunksize(length(x))
@inline NonlinearSolveBase.pickchunksize(x::Int) = ForwardDiff.pickchunksize(x)

end
4 changes: 4 additions & 0 deletions lib/NonlinearSolveBase/src/NonlinearSolveBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ include("descent/geodesic_acceleration.jl")

include("solve.jl")

include("forward_diff.jl")

# Unexported Public API
@compat(public, (L2_NORM, Linf_NORM, NAN_CHECK, UNITLESS_ABS2, get_tolerance))
@compat(public, (nonlinearsolve_forwarddiff_solve, nonlinearsolve_dual_solution))
Expand All @@ -83,4 +85,6 @@ export DescentResult, SteepestDescent, NewtonDescent, DampedNewtonDescent, Dogle

export NonlinearSolvePolyAlgorithm

export pickchunksize

end
8 changes: 8 additions & 0 deletions lib/NonlinearSolveBase/src/forward_diff.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
@concrete mutable struct NonlinearSolveForwardDiffCache <: AbstractNonlinearSolveCache
cache
prob
alg
p
values_p
partials_p
end
9 changes: 9 additions & 0 deletions lib/NonlinearSolveBase/src/public.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,15 @@ function nonlinearsolve_dual_solution end
function nonlinearsolve_∂f_∂p end
function nonlinearsolve_∂f_∂u end
function nlls_generate_vjp_function end
function nodual_value end

"""
pickchunksize(x) = pickchunksize(length(x))
pickchunksize(x::Int)

Determine the chunk size for ForwardDiff and PolyesterForwardDiff based on the input length.
"""
function pickchunksize end

# Nonlinear Solve Termination Conditions
abstract type AbstractNonlinearTerminationMode end
Expand Down
3 changes: 2 additions & 1 deletion lib/NonlinearSolveFirstOrder/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ julia = "1.10"
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
BandedMatrices = "aae01518-5342-5314-be14-df237901396f"
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7"
Hwloc = "0e44f5e4-bd66-52a0-8798-143a42290a1d"
Expand All @@ -86,4 +87,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Aqua", "BandedMatrices", "BenchmarkTools", "Enzyme", "ExplicitImports", "Hwloc", "InteractiveUtils", "LineSearch", "LineSearches", "NonlinearProblemLibrary", "Pkg", "Random", "ReTestItems", "SparseArrays", "SparseConnectivityTracer", "SparseMatrixColorings", "StableRNGs", "StaticArrays", "Test", "Zygote"]
test = ["Aqua", "BandedMatrices", "BenchmarkTools", "ForwardDiff", "Enzyme", "ExplicitImports", "Hwloc", "InteractiveUtils", "LineSearch", "LineSearches", "NonlinearProblemLibrary", "Pkg", "Random", "ReTestItems", "SparseArrays", "SparseConnectivityTracer", "SparseMatrixColorings", "StableRNGs", "StaticArrays", "Test", "Zygote"]
6 changes: 4 additions & 2 deletions lib/NonlinearSolveFirstOrder/src/NonlinearSolveFirstOrder.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,14 @@ using NonlinearSolveBase: NonlinearSolveBase, AbstractNonlinearSolveAlgorithm,
Utils, InternalAPI, get_timer_output, @static_timeit,
update_trace!, L2_NORM, NonlinearSolvePolyAlgorithm,
NewtonDescent, DampedNewtonDescent, GeodesicAcceleration,
Dogleg
Dogleg, NonlinearSolveForwardDiffCache
using SciMLBase: SciMLBase, AbstractNonlinearProblem, NLStats, ReturnCode,
NonlinearFunction,
NonlinearLeastSquaresProblem, NonlinearProblem, NoSpecialize
using SciMLJacobianOperators: VecJacOperator, JacVecOperator, StatefulJacobianOperator

using FiniteDiff: FiniteDiff # Default Finite Difference Method
using ForwardDiff: ForwardDiff # Default Forward Mode AD
using ForwardDiff: ForwardDiff, Dual # Default Forward Mode AD

include("raphson.jl")
include("gauss_newton.jl")
Expand All @@ -41,6 +41,8 @@ include("poly_algs.jl")

include("solve.jl")

include("forward_diff.jl")

@setup_workload begin
nonlinear_functions = (
(NonlinearFunction{false, NoSpecialize}((u, p) -> u .* u .- p), 0.1),
Expand Down
34 changes: 34 additions & 0 deletions lib/NonlinearSolveFirstOrder/src/forward_diff.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
const DualNonlinearProblem = NonlinearProblem{
<:Union{Number, <:AbstractArray}, iip,
<:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}
} where {iip, T, V, P}
const DualNonlinearLeastSquaresProblem = NonlinearLeastSquaresProblem{
<:Union{Number, <:AbstractArray}, iip,
<:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}
} where {iip, T, V, P}
const DualAbstractNonlinearProblem = Union{
DualNonlinearProblem, DualNonlinearLeastSquaresProblem
}

function SciMLBase.__init(
prob::DualAbstractNonlinearProblem, alg::GeneralizedFirstOrderAlgorithm, args...; kwargs...
)
p = NonlinearSolveBase.nodual_value(prob.p)
newprob = SciMLBase.remake(prob; u0 = NonlinearSolveBase.nodual_value(prob.u0), p)
cache = init(newprob, alg, args...; kwargs...)
return NonlinearSolveForwardDiffCache(
cache, newprob, alg, prob.p, p, ForwardDiff.partials(prob.p)
)
end

function SciMLBase.__solve(
prob::DualAbstractNonlinearProblem, alg::GeneralizedFirstOrderAlgorithm, args...; kwargs...
)
sol, partials = NonlinearSolveBase.nonlinearsolve_forwarddiff_solve(
prob, alg, args...; kwargs...
)
dual_soln = NonlinearSolveBase.nonlinearsolve_dual_solution(sol.u, partials, prob.p)
return SciMLBase.build_solution(
prob, alg, dual_soln, sol.resid; sol.retcode, sol.stats, sol.original
)
end
10 changes: 10 additions & 0 deletions lib/NonlinearSolveFirstOrder/test/misc_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,13 @@
@test sol.retcode == ReturnCode.Success
@test jac_calls == 0
end

@testitem "Dual of BigFloat: Issue #512" tags=[:core] begin
using NonlinearSolveFirstOrder, ForwardDiff
fn_iip = NonlinearFunction{true}((du, u, p) -> du .= u .* u .- p)
u2 = [ForwardDiff.Dual(BigFloat(1.0), 5.0), ForwardDiff.Dual(BigFloat(1.0), 5.0),
ForwardDiff.Dual(BigFloat(1.0), 5.0)]
prob_iip_bf = NonlinearProblem{true}(fn_iip, u2, ForwardDiff.Dual(BigFloat(2.0), 5.0))
sol = solve(prob_iip_bf, NewtonRaphson())
@test sol.retcode == ReturnCode.Success
end
6 changes: 6 additions & 0 deletions lib/NonlinearSolveQuasiNewton/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,12 @@ SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
SciMLOperators = "c0aeaf25-5076-4817-a8d5-81caf7dfa961"
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"

[weakdeps]
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"

[extensions]
NonlinearSolveQuasiNewtonForwardDiffExt = "ForwardDiff"

[compat]
ADTypes = "1.9.0"
Aqua = "0.8"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
module NonlinearSolveQuasiNewtonForwardDiffExt

using CommonSolve: CommonSolve, init
using ForwardDiff: ForwardDiff, Dual
using SciMLBase: SciMLBase, NonlinearProblem, NonlinearLeastSquaresProblem

using NonlinearSolveBase: NonlinearSolveBase, NonlinearSolveForwardDiffCache, nodual_value

using NonlinearSolveQuasiNewton: QuasiNewtonAlgorithm

const DualNonlinearProblem = NonlinearProblem{
<:Union{Number, <:AbstractArray}, iip,
<:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}
} where {iip, T, V, P}
const DualNonlinearLeastSquaresProblem = NonlinearLeastSquaresProblem{
<:Union{Number, <:AbstractArray}, iip,
<:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}
} where {iip, T, V, P}
const DualAbstractNonlinearProblem = Union{
DualNonlinearProblem, DualNonlinearLeastSquaresProblem
}

function SciMLBase.__solve(
prob::DualAbstractNonlinearProblem, alg::QuasiNewtonAlgorithm, args...; kwargs...
)
sol, partials = NonlinearSolveBase.nonlinearsolve_forwarddiff_solve(
prob, alg, args...; kwargs...
)
dual_soln = NonlinearSolveBase.nonlinearsolve_dual_solution(sol.u, partials, prob.p)
return SciMLBase.build_solution(
prob, alg, dual_soln, sol.resid; sol.retcode, sol.stats, sol.original
)
end

function SciMLBase.__init(
prob::DualAbstractNonlinearProblem, alg::QuasiNewtonAlgorithm, args...; kwargs...
)
p = nodual_value(prob.p)
newprob = SciMLBase.remake(prob; u0 = nodual_value(prob.u0), p)
cache = init(newprob, alg, args...; kwargs...)
return NonlinearSolveForwardDiffCache(
cache, newprob, alg, prob.p, p, ForwardDiff.partials(prob.p)
)
end

end
7 changes: 7 additions & 0 deletions lib/NonlinearSolveSpectralMethods/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,20 @@ PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"

[weakdeps]
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"

[extensions]
NonlinearSolveSpectralMethodsForwardDiffExt = "ForwardDiff"

[compat]
Aqua = "0.8"
BenchmarkTools = "1.5.0"
CommonSolve = "0.2.4"
ConcreteStructs = "0.2.3"
DiffEqBase = "6.158.3"
ExplicitImports = "1.5"
ForwardDiff = "0.10.36"
Hwloc = "3"
InteractiveUtils = "<0.0.1, 1"
LineSearch = "0.1.4"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
module NonlinearSolveSpectralMethodsForwardDiffExt

using CommonSolve: CommonSolve, init
using ForwardDiff: ForwardDiff, Dual
using SciMLBase: SciMLBase, NonlinearProblem, NonlinearLeastSquaresProblem

using NonlinearSolveBase: NonlinearSolveBase, NonlinearSolveForwardDiffCache, nodual_value

using NonlinearSolveSpectralMethods: GeneralizedDFSane

const DualNonlinearProblem = NonlinearProblem{
<:Union{Number, <:AbstractArray}, iip,
<:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}
} where {iip, T, V, P}
const DualNonlinearLeastSquaresProblem = NonlinearLeastSquaresProblem{
<:Union{Number, <:AbstractArray}, iip,
<:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}
} where {iip, T, V, P}
const DualAbstractNonlinearProblem = Union{
DualNonlinearProblem, DualNonlinearLeastSquaresProblem
}

function SciMLBase.__solve(
prob::DualAbstractNonlinearProblem, alg::GeneralizedDFSane, args...; kwargs...
)
sol, partials = NonlinearSolveBase.nonlinearsolve_forwarddiff_solve(
prob, alg, args...; kwargs...
)
dual_soln = NonlinearSolveBase.nonlinearsolve_dual_solution(sol.u, partials, prob.p)
return SciMLBase.build_solution(
prob, alg, dual_soln, sol.resid; sol.retcode, sol.stats, sol.original
)
end

function SciMLBase.__init(
prob::DualAbstractNonlinearProblem, alg::GeneralizedDFSane, args...; kwargs...
)
p = nodual_value(prob.p)
newprob = SciMLBase.remake(prob; u0 = nodual_value(prob.u0), p)
cache = init(newprob, alg, args...; kwargs...)
return NonlinearSolveForwardDiffCache(
cache, newprob, alg, prob.p, p, ForwardDiff.partials(prob.p)
)
end

end
Loading
Loading