From b49e42d8d8ebfc25ccabb7d444406b73346f014e Mon Sep 17 00:00:00 2001
From: Avik Pal <avikpal@mit.edu>
Date: Mon, 21 Oct 2024 20:34:34 -0400
Subject: [PATCH] fix: dispatch forwarddiff on `__init` and `__solve`

---
 src/NonlinearSolve.jl        |  3 ++-
 src/internal/forward_diff.jl | 51 ++++++++++++++++++++----------------
 2 files changed, 31 insertions(+), 23 deletions(-)

diff --git a/src/NonlinearSolve.jl b/src/NonlinearSolve.jl
index 5f21936a1..6a0d50054 100644
--- a/src/NonlinearSolve.jl
+++ b/src/NonlinearSolve.jl
@@ -67,7 +67,6 @@ include("descent/damped_newton.jl")
 include("descent/geodesic_acceleration.jl")
 
 include("internal/jacobian.jl")
-include("internal/forward_diff.jl")
 include("internal/linear_solve.jl")
 include("internal/termination.jl")
 include("internal/tracing.jl")
@@ -82,6 +81,8 @@ include("core/generalized_first_order.jl")
 include("core/spectral_methods.jl")
 include("core/noinit.jl")
 
+include("internal/forward_diff.jl") # we need to define after the algorithms
+
 include("algorithms/raphson.jl")
 include("algorithms/pseudo_transient.jl")
 include("algorithms/broyden.jl")
diff --git a/src/internal/forward_diff.jl b/src/internal/forward_diff.jl
index a4238674e..86c223fc8 100644
--- a/src/internal/forward_diff.jl
+++ b/src/internal/forward_diff.jl
@@ -1,14 +1,19 @@
-# XXX: dispatch on `__solve` & `__init`
-function SciMLBase.solve(
-        prob::NonlinearProblem{<:Union{Number, <:AbstractArray}, iip,
-            <:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}},
-        alg::Union{Nothing, AbstractNonlinearAlgorithm},
-        args...;
-        kwargs...) where {T, V, P, iip}
-    sol, partials = nonlinearsolve_forwarddiff_solve(prob, alg, args...; kwargs...)
-    dual_soln = nonlinearsolve_dual_solution(sol.u, partials, prob.p)
-    return SciMLBase.build_solution(
-        prob, alg, dual_soln, sol.resid; sol.retcode, sol.stats, sol.original)
+const DualNonlinearProblem = NonlinearProblem{<:Union{Number, <:AbstractArray}, iip,
+    <:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}} where {iip, T, V, P}
+const DualNonlinearLeastSquaresProblem = NonlinearLeastSquaresProblem{
+    <:Union{Number, <:AbstractArray}, iip,
+    <:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}} where {iip, T, V, P}
+const DualAbstractNonlinearProblem = Union{
+    DualNonlinearProblem, DualNonlinearLeastSquaresProblem}
+
+for algType in (Nothing, AbstractNonlinearSolveAlgorithm)
+    @eval function SciMLBase.__solve(
+            prob::DualNonlinearProblem, alg::$(algType), args...; kwargs...)
+        sol, partials = nonlinearsolve_forwarddiff_solve(prob, alg, args...; kwargs...)
+        dual_soln = nonlinearsolve_dual_solution(sol.u, partials, prob.p)
+        return SciMLBase.build_solution(
+            prob, alg, dual_soln, sol.resid; sol.retcode, sol.stats, sol.original)
+    end
 end
 
 @concrete mutable struct NonlinearSolveForwardDiffCache
@@ -32,17 +37,19 @@ function reinit_cache!(cache::NonlinearSolveForwardDiffCache;
     return cache
 end
 
-function SciMLBase.init(
-        prob::NonlinearProblem{<:Union{Number, <:AbstractArray}, iip,
-            <:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}},
-        alg::Union{Nothing, AbstractNonlinearAlgorithm},
-        args...;
-        kwargs...) where {T, V, P, iip}
-    p = __value(prob.p)
-    newprob = NonlinearProblem(prob.f, __value(prob.u0), p; prob.kwargs...)
-    cache = init(newprob, alg, args...; kwargs...)
-    return NonlinearSolveForwardDiffCache(
-        cache, newprob, alg, prob.p, p, ForwardDiff.partials(prob.p))
+for algType in (
+    Nothing, AbstractNonlinearSolveAlgorithm, GeneralizedDFSane,
+    SimpleNonlinearSolve.AbstractSimpleNonlinearSolveAlgorithm,
+    GeneralizedFirstOrderAlgorithm, ApproximateJacobianSolveAlgorithm
+)
+    @eval function SciMLBase.__init(
+            prob::DualNonlinearProblem, alg::$(algType), args...; kwargs...)
+        p = __value(prob.p)
+        newprob = NonlinearProblem(prob.f, __value(prob.u0), p; prob.kwargs...)
+        cache = init(newprob, alg, args...; kwargs...)
+        return NonlinearSolveForwardDiffCache(
+            cache, newprob, alg, prob.p, p, ForwardDiff.partials(prob.p))
+    end
 end
 
 function SciMLBase.solve!(cache::NonlinearSolveForwardDiffCache)