From 1075e2d8ea360d69814e0948f96629dc175fc17e Mon Sep 17 00:00:00 2001
From: Avik Pal <avikpal@mit.edu>
Date: Wed, 25 Sep 2024 18:04:59 -0400
Subject: [PATCH] feat: automatic backend selection for autodiff

---
 lib/NonlinearSolveBase/Project.toml           |   6 +
 .../ext/NonlinearSolveBaseForwardDiffExt.jl   |   9 +-
 .../src/NonlinearSolveBase.jl                 |   9 ++
 lib/NonlinearSolveBase/src/autodiff.jl        | 109 ++++++++++++++++++
 4 files changed, 132 insertions(+), 1 deletion(-)
 create mode 100644 lib/NonlinearSolveBase/src/autodiff.jl

diff --git a/lib/NonlinearSolveBase/Project.toml b/lib/NonlinearSolveBase/Project.toml
index 94d12d822..819bcc79f 100644
--- a/lib/NonlinearSolveBase/Project.toml
+++ b/lib/NonlinearSolveBase/Project.toml
@@ -4,10 +4,13 @@ authors = ["Avik Pal <avikpal@mit.edu> and contributors"]
 version = "1.0.0"
 
 [deps]
+ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
 ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
 CommonSolve = "38540f10-b2f7-11e9-35d8-d573e4eb0ff2"
 Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
 ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
+DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
+EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
 FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a"
 LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
 Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
@@ -24,10 +27,13 @@ NonlinearSolveBaseForwardDiffExt = "ForwardDiff"
 NonlinearSolveBaseSparseArraysExt = "SparseArrays"
 
 [compat]
+ADTypes = "1.9"
 ArrayInterface = "7.9"
 CommonSolve = "0.2.4"
 Compat = "4.15"
 ConcreteStructs = "0.2.3"
+DifferentiationInterface = "0.6.1"
+EnzymeCore = "0.8"
 FastClosures = "0.3"
 ForwardDiff = "0.10.36"
 LinearAlgebra = "1.10"
diff --git a/lib/NonlinearSolveBase/ext/NonlinearSolveBaseForwardDiffExt.jl b/lib/NonlinearSolveBase/ext/NonlinearSolveBaseForwardDiffExt.jl
index 469c5944b..31550da96 100644
--- a/lib/NonlinearSolveBase/ext/NonlinearSolveBaseForwardDiffExt.jl
+++ b/lib/NonlinearSolveBase/ext/NonlinearSolveBaseForwardDiffExt.jl
@@ -1,13 +1,20 @@
 module NonlinearSolveBaseForwardDiffExt
 
+using ADTypes: ADTypes, AutoForwardDiff, AutoPolyesterForwardDiff
 using CommonSolve: solve
 using FastClosures: @closure
 using ForwardDiff: ForwardDiff, Dual
-using SciMLBase: SciMLBase, IntervalNonlinearProblem, NonlinearProblem,
+using SciMLBase: SciMLBase, AbstractNonlinearProblem, IntervalNonlinearProblem,
+                 NonlinearProblem,
                  NonlinearLeastSquaresProblem, remake
 
 using NonlinearSolveBase: NonlinearSolveBase, ImmutableNonlinearProblem, Utils
 
+function NonlinearSolveBase.additional_incompatible_backend_check(
+        prob::AbstractNonlinearProblem, ::Union{AutoForwardDiff, AutoPolyesterForwardDiff})
+    return !ForwardDiff.can_dual(eltype(prob.u0))
+end
+
 Utils.value(::Type{Dual{T, V, N}}) where {T, V, N} = V
 Utils.value(x::Dual) = Utils.value(ForwardDiff.value(x))
 Utils.value(x::AbstractArray{<:Dual}) = Utils.value.(x)
diff --git a/lib/NonlinearSolveBase/src/NonlinearSolveBase.jl b/lib/NonlinearSolveBase/src/NonlinearSolveBase.jl
index 63f4b697c..4b3eec258 100644
--- a/lib/NonlinearSolveBase/src/NonlinearSolveBase.jl
+++ b/lib/NonlinearSolveBase/src/NonlinearSolveBase.jl
@@ -1,8 +1,11 @@
 module NonlinearSolveBase
 
+using ADTypes: ADTypes, AbstractADType, ForwardMode, ReverseMode
 using ArrayInterface: ArrayInterface
 using Compat: @compat
 using ConcreteStructs: @concrete
+using DifferentiationInterface: DifferentiationInterface
+using EnzymeCore: EnzymeCore
 using FastClosures: @closure
 using LinearAlgebra: norm
 using Markdown: @doc_str
@@ -13,6 +16,8 @@ using SciMLBase: SciMLBase, ReturnCode, AbstractODEIntegrator, AbstractNonlinear
                  isinplace, warn_paramtype
 using StaticArraysCore: StaticArray
 
+const DI = DifferentiationInterface
+
 include("public.jl")
 include("utils.jl")
 
@@ -20,9 +25,13 @@ include("immutable_problem.jl")
 include("common_defaults.jl")
 include("termination_conditions.jl")
 
+include("autodiff.jl")
+
 # Unexported Public API
 @compat(public, (L2_NORM, Linf_NORM, NAN_CHECK, UNITLESS_ABS2, get_tolerance))
 @compat(public, (nonlinearsolve_forwarddiff_solve, nonlinearsolve_dual_solution))
+@compat(public, (select_forward_mode_autodiff, select_reverse_mode_autodiff,
+    select_jacobian_autodiff))
 
 export RelTerminationMode, AbsTerminationMode, NormTerminationMode, RelNormTerminationMode,
        AbsNormTerminationMode, RelNormSafeTerminationMode, AbsNormSafeTerminationMode,
diff --git a/lib/NonlinearSolveBase/src/autodiff.jl b/lib/NonlinearSolveBase/src/autodiff.jl
new file mode 100644
index 000000000..d2e51389d
--- /dev/null
+++ b/lib/NonlinearSolveBase/src/autodiff.jl
@@ -0,0 +1,109 @@
+# Here we determine the preferred AD backend. We have a predefined list of ADs and then
+# we select the first one that is avialable and would work with the problem.
+
+# Ordering is important here. We want to select the first one that is compatible with the
+# problem.
+const ReverseADs = [
+    ADTypes.AutoEnzyme(; mode = EnzymeCore.Reverse),
+    ADTypes.AutoZygote(),
+    ADTypes.AutoTracker(),
+    ADTypes.AutoReverseDiff(),
+    ADTypes.AutoFiniteDiff()
+]
+
+const ForwardADs = [
+    ADTypes.AutoEnzyme(; mode = EnzymeCore.Forward),
+    ADTypes.AutoPolyesterForwardDiff(),
+    ADTypes.AutoForwardDiff(),
+    ADTypes.AutoFiniteDiff()
+]
+
+# TODO: Handle Sparsity
+
+function select_forward_mode_autodiff(
+        prob::AbstractNonlinearProblem, ad::AbstractADType; warn_check_mode::Bool = true)
+    if warn_check_mode && !(ADTypes.mode(ad) isa ADTypes.ForwardMode)
+        @warn "The chosen AD backend $(ad) is not a forward mode AD. Use with caution."
+    end
+    if incompatible_backend_and_problem(prob, ad)
+        adₙ = select_forward_mode_autodiff(prob, nothing; warn_check_mode)
+        @warn "The chosen AD backend `$(ad)` does not support the chosen problem. After \
+               running autodiff selection detected `$(adₙ)` as a potential forward mode \
+               backend."
+        return adₙ
+    end
+    return ad
+end
+
+function select_forward_mode_autodiff(prob::AbstractNonlinearProblem, ::Nothing;
+        warn_check_mode::Bool = true)
+    idx = findfirst(!Base.Fix1(incompatible_backend_and_problem, prob), ForwardADs)
+    idx !== nothing && return ForwardADs[idx]
+    throw(ArgumentError("No forward mode AD backend is compatible with the chosen problem. \
+                         This could be because no forward mode autodiff backend is loaded \
+                         or the loaded backends don't support the problem."))
+end
+
+function select_reverse_mode_autodiff(
+        prob::AbstractNonlinearProblem, ad::AbstractADType; warn_check_mode::Bool = true)
+    if warn_check_mode && !(ADTypes.mode(ad) isa ADTypes.ReverseMode)
+        if !is_finite_differences_backend(ad)
+            @warn "The chosen AD backend $(ad) is not a reverse mode AD. Use with caution."
+        else
+            @warn "The chosen AD backend $(ad) is a finite differences backend. This might \
+                   be slow and inaccurate. Use with caution."
+        end
+    end
+    if incompatible_backend_and_problem(prob, ad)
+        adₙ = select_reverse_mode_autodiff(prob, nothing; warn_check_mode)
+        @warn "The chosen AD backend `$(ad)` does not support the chosen problem. After \
+               running autodiff selection detected `$(adₙ)` as a potential reverse mode \
+               backend."
+        return adₙ
+    end
+    return ad
+end
+
+function select_reverse_mode_autodiff(prob::AbstractNonlinearProblem, ::Nothing;
+        warn_check_mode::Bool = true)
+    idx = findfirst(!Base.Fix1(incompatible_backend_and_problem, prob), ReverseADs)
+    idx !== nothing && return ReverseADs[idx]
+    throw(ArgumentError("No reverse mode AD backend is compatible with the chosen problem. \
+                         This could be because no reverse mode autodiff backend is loaded \
+                         or the loaded backends don't support the problem."))
+end
+
+function select_jacobian_autodiff(prob::AbstractNonlinearProblem, ad::AbstractADType)
+    if incompatible_backend_and_problem(prob, ad)
+        adₙ = select_jacobian_autodiff(prob, nothing)
+        @warn "The chosen AD backend `$(ad)` does not support the chosen problem. After \
+               running autodiff selection detected `$(adₙ)` as a potential jacobian \
+               backend."
+        return adₙ
+    end
+    return ad
+end
+
+function select_jacobian_autodiff(prob::AbstractNonlinearProblem, ::Nothing)
+    idx = findfirst(!Base.Fix1(incompatible_backend_and_problem, prob), ForwardADs)
+    idx !== nothing && !is_finite_differences_backend(ForwardADs[idx]) &&
+        return ForwardADs[idx]
+    idx = findfirst(!Base.Fix1(incompatible_backend_and_problem, prob), ReverseADs)
+    idx !== nothing && return ReverseADs[idx]
+    throw(ArgumentError("No jacobian AD backend is compatible with the chosen problem. \
+                         This could be because no jacobian autodiff backend is loaded \
+                         or the loaded backends don't support the problem."))
+end
+
+function incompatible_backend_and_problem(
+        prob::AbstractNonlinearProblem, ad::AbstractADType)
+    !DI.check_available(ad) && return true
+    SciMLBase.isinplace(prob) && !DI.check_inplace(ad) && return true
+    return additional_incompatible_backend_check(prob, ad)
+end
+
+additional_incompatible_backend_check(::AbstractNonlinearProblem, ::AbstractADType) = false
+
+is_finite_differences_backend(ad::AbstractADType) = false
+is_finite_differences_backend(::ADTypes.AutoFiniteDiff) = true
+is_finite_differences_backend(::ADTypes.AutoFiniteDifferences) = true