Skip to content

Commit

Permalink
refactor: Move dual nonlinear solving to NonlinearSolveBase
Browse files Browse the repository at this point in the history
  • Loading branch information
ErikQQY committed Dec 2, 2024
1 parent c340cdc commit ce721be
Show file tree
Hide file tree
Showing 5 changed files with 121 additions and 104 deletions.
111 changes: 109 additions & 2 deletions lib/NonlinearSolveBase/ext/NonlinearSolveBaseForwardDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,36 @@ module NonlinearSolveBaseForwardDiffExt

using ADTypes: ADTypes, AutoForwardDiff, AutoPolyesterForwardDiff
using ArrayInterface: ArrayInterface
using CommonSolve: solve
using CommonSolve: CommonSolve, solve
using ConcreteStructs: @concrete
using DifferentiationInterface: DifferentiationInterface
using FastClosures: @closure
using ForwardDiff: ForwardDiff, Dual
using SciMLBase: SciMLBase, AbstractNonlinearProblem, IntervalNonlinearProblem,
NonlinearProblem, NonlinearLeastSquaresProblem, remake

using NonlinearSolveBase: NonlinearSolveBase, ImmutableNonlinearProblem, Utils
using NonlinearSolveBase: NonlinearSolveBase, ImmutableNonlinearProblem,
AbstractNonlinearSolveAlgorithm, Utils, InternalAPI,
AbstractNonlinearSolveCache

const DI = DifferentiationInterface

const ALL_SOLVER_TYPES = [
Nothing, AbstractNonlinearSolveAlgorithm
]

const DualNonlinearProblem = NonlinearProblem{
<:Union{Number, <:AbstractArray}, iip,
<:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}
} where {iip, T, V, P}
const DualNonlinearLeastSquaresProblem = NonlinearLeastSquaresProblem{
<:Union{Number, <:AbstractArray}, iip,
<:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}
} where {iip, T, V, P}
const DualAbstractNonlinearProblem = Union{
DualNonlinearProblem, DualNonlinearLeastSquaresProblem
}

function NonlinearSolveBase.additional_incompatible_backend_check(
prob::AbstractNonlinearProblem, ::Union{AutoForwardDiff, AutoPolyesterForwardDiff})
return !ForwardDiff.can_dual(eltype(prob.u0))
Expand Down Expand Up @@ -102,4 +121,92 @@ function NonlinearSolveBase.nonlinearsolve_dual_solution(
return map(((uᵢ, pᵢ),) -> Dual{T, V, P}(uᵢ, pᵢ), zip(u, Utils.restructure(u, partials)))
end

for algType in ALL_SOLVER_TYPES
@eval function SciMLBase.__solve(
prob::DualAbstractNonlinearProblem, alg::$(algType), args...; kwargs...
)
sol, partials = NonlinearSolveBase.nonlinearsolve_forwarddiff_solve(
prob, alg, args...; kwargs...
)
dual_soln = NonlinearSolveBase.nonlinearsolve_dual_solution(sol.u, partials, prob.p)
return SciMLBase.build_solution(
prob, alg, dual_soln, sol.resid; sol.retcode, sol.stats, sol.original
)
end
end

@concrete mutable struct NonlinearSolveForwardDiffCache <: AbstractNonlinearSolveCache
cache
prob
alg
p
values_p
partials_p
end

function InternalAPI.reinit!(
cache::NonlinearSolveForwardDiffCache, args...;
p = cache.p, u0 = NonlinearSolveBase.get_u(cache.cache), kwargs...
)
InternalAPI.reinit!(
cache.cache; p = nodual_value(p), u0 = nodual_value(u0), kwargs...
)
cache.p = p
cache.values_p = nodual_value(p)
cache.partials_p = ForwardDiff.partials(p)
return cache
end

for algType in ALL_SOLVER_TYPES
@eval function SciMLBase.__init(
prob::DualAbstractNonlinearProblem, alg::$(algType), args...; kwargs...
)
p = nodual_value(prob.p)
newprob = SciMLBase.remake(prob; u0 = nodual_value(prob.u0), p)
cache = init(newprob, alg, args...; kwargs...)
return NonlinearSolveForwardDiffCache(
cache, newprob, alg, prob.p, p, ForwardDiff.partials(prob.p)
)
end
end

function CommonSolve.solve!(cache::NonlinearSolveForwardDiffCache)
sol = solve!(cache.cache)
prob = cache.prob
uu = sol.u

fn = prob isa NonlinearLeastSquaresProblem ?
NonlinearSolveBase.nlls_generate_vjp_function(prob, sol, uu) : prob.f

Jₚ = NonlinearSolveBase.nonlinearsolve_∂f_∂p(prob, fn, uu, cache.values_p)
Jᵤ = NonlinearSolveBase.nonlinearsolve_∂f_∂u(prob, fn, uu, cache.values_p)

z_arr = -Jᵤ \ Jₚ

sumfun = ((z, p),) -> map(zᵢ -> zᵢ * ForwardDiff.partials(p), z)
if cache.p isa Number
partials = sumfun((z_arr, cache.p))
else
partials = sum(sumfun, zip(eachcol(z_arr), cache.p))
end

dual_soln = NonlinearSolveBase.nonlinearsolve_dual_solution(sol.u, partials, cache.p)
return SciMLBase.build_solution(
prob, cache.alg, dual_soln, sol.resid; sol.retcode, sol.stats, sol.original
)
end

nodual_value(x) = x
nodual_value(x::Dual) = ForwardDiff.value(x)
nodual_value(x::AbstractArray{<:Dual}) = map(ForwardDiff.value, x)

"""
pickchunksize(x) = pickchunksize(length(x))
pickchunksize(x::Int)
Determine the chunk size for ForwardDiff and PolyesterForwardDiff based on the input length.
"""
@inline pickchunksize(x) = pickchunksize(length(x))
@inline pickchunksize(x::Int) = ForwardDiff.pickchunksize(x)

end
3 changes: 2 additions & 1 deletion lib/NonlinearSolveFirstOrder/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ julia = "1.10"
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
BandedMatrices = "aae01518-5342-5314-be14-df237901396f"
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7"
Hwloc = "0e44f5e4-bd66-52a0-8798-143a42290a1d"
Expand All @@ -86,4 +87,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Aqua", "BandedMatrices", "BenchmarkTools", "Enzyme", "ExplicitImports", "Hwloc", "InteractiveUtils", "LineSearch", "LineSearches", "NonlinearProblemLibrary", "Pkg", "Random", "ReTestItems", "SparseArrays", "SparseConnectivityTracer", "SparseMatrixColorings", "StableRNGs", "StaticArrays", "Test", "Zygote"]
test = ["Aqua", "BandedMatrices", "BenchmarkTools", "ForwardDiff", "Enzyme", "ExplicitImports", "Hwloc", "InteractiveUtils", "LineSearch", "LineSearches", "NonlinearProblemLibrary", "Pkg", "Random", "ReTestItems", "SparseArrays", "SparseConnectivityTracer", "SparseMatrixColorings", "StableRNGs", "StaticArrays", "Test", "Zygote"]
10 changes: 10 additions & 0 deletions lib/NonlinearSolveFirstOrder/test/misc_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,13 @@
@test sol.retcode == ReturnCode.Success
@test jac_calls == 0
end

@testitem "Dual of BigFloat: Issue #512" tags=[:core] begin
using NonlinearSolveFirstOrder, ForwardDiff
fn_iip = NonlinearFunction{true}((du, u, p) -> du .= u .* u .- p)
u2 = [ForwardDiff.Dual(BigFloat(1.0), 5.0), ForwardDiff.Dual(BigFloat(1.0), 5.0),
ForwardDiff.Dual(BigFloat(1.0), 5.0)]
prob_iip_bf = NonlinearProblem{true}(fn_iip, u2, ForwardDiff.Dual(BigFloat(2.0), 5.0))
sol = solve(prob_iip_bf, NewtonRaphson())
@test sol.retcode == ReturnCode.Success
end
2 changes: 0 additions & 2 deletions src/NonlinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,6 @@ const ALL_SOLVER_TYPES = [
NonlinearSolvePolyAlgorithm
]

include("forward_diff.jl")

@setup_workload begin
nonlinear_functions = (
(NonlinearFunction{false, NoSpecialize}((u, p) -> u .* u .- p), 0.1),
Expand Down
99 changes: 0 additions & 99 deletions src/forward_diff.jl

This file was deleted.

0 comments on commit ce721be

Please sign in to comment.