From 6b0002b7b8cda2524a6211663d03fa64df42fa17 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 1 Nov 2024 14:51:53 -0400 Subject: [PATCH] fix: hessian (#489) * fix: hessian through nonlinear solvers * feat: extend gradient support for cached nlls --- Project.toml | 2 +- lib/NonlinearSolveBase/Project.toml | 2 +- .../ext/NonlinearSolveBaseForwardDiffExt.jl | 100 ++---------------- .../src/NonlinearSolveBase.jl | 4 +- lib/NonlinearSolveBase/src/autodiff.jl | 62 +++++++++++ lib/NonlinearSolveBase/src/public.jl | 1 + src/forward_diff.jl | 12 ++- test/forward_ad_tests.jl | 94 ++++++++++++++++ 8 files changed, 177 insertions(+), 100 deletions(-) diff --git a/Project.toml b/Project.toml index 845611499..93643123b 100644 --- a/Project.toml +++ b/Project.toml @@ -89,7 +89,7 @@ NLSolvers = "0.5" NLsolve = "4.5" NaNMath = "1" NonlinearProblemLibrary = "0.1.2" -NonlinearSolveBase = "1" +NonlinearSolveBase = "1.2" NonlinearSolveFirstOrder = "1" NonlinearSolveQuasiNewton = "1" NonlinearSolveSpectralMethods = "1" diff --git a/lib/NonlinearSolveBase/Project.toml b/lib/NonlinearSolveBase/Project.toml index 4abb81cde..e7ed5851f 100644 --- a/lib/NonlinearSolveBase/Project.toml +++ b/lib/NonlinearSolveBase/Project.toml @@ -1,7 +1,7 @@ name = "NonlinearSolveBase" uuid = "be0214bd-f91f-a760-ac4e-3421ce2b2da0" authors = ["Avik Pal and contributors"] -version = "1.1.0" +version = "1.2.0" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/lib/NonlinearSolveBase/ext/NonlinearSolveBaseForwardDiffExt.jl b/lib/NonlinearSolveBase/ext/NonlinearSolveBaseForwardDiffExt.jl index 0b16391c4..bb3165396 100644 --- a/lib/NonlinearSolveBase/ext/NonlinearSolveBaseForwardDiffExt.jl +++ b/lib/NonlinearSolveBase/ext/NonlinearSolveBaseForwardDiffExt.jl @@ -6,7 +6,6 @@ using CommonSolve: solve using DifferentiationInterface: DifferentiationInterface using FastClosures: @closure using ForwardDiff: ForwardDiff, Dual -using LinearAlgebra: mul! using SciMLBase: SciMLBase, AbstractNonlinearProblem, IntervalNonlinearProblem, NonlinearProblem, NonlinearLeastSquaresProblem, remake @@ -20,11 +19,14 @@ function NonlinearSolveBase.additional_incompatible_backend_check( end Utils.value(::Type{Dual{T, V, N}}) where {T, V, N} = V -Utils.value(x::Dual) = Utils.value(ForwardDiff.value(x)) +Utils.value(x::Dual) = ForwardDiff.value(x) Utils.value(x::AbstractArray{<:Dual}) = Utils.value.(x) function NonlinearSolveBase.nonlinearsolve_forwarddiff_solve( - prob::Union{IntervalNonlinearProblem, NonlinearProblem, ImmutableNonlinearProblem}, + prob::Union{ + IntervalNonlinearProblem, NonlinearProblem, + ImmutableNonlinearProblem, NonlinearLeastSquaresProblem + }, alg, args...; kwargs... ) p = Utils.value(prob.p) @@ -35,98 +37,14 @@ function NonlinearSolveBase.nonlinearsolve_forwarddiff_solve( newprob = remake(prob; p, u0 = Utils.value(prob.u0)) end - sol = solve(newprob, alg, args...; kwargs...) - - uu = sol.u - Jₚ = NonlinearSolveBase.nonlinearsolve_∂f_∂p(prob, prob.f, uu, p) - Jᵤ = NonlinearSolveBase.nonlinearsolve_∂f_∂u(prob, prob.f, uu, p) - z = -Jᵤ \ Jₚ - pp = prob.p - sumfun = ((z, p),) -> map(Base.Fix2(*, ForwardDiff.partials(p)), z) - - if uu isa Number - partials = sum(sumfun, zip(z, pp)) - elseif p isa Number - partials = sumfun((z, pp)) - else - partials = sum(sumfun, zip(eachcol(z), pp)) - end - - return sol, partials -end - -function NonlinearSolveBase.nonlinearsolve_forwarddiff_solve( - prob::NonlinearLeastSquaresProblem, alg, args...; kwargs... -) - p = Utils.value(prob.p) - newprob = remake(prob; p, u0 = Utils.value(prob.u0)) sol = solve(newprob, alg, args...; kwargs...) uu = sol.u - # First check for custom `vjp` then custom `Jacobian` and if nothing is provided use - # nested autodiff as the last resort - if SciMLBase.has_vjp(prob.f) - if SciMLBase.isinplace(prob) - vjp_fn = @closure (du, u, p) -> begin - resid = Utils.safe_similar(du, length(sol.resid)) - prob.f(resid, u, p) - prob.f.vjp(du, resid, u, p) - du .*= 2 - return nothing - end - else - vjp_fn = @closure (u, p) -> begin - resid = prob.f(u, p) - return reshape(2 .* prob.f.vjp(resid, u, p), size(u)) - end - end - elseif SciMLBase.has_jac(prob.f) - if SciMLBase.isinplace(prob) - vjp_fn = @closure (du, u, p) -> begin - J = Utils.safe_similar(du, length(sol.resid), length(u)) - prob.f.jac(J, u, p) - resid = Utils.safe_similar(du, length(sol.resid)) - prob.f(resid, u, p) - mul!(reshape(du, 1, :), vec(resid)', J, 2, false) - return nothing - end - else - vjp_fn = @closure (u, p) -> begin - return reshape(2 .* vec(prob.f(u, p))' * prob.f.jac(u, p), size(u)) - end - end - else - # For small problems, nesting ForwardDiff is actually quite fast - autodiff = length(uu) + length(sol.resid) ≥ 50 ? - NonlinearSolveBase.select_reverse_mode_autodiff(prob, nothing) : - AutoForwardDiff() - - if SciMLBase.isinplace(prob) - vjp_fn = @closure (du, u, p) -> begin - resid = Utils.safe_similar(du, length(sol.resid)) - prob.f(resid, u, p) - # Using `Constant` lead to dual ordering issues - ff = @closure (du, u) -> prob.f(du, u, p) - resid2 = copy(resid) - DI.pullback!(ff, resid2, (du,), autodiff, u, (resid,)) - @. du *= 2 - return nothing - end - else - vjp_fn = @closure (u, p) -> begin - v = prob.f(u, p) - # Using `Constant` lead to dual ordering issues - ff = Base.Fix2(prob.f, p) - res = only(DI.pullback(ff, autodiff, u, (v,))) - ArrayInterface.can_setindex(res) || return 2 .* res - @. res *= 2 - return res - end - end - end + fn = prob isa NonlinearLeastSquaresProblem ? + NonlinearSolveBase.nlls_generate_vjp_function(prob, sol, uu) : prob.f - Jₚ = NonlinearSolveBase.nonlinearsolve_∂f_∂p(prob, vjp_fn, uu, newprob.p) - Jᵤ = NonlinearSolveBase.nonlinearsolve_∂f_∂u(prob, vjp_fn, uu, newprob.p) + Jₚ = NonlinearSolveBase.nonlinearsolve_∂f_∂p(prob, fn, uu, p) + Jᵤ = NonlinearSolveBase.nonlinearsolve_∂f_∂u(prob, fn, uu, p) z = -Jᵤ \ Jₚ pp = prob.p sumfun = ((z, p),) -> map(Base.Fix2(*, ForwardDiff.partials(p)), z) diff --git a/lib/NonlinearSolveBase/src/NonlinearSolveBase.jl b/lib/NonlinearSolveBase/src/NonlinearSolveBase.jl index a08384677..412f6a748 100644 --- a/lib/NonlinearSolveBase/src/NonlinearSolveBase.jl +++ b/lib/NonlinearSolveBase/src/NonlinearSolveBase.jl @@ -5,7 +5,7 @@ using ConcreteStructs: @concrete using FastClosures: @closure using Preferences: @load_preference, @set_preferences! -using ADTypes: ADTypes, AbstractADType, AutoSparse, NoSparsityDetector, +using ADTypes: ADTypes, AbstractADType, AutoSparse, AutoForwardDiff, NoSparsityDetector, KnownJacobianSparsityDetector using Adapt: WrappedArray using ArrayInterface: ArrayInterface @@ -25,7 +25,7 @@ using SciMLJacobianOperators: JacobianOperator, StatefulJacobianOperator using SciMLOperators: AbstractSciMLOperator, IdentityOperator using SymbolicIndexingInterface: SymbolicIndexingInterface -using LinearAlgebra: LinearAlgebra, Diagonal, norm, ldiv!, diagind +using LinearAlgebra: LinearAlgebra, Diagonal, norm, ldiv!, diagind, mul! using Markdown: @doc_str using Printf: @printf diff --git a/lib/NonlinearSolveBase/src/autodiff.jl b/lib/NonlinearSolveBase/src/autodiff.jl index 395580924..f70e9770b 100644 --- a/lib/NonlinearSolveBase/src/autodiff.jl +++ b/lib/NonlinearSolveBase/src/autodiff.jl @@ -128,3 +128,65 @@ end is_finite_differences_backend(ad::AbstractADType) = false is_finite_differences_backend(::ADTypes.AutoFiniteDiff) = true is_finite_differences_backend(::ADTypes.AutoFiniteDifferences) = true + +function nlls_generate_vjp_function(prob::NonlinearLeastSquaresProblem, sol, uu) + # First check for custom `vjp` then custom `Jacobian` and if nothing is provided use + # nested autodiff as the last resort + if SciMLBase.has_vjp(prob.f) + if SciMLBase.isinplace(prob) + return @closure (du, u, p) -> begin + resid = Utils.safe_similar(du, length(sol.resid)) + prob.f.vjp(resid, u, p) + prob.f.vjp(du, resid, u, p) + du .*= 2 + return nothing + end + else + return @closure (u, p) -> begin + resid = prob.f(u, p) + return reshape(2 .* prob.f.vjp(resid, u, p), size(u)) + end + end + elseif SciMLBase.has_jac(prob.f) + if SciMLBase.isinplace(prob) + return @closure (du, u, p) -> begin + J = Utils.safe_similar(du, length(sol.resid), length(u)) + prob.f.jac(J, u, p) + resid = Utils.safe_similar(du, length(sol.resid)) + prob.f(resid, u, p) + mul!(reshape(du, 1, :), vec(resid)', J, 2, false) + return nothing + end + else + return @closure (u, p) -> begin + return reshape(2 .* vec(prob.f(u, p))' * prob.f.jac(u, p), size(u)) + end + end + else + # For small problems, nesting ForwardDiff is actually quite fast + autodiff = length(uu) + length(sol.resid) ≥ 50 ? + select_reverse_mode_autodiff(prob, nothing) : AutoForwardDiff() + + if SciMLBase.isinplace(prob) + return @closure (du, u, p) -> begin + resid = Utils.safe_similar(du, length(sol.resid)) + prob.f(resid, u, p) + # Using `Constant` lead to dual ordering issues + ff = @closure (du, u) -> prob.f(du, u, p) + resid2 = copy(resid) + DI.pullback!(ff, resid2, (du,), autodiff, u, (resid,)) + @. du *= 2 + return nothing + end + else + return @closure (u, p) -> begin + v = prob.f(u, p) + # Using `Constant` lead to dual ordering issues + res = only(DI.pullback(Base.Fix2(prob.f, p), autodiff, u, (v,))) + ArrayInterface.can_setindex(res) || return 2 .* res + @. res *= 2 + return res + end + end + end +end diff --git a/lib/NonlinearSolveBase/src/public.jl b/lib/NonlinearSolveBase/src/public.jl index 2101f4274..d076f7873 100644 --- a/lib/NonlinearSolveBase/src/public.jl +++ b/lib/NonlinearSolveBase/src/public.jl @@ -10,6 +10,7 @@ function nonlinearsolve_forwarddiff_solve end function nonlinearsolve_dual_solution end function nonlinearsolve_∂f_∂p end function nonlinearsolve_∂f_∂u end +function nlls_generate_vjp_function end # Nonlinear Solve Termination Conditions abstract type AbstractNonlinearTerminationMode end diff --git a/src/forward_diff.jl b/src/forward_diff.jl index 410e818c3..d34bca877 100644 --- a/src/forward_diff.jl +++ b/src/forward_diff.jl @@ -48,9 +48,8 @@ function InternalAPI.reinit!( end for algType in ALL_SOLVER_TYPES - # XXX: Extend to DualNonlinearLeastSquaresProblem @eval function SciMLBase.__init( - prob::DualNonlinearProblem, alg::$(algType), args...; kwargs... + prob::DualAbstractNonlinearProblem, alg::$(algType), args...; kwargs... ) p = nodual_value(prob.p) newprob = SciMLBase.remake(prob; u0 = nodual_value(prob.u0), p) @@ -64,10 +63,13 @@ end function CommonSolve.solve!(cache::NonlinearSolveForwardDiffCache) sol = solve!(cache.cache) prob = cache.prob - uu = sol.u - Jₚ = NonlinearSolveBase.nonlinearsolve_∂f_∂p(prob, prob.f, uu, cache.values_p) - Jᵤ = NonlinearSolveBase.nonlinearsolve_∂f_∂u(prob, prob.f, uu, cache.values_p) + + fn = prob isa NonlinearLeastSquaresProblem ? + NonlinearSolveBase.nlls_generate_vjp_function(prob, sol, uu) : prob.f + + Jₚ = NonlinearSolveBase.nonlinearsolve_∂f_∂p(prob, fn, uu, cache.values_p) + Jᵤ = NonlinearSolveBase.nonlinearsolve_∂f_∂u(prob, fn, uu, cache.values_p) z_arr = -Jᵤ \ Jₚ diff --git a/test/forward_ad_tests.jl b/test/forward_ad_tests.jl index dcd1c2e3a..942d66c9b 100644 --- a/test/forward_ad_tests.jl +++ b/test/forward_ad_tests.jl @@ -124,3 +124,97 @@ end end end end + +@testitem "NLLS Hessian SciML/NonlinearSolve.jl#445" tags=[:core] begin + using ForwardDiff, FiniteDiff + + function objfn(F, init, params) + th1, th2 = init + px, py, l1, l2 = params + F[1] = l1 * cos(th1) + l2 * cos(th1 + th2) - px + F[2] = l1 * sin(th1) + l2 * sin(th1 + th2) - py + return F + end + + function solve_nlprob(pxpy) + px, py = pxpy + theta1 = pi / 4 + theta2 = pi / 4 + initial_guess = [theta1; theta2] + l1 = 60 + l2 = 60 + p = [px; py; l1; l2] + prob = NonlinearLeastSquaresProblem( + NonlinearFunction(objfn, resid_prototype = zeros(2)), + initial_guess, p + ) + resu = solve( + prob, + reltol = 1e-12, abstol = 1e-12 + ) + th1, th2 = resu.u + cable1_base = [-90; 0; 0] + cable2_base = [-150; 0; 0] + cable3_base = [150; 0; 0] + cable1_top = [l1 * cos(th1) / 2; l1 * sin(th1) / 2; 0] + cable23_top = [l1 * cos(th1) + l2 * cos(th1 + th2) / 2; + l1 * sin(th1) + l2 * sin(th1 + th2) / 2; 0] + c1_length = sqrt((cable1_top[1] - cable1_base[1])^2 + + (cable1_top[2] - cable1_base[2])^2) + c2_length = sqrt((cable23_top[1] - cable2_base[1])^2 + + (cable23_top[2] - cable2_base[2])^2) + c3_length = sqrt((cable23_top[1] - cable3_base[1])^2 + + (cable23_top[2] - cable3_base[2])^2) + return c1_length + c2_length + c3_length + end + + grad1 = ForwardDiff.gradient(solve_nlprob, [34.0, 87.0]) + grad2 = FiniteDiff.finite_difference_gradient(solve_nlprob, [34.0, 87.0]) + + @test grad1≈grad2 atol=1e-3 + + hess1 = ForwardDiff.hessian(solve_nlprob, [34.0, 87.0]) + hess2 = FiniteDiff.finite_difference_hessian(solve_nlprob, [34.0, 87.0]) + + @test hess1≈hess2 atol=1e-3 + + function solve_nlprob_with_cache(pxpy) + px, py = pxpy + theta1 = pi / 4 + theta2 = pi / 4 + initial_guess = [theta1; theta2] + l1 = 60 + l2 = 60 + p = [px; py; l1; l2] + prob = NonlinearLeastSquaresProblem( + NonlinearFunction(objfn, resid_prototype = zeros(2)), + initial_guess, p + ) + cache = init(prob; reltol = 1e-12, abstol = 1e-12) + resu = solve!(cache) + th1, th2 = resu.u + cable1_base = [-90; 0; 0] + cable2_base = [-150; 0; 0] + cable3_base = [150; 0; 0] + cable1_top = [l1 * cos(th1) / 2; l1 * sin(th1) / 2; 0] + cable23_top = [l1 * cos(th1) + l2 * cos(th1 + th2) / 2; + l1 * sin(th1) + l2 * sin(th1 + th2) / 2; 0] + c1_length = sqrt((cable1_top[1] - cable1_base[1])^2 + + (cable1_top[2] - cable1_base[2])^2) + c2_length = sqrt((cable23_top[1] - cable2_base[1])^2 + + (cable23_top[2] - cable2_base[2])^2) + c3_length = sqrt((cable23_top[1] - cable3_base[1])^2 + + (cable23_top[2] - cable3_base[2])^2) + return c1_length + c2_length + c3_length + end + + grad1 = ForwardDiff.gradient(solve_nlprob_with_cache, [34.0, 87.0]) + grad2 = FiniteDiff.finite_difference_gradient(solve_nlprob_with_cache, [34.0, 87.0]) + + @test grad1≈grad2 atol=1e-3 + + hess1 = ForwardDiff.hessian(solve_nlprob_with_cache, [34.0, 87.0]) + hess2 = FiniteDiff.finite_difference_hessian(solve_nlprob_with_cache, [34.0, 87.0]) + + @test hess1≈hess2 atol=1e-3 +end