From ce721be9a7c2886c9daa58f23699dd33bbe8ad9f Mon Sep 17 00:00:00 2001 From: Qingyu Qu <2283984853@qq.com> Date: Tue, 3 Dec 2024 00:46:35 +0800 Subject: [PATCH] refactor: Move dual nonlinear solving to NonlinearSolveBase --- .../ext/NonlinearSolveBaseForwardDiffExt.jl | 111 +++++++++++++++++- lib/NonlinearSolveFirstOrder/Project.toml | 3 +- .../test/misc_tests.jl | 10 ++ src/NonlinearSolve.jl | 2 - src/forward_diff.jl | 99 ---------------- 5 files changed, 121 insertions(+), 104 deletions(-) delete mode 100644 src/forward_diff.jl diff --git a/lib/NonlinearSolveBase/ext/NonlinearSolveBaseForwardDiffExt.jl b/lib/NonlinearSolveBase/ext/NonlinearSolveBaseForwardDiffExt.jl index bb3165396..6357549ec 100644 --- a/lib/NonlinearSolveBase/ext/NonlinearSolveBaseForwardDiffExt.jl +++ b/lib/NonlinearSolveBase/ext/NonlinearSolveBaseForwardDiffExt.jl @@ -2,17 +2,36 @@ module NonlinearSolveBaseForwardDiffExt using ADTypes: ADTypes, AutoForwardDiff, AutoPolyesterForwardDiff using ArrayInterface: ArrayInterface -using CommonSolve: solve +using CommonSolve: CommonSolve, solve +using ConcreteStructs: @concrete using DifferentiationInterface: DifferentiationInterface using FastClosures: @closure using ForwardDiff: ForwardDiff, Dual using SciMLBase: SciMLBase, AbstractNonlinearProblem, IntervalNonlinearProblem, NonlinearProblem, NonlinearLeastSquaresProblem, remake -using NonlinearSolveBase: NonlinearSolveBase, ImmutableNonlinearProblem, Utils +using NonlinearSolveBase: NonlinearSolveBase, ImmutableNonlinearProblem, + AbstractNonlinearSolveAlgorithm, Utils, InternalAPI, + AbstractNonlinearSolveCache const DI = DifferentiationInterface +const ALL_SOLVER_TYPES = [ + Nothing, AbstractNonlinearSolveAlgorithm +] + +const DualNonlinearProblem = NonlinearProblem{ + <:Union{Number, <:AbstractArray}, iip, + <:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}} +} where {iip, T, V, P} +const DualNonlinearLeastSquaresProblem = NonlinearLeastSquaresProblem{ + <:Union{Number, <:AbstractArray}, iip, + <:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}} +} where {iip, T, V, P} +const DualAbstractNonlinearProblem = Union{ + DualNonlinearProblem, DualNonlinearLeastSquaresProblem +} + function NonlinearSolveBase.additional_incompatible_backend_check( prob::AbstractNonlinearProblem, ::Union{AutoForwardDiff, AutoPolyesterForwardDiff}) return !ForwardDiff.can_dual(eltype(prob.u0)) @@ -102,4 +121,92 @@ function NonlinearSolveBase.nonlinearsolve_dual_solution( return map(((uᵢ, pᵢ),) -> Dual{T, V, P}(uᵢ, pᵢ), zip(u, Utils.restructure(u, partials))) end +for algType in ALL_SOLVER_TYPES + @eval function SciMLBase.__solve( + prob::DualAbstractNonlinearProblem, alg::$(algType), args...; kwargs... + ) + sol, partials = NonlinearSolveBase.nonlinearsolve_forwarddiff_solve( + prob, alg, args...; kwargs... + ) + dual_soln = NonlinearSolveBase.nonlinearsolve_dual_solution(sol.u, partials, prob.p) + return SciMLBase.build_solution( + prob, alg, dual_soln, sol.resid; sol.retcode, sol.stats, sol.original + ) + end +end + +@concrete mutable struct NonlinearSolveForwardDiffCache <: AbstractNonlinearSolveCache + cache + prob + alg + p + values_p + partials_p +end + +function InternalAPI.reinit!( + cache::NonlinearSolveForwardDiffCache, args...; + p = cache.p, u0 = NonlinearSolveBase.get_u(cache.cache), kwargs... +) + InternalAPI.reinit!( + cache.cache; p = nodual_value(p), u0 = nodual_value(u0), kwargs... + ) + cache.p = p + cache.values_p = nodual_value(p) + cache.partials_p = ForwardDiff.partials(p) + return cache +end + +for algType in ALL_SOLVER_TYPES + @eval function SciMLBase.__init( + prob::DualAbstractNonlinearProblem, alg::$(algType), args...; kwargs... + ) + p = nodual_value(prob.p) + newprob = SciMLBase.remake(prob; u0 = nodual_value(prob.u0), p) + cache = init(newprob, alg, args...; kwargs...) + return NonlinearSolveForwardDiffCache( + cache, newprob, alg, prob.p, p, ForwardDiff.partials(prob.p) + ) + end +end + +function CommonSolve.solve!(cache::NonlinearSolveForwardDiffCache) + sol = solve!(cache.cache) + prob = cache.prob + uu = sol.u + + fn = prob isa NonlinearLeastSquaresProblem ? + NonlinearSolveBase.nlls_generate_vjp_function(prob, sol, uu) : prob.f + + Jₚ = NonlinearSolveBase.nonlinearsolve_∂f_∂p(prob, fn, uu, cache.values_p) + Jᵤ = NonlinearSolveBase.nonlinearsolve_∂f_∂u(prob, fn, uu, cache.values_p) + + z_arr = -Jᵤ \ Jₚ + + sumfun = ((z, p),) -> map(zᵢ -> zᵢ * ForwardDiff.partials(p), z) + if cache.p isa Number + partials = sumfun((z_arr, cache.p)) + else + partials = sum(sumfun, zip(eachcol(z_arr), cache.p)) + end + + dual_soln = NonlinearSolveBase.nonlinearsolve_dual_solution(sol.u, partials, cache.p) + return SciMLBase.build_solution( + prob, cache.alg, dual_soln, sol.resid; sol.retcode, sol.stats, sol.original + ) +end + +nodual_value(x) = x +nodual_value(x::Dual) = ForwardDiff.value(x) +nodual_value(x::AbstractArray{<:Dual}) = map(ForwardDiff.value, x) + +""" + pickchunksize(x) = pickchunksize(length(x)) + pickchunksize(x::Int) + +Determine the chunk size for ForwardDiff and PolyesterForwardDiff based on the input length. +""" +@inline pickchunksize(x) = pickchunksize(length(x)) +@inline pickchunksize(x::Int) = ForwardDiff.pickchunksize(x) + end diff --git a/lib/NonlinearSolveFirstOrder/Project.toml b/lib/NonlinearSolveFirstOrder/Project.toml index ee2d2c9de..c299b6dc1 100644 --- a/lib/NonlinearSolveFirstOrder/Project.toml +++ b/lib/NonlinearSolveFirstOrder/Project.toml @@ -67,6 +67,7 @@ julia = "1.10" Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" BandedMatrices = "aae01518-5342-5314-be14-df237901396f" BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7" Hwloc = "0e44f5e4-bd66-52a0-8798-143a42290a1d" @@ -86,4 +87,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["Aqua", "BandedMatrices", "BenchmarkTools", "Enzyme", "ExplicitImports", "Hwloc", "InteractiveUtils", "LineSearch", "LineSearches", "NonlinearProblemLibrary", "Pkg", "Random", "ReTestItems", "SparseArrays", "SparseConnectivityTracer", "SparseMatrixColorings", "StableRNGs", "StaticArrays", "Test", "Zygote"] +test = ["Aqua", "BandedMatrices", "BenchmarkTools", "ForwardDiff", "Enzyme", "ExplicitImports", "Hwloc", "InteractiveUtils", "LineSearch", "LineSearches", "NonlinearProblemLibrary", "Pkg", "Random", "ReTestItems", "SparseArrays", "SparseConnectivityTracer", "SparseMatrixColorings", "StableRNGs", "StaticArrays", "Test", "Zygote"] diff --git a/lib/NonlinearSolveFirstOrder/test/misc_tests.jl b/lib/NonlinearSolveFirstOrder/test/misc_tests.jl index 40fcb2c55..79c63f37c 100644 --- a/lib/NonlinearSolveFirstOrder/test/misc_tests.jl +++ b/lib/NonlinearSolveFirstOrder/test/misc_tests.jl @@ -20,3 +20,13 @@ @test sol.retcode == ReturnCode.Success @test jac_calls == 0 end + +@testitem "Dual of BigFloat: Issue #512" tags=[:core] begin + using NonlinearSolveFirstOrder, ForwardDiff + fn_iip = NonlinearFunction{true}((du, u, p) -> du .= u .* u .- p) + u2 = [ForwardDiff.Dual(BigFloat(1.0), 5.0), ForwardDiff.Dual(BigFloat(1.0), 5.0), + ForwardDiff.Dual(BigFloat(1.0), 5.0)] + prob_iip_bf = NonlinearProblem{true}(fn_iip, u2, ForwardDiff.Dual(BigFloat(2.0), 5.0)) + sol = solve(prob_iip_bf, NewtonRaphson()) + @test sol.retcode == ReturnCode.Success +end diff --git a/src/NonlinearSolve.jl b/src/NonlinearSolve.jl index 4c44cc972..c6fcc1f12 100644 --- a/src/NonlinearSolve.jl +++ b/src/NonlinearSolve.jl @@ -62,8 +62,6 @@ const ALL_SOLVER_TYPES = [ NonlinearSolvePolyAlgorithm ] -include("forward_diff.jl") - @setup_workload begin nonlinear_functions = ( (NonlinearFunction{false, NoSpecialize}((u, p) -> u .* u .- p), 0.1), diff --git a/src/forward_diff.jl b/src/forward_diff.jl deleted file mode 100644 index 5bb98561c..000000000 --- a/src/forward_diff.jl +++ /dev/null @@ -1,99 +0,0 @@ -const DualNonlinearProblem = NonlinearProblem{ - <:Union{Number, <:AbstractArray}, iip, - <:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}} -} where {iip, T, V, P} -const DualNonlinearLeastSquaresProblem = NonlinearLeastSquaresProblem{ - <:Union{Number, <:AbstractArray}, iip, - <:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}} -} where {iip, T, V, P} -const DualAbstractNonlinearProblem = Union{ - DualNonlinearProblem, DualNonlinearLeastSquaresProblem -} - -for algType in ALL_SOLVER_TYPES - @eval function SciMLBase.__solve( - prob::DualAbstractNonlinearProblem, alg::$(algType), args...; kwargs... - ) - sol, partials = NonlinearSolveBase.nonlinearsolve_forwarddiff_solve( - prob, alg, args...; kwargs... - ) - dual_soln = NonlinearSolveBase.nonlinearsolve_dual_solution(sol.u, partials, prob.p) - return SciMLBase.build_solution( - prob, alg, dual_soln, sol.resid; sol.retcode, sol.stats, sol.original - ) - end -end - -@concrete mutable struct NonlinearSolveForwardDiffCache <: AbstractNonlinearSolveCache - cache - prob - alg - p - values_p - partials_p -end - -function InternalAPI.reinit!( - cache::NonlinearSolveForwardDiffCache, args...; - p = cache.p, u0 = NonlinearSolveBase.get_u(cache.cache), kwargs... -) - InternalAPI.reinit!( - cache.cache; p = nodual_value(p), u0 = nodual_value(u0), kwargs... - ) - cache.p = p - cache.values_p = nodual_value(p) - cache.partials_p = ForwardDiff.partials(p) - return cache -end - -for algType in ALL_SOLVER_TYPES - @eval function SciMLBase.__init( - prob::DualAbstractNonlinearProblem, alg::$(algType), args...; kwargs... - ) - p = nodual_value(prob.p) - newprob = SciMLBase.remake(prob; u0 = nodual_value(prob.u0), p) - cache = init(newprob, alg, args...; kwargs...) - return NonlinearSolveForwardDiffCache( - cache, newprob, alg, prob.p, p, ForwardDiff.partials(prob.p) - ) - end -end - -function CommonSolve.solve!(cache::NonlinearSolveForwardDiffCache) - sol = solve!(cache.cache) - prob = cache.prob - uu = sol.u - - fn = prob isa NonlinearLeastSquaresProblem ? - NonlinearSolveBase.nlls_generate_vjp_function(prob, sol, uu) : prob.f - - Jₚ = NonlinearSolveBase.nonlinearsolve_∂f_∂p(prob, fn, uu, cache.values_p) - Jᵤ = NonlinearSolveBase.nonlinearsolve_∂f_∂u(prob, fn, uu, cache.values_p) - - z_arr = -Jᵤ \ Jₚ - - sumfun = ((z, p),) -> map(zᵢ -> zᵢ * ForwardDiff.partials(p), z) - if cache.p isa Number - partials = sumfun((z_arr, cache.p)) - else - partials = sum(sumfun, zip(eachcol(z_arr), cache.p)) - end - - dual_soln = NonlinearSolveBase.nonlinearsolve_dual_solution(sol.u, partials, cache.p) - return SciMLBase.build_solution( - prob, cache.alg, dual_soln, sol.resid; sol.retcode, sol.stats, sol.original - ) -end - -nodual_value(x) = x -nodual_value(x::Dual) = ForwardDiff.value(x) -nodual_value(x::AbstractArray{<:Dual}) = map(ForwardDiff.value, x) - -""" - pickchunksize(x) = pickchunksize(length(x)) - pickchunksize(x::Int) - -Determine the chunk size for ForwardDiff and PolyesterForwardDiff based on the input length. -""" -@inline pickchunksize(x) = pickchunksize(length(x)) -@inline pickchunksize(x::Int) = ForwardDiff.pickchunksize(x)