diff --git a/src/NonlinearSolve.jl b/src/NonlinearSolve.jl index 5f21936a1..6a0d50054 100644 --- a/src/NonlinearSolve.jl +++ b/src/NonlinearSolve.jl @@ -67,7 +67,6 @@ include("descent/damped_newton.jl") include("descent/geodesic_acceleration.jl") include("internal/jacobian.jl") -include("internal/forward_diff.jl") include("internal/linear_solve.jl") include("internal/termination.jl") include("internal/tracing.jl") @@ -82,6 +81,8 @@ include("core/generalized_first_order.jl") include("core/spectral_methods.jl") include("core/noinit.jl") +include("internal/forward_diff.jl") # we need to define after the algorithms + include("algorithms/raphson.jl") include("algorithms/pseudo_transient.jl") include("algorithms/broyden.jl") diff --git a/src/internal/forward_diff.jl b/src/internal/forward_diff.jl index a4238674e..86c223fc8 100644 --- a/src/internal/forward_diff.jl +++ b/src/internal/forward_diff.jl @@ -1,14 +1,19 @@ -# XXX: dispatch on `__solve` & `__init` -function SciMLBase.solve( - prob::NonlinearProblem{<:Union{Number, <:AbstractArray}, iip, - <:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}}, - alg::Union{Nothing, AbstractNonlinearAlgorithm}, - args...; - kwargs...) where {T, V, P, iip} - sol, partials = nonlinearsolve_forwarddiff_solve(prob, alg, args...; kwargs...) - dual_soln = nonlinearsolve_dual_solution(sol.u, partials, prob.p) - return SciMLBase.build_solution( - prob, alg, dual_soln, sol.resid; sol.retcode, sol.stats, sol.original) +const DualNonlinearProblem = NonlinearProblem{<:Union{Number, <:AbstractArray}, iip, + <:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}} where {iip, T, V, P} +const DualNonlinearLeastSquaresProblem = NonlinearLeastSquaresProblem{ + <:Union{Number, <:AbstractArray}, iip, + <:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}} where {iip, T, V, P} +const DualAbstractNonlinearProblem = Union{ + DualNonlinearProblem, DualNonlinearLeastSquaresProblem} + +for algType in (Nothing, AbstractNonlinearSolveAlgorithm) + @eval function SciMLBase.__solve( + prob::DualNonlinearProblem, alg::$(algType), args...; kwargs...) + sol, partials = nonlinearsolve_forwarddiff_solve(prob, alg, args...; kwargs...) + dual_soln = nonlinearsolve_dual_solution(sol.u, partials, prob.p) + return SciMLBase.build_solution( + prob, alg, dual_soln, sol.resid; sol.retcode, sol.stats, sol.original) + end end @concrete mutable struct NonlinearSolveForwardDiffCache @@ -32,17 +37,19 @@ function reinit_cache!(cache::NonlinearSolveForwardDiffCache; return cache end -function SciMLBase.init( - prob::NonlinearProblem{<:Union{Number, <:AbstractArray}, iip, - <:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}}, - alg::Union{Nothing, AbstractNonlinearAlgorithm}, - args...; - kwargs...) where {T, V, P, iip} - p = __value(prob.p) - newprob = NonlinearProblem(prob.f, __value(prob.u0), p; prob.kwargs...) - cache = init(newprob, alg, args...; kwargs...) - return NonlinearSolveForwardDiffCache( - cache, newprob, alg, prob.p, p, ForwardDiff.partials(prob.p)) +for algType in ( + Nothing, AbstractNonlinearSolveAlgorithm, GeneralizedDFSane, + SimpleNonlinearSolve.AbstractSimpleNonlinearSolveAlgorithm, + GeneralizedFirstOrderAlgorithm, ApproximateJacobianSolveAlgorithm +) + @eval function SciMLBase.__init( + prob::DualNonlinearProblem, alg::$(algType), args...; kwargs...) + p = __value(prob.p) + newprob = NonlinearProblem(prob.f, __value(prob.u0), p; prob.kwargs...) + cache = init(newprob, alg, args...; kwargs...) + return NonlinearSolveForwardDiffCache( + cache, newprob, alg, prob.p, p, ForwardDiff.partials(prob.p)) + end end function SciMLBase.solve!(cache::NonlinearSolveForwardDiffCache)