diff --git a/lib/NonlinearSolveBase/Project.toml b/lib/NonlinearSolveBase/Project.toml index 94d12d822..819bcc79f 100644 --- a/lib/NonlinearSolveBase/Project.toml +++ b/lib/NonlinearSolveBase/Project.toml @@ -4,10 +4,13 @@ authors = ["Avik Pal and contributors"] version = "1.0.0" [deps] +ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" CommonSolve = "38540f10-b2f7-11e9-35d8-d573e4eb0ff2" Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471" +DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" +EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a" @@ -24,10 +27,13 @@ NonlinearSolveBaseForwardDiffExt = "ForwardDiff" NonlinearSolveBaseSparseArraysExt = "SparseArrays" [compat] +ADTypes = "1.9" ArrayInterface = "7.9" CommonSolve = "0.2.4" Compat = "4.15" ConcreteStructs = "0.2.3" +DifferentiationInterface = "0.6.1" +EnzymeCore = "0.8" FastClosures = "0.3" ForwardDiff = "0.10.36" LinearAlgebra = "1.10" diff --git a/lib/NonlinearSolveBase/ext/NonlinearSolveBaseForwardDiffExt.jl b/lib/NonlinearSolveBase/ext/NonlinearSolveBaseForwardDiffExt.jl index 469c5944b..31550da96 100644 --- a/lib/NonlinearSolveBase/ext/NonlinearSolveBaseForwardDiffExt.jl +++ b/lib/NonlinearSolveBase/ext/NonlinearSolveBaseForwardDiffExt.jl @@ -1,13 +1,20 @@ module NonlinearSolveBaseForwardDiffExt +using ADTypes: ADTypes, AutoForwardDiff, AutoPolyesterForwardDiff using CommonSolve: solve using FastClosures: @closure using ForwardDiff: ForwardDiff, Dual -using SciMLBase: SciMLBase, IntervalNonlinearProblem, NonlinearProblem, +using SciMLBase: SciMLBase, AbstractNonlinearProblem, IntervalNonlinearProblem, + NonlinearProblem, NonlinearLeastSquaresProblem, remake using NonlinearSolveBase: NonlinearSolveBase, ImmutableNonlinearProblem, Utils +function NonlinearSolveBase.additional_incompatible_backend_check( + prob::AbstractNonlinearProblem, ::Union{AutoForwardDiff, AutoPolyesterForwardDiff}) + return !ForwardDiff.can_dual(eltype(prob.u0)) +end + Utils.value(::Type{Dual{T, V, N}}) where {T, V, N} = V Utils.value(x::Dual) = Utils.value(ForwardDiff.value(x)) Utils.value(x::AbstractArray{<:Dual}) = Utils.value.(x) diff --git a/lib/NonlinearSolveBase/src/NonlinearSolveBase.jl b/lib/NonlinearSolveBase/src/NonlinearSolveBase.jl index 63f4b697c..4b3eec258 100644 --- a/lib/NonlinearSolveBase/src/NonlinearSolveBase.jl +++ b/lib/NonlinearSolveBase/src/NonlinearSolveBase.jl @@ -1,8 +1,11 @@ module NonlinearSolveBase +using ADTypes: ADTypes, AbstractADType, ForwardMode, ReverseMode using ArrayInterface: ArrayInterface using Compat: @compat using ConcreteStructs: @concrete +using DifferentiationInterface: DifferentiationInterface +using EnzymeCore: EnzymeCore using FastClosures: @closure using LinearAlgebra: norm using Markdown: @doc_str @@ -13,6 +16,8 @@ using SciMLBase: SciMLBase, ReturnCode, AbstractODEIntegrator, AbstractNonlinear isinplace, warn_paramtype using StaticArraysCore: StaticArray +const DI = DifferentiationInterface + include("public.jl") include("utils.jl") @@ -20,9 +25,13 @@ include("immutable_problem.jl") include("common_defaults.jl") include("termination_conditions.jl") +include("autodiff.jl") + # Unexported Public API @compat(public, (L2_NORM, Linf_NORM, NAN_CHECK, UNITLESS_ABS2, get_tolerance)) @compat(public, (nonlinearsolve_forwarddiff_solve, nonlinearsolve_dual_solution)) +@compat(public, (select_forward_mode_autodiff, select_reverse_mode_autodiff, + select_jacobian_autodiff)) export RelTerminationMode, AbsTerminationMode, NormTerminationMode, RelNormTerminationMode, AbsNormTerminationMode, RelNormSafeTerminationMode, AbsNormSafeTerminationMode, diff --git a/lib/NonlinearSolveBase/src/autodiff.jl b/lib/NonlinearSolveBase/src/autodiff.jl new file mode 100644 index 000000000..d2e51389d --- /dev/null +++ b/lib/NonlinearSolveBase/src/autodiff.jl @@ -0,0 +1,109 @@ +# Here we determine the preferred AD backend. We have a predefined list of ADs and then +# we select the first one that is avialable and would work with the problem. + +# Ordering is important here. We want to select the first one that is compatible with the +# problem. +const ReverseADs = [ + ADTypes.AutoEnzyme(; mode = EnzymeCore.Reverse), + ADTypes.AutoZygote(), + ADTypes.AutoTracker(), + ADTypes.AutoReverseDiff(), + ADTypes.AutoFiniteDiff() +] + +const ForwardADs = [ + ADTypes.AutoEnzyme(; mode = EnzymeCore.Forward), + ADTypes.AutoPolyesterForwardDiff(), + ADTypes.AutoForwardDiff(), + ADTypes.AutoFiniteDiff() +] + +# TODO: Handle Sparsity + +function select_forward_mode_autodiff( + prob::AbstractNonlinearProblem, ad::AbstractADType; warn_check_mode::Bool = true) + if warn_check_mode && !(ADTypes.mode(ad) isa ADTypes.ForwardMode) + @warn "The chosen AD backend $(ad) is not a forward mode AD. Use with caution." + end + if incompatible_backend_and_problem(prob, ad) + adₙ = select_forward_mode_autodiff(prob, nothing; warn_check_mode) + @warn "The chosen AD backend `$(ad)` does not support the chosen problem. After \ + running autodiff selection detected `$(adₙ)` as a potential forward mode \ + backend." + return adₙ + end + return ad +end + +function select_forward_mode_autodiff(prob::AbstractNonlinearProblem, ::Nothing; + warn_check_mode::Bool = true) + idx = findfirst(!Base.Fix1(incompatible_backend_and_problem, prob), ForwardADs) + idx !== nothing && return ForwardADs[idx] + throw(ArgumentError("No forward mode AD backend is compatible with the chosen problem. \ + This could be because no forward mode autodiff backend is loaded \ + or the loaded backends don't support the problem.")) +end + +function select_reverse_mode_autodiff( + prob::AbstractNonlinearProblem, ad::AbstractADType; warn_check_mode::Bool = true) + if warn_check_mode && !(ADTypes.mode(ad) isa ADTypes.ReverseMode) + if !is_finite_differences_backend(ad) + @warn "The chosen AD backend $(ad) is not a reverse mode AD. Use with caution." + else + @warn "The chosen AD backend $(ad) is a finite differences backend. This might \ + be slow and inaccurate. Use with caution." + end + end + if incompatible_backend_and_problem(prob, ad) + adₙ = select_reverse_mode_autodiff(prob, nothing; warn_check_mode) + @warn "The chosen AD backend `$(ad)` does not support the chosen problem. After \ + running autodiff selection detected `$(adₙ)` as a potential reverse mode \ + backend." + return adₙ + end + return ad +end + +function select_reverse_mode_autodiff(prob::AbstractNonlinearProblem, ::Nothing; + warn_check_mode::Bool = true) + idx = findfirst(!Base.Fix1(incompatible_backend_and_problem, prob), ReverseADs) + idx !== nothing && return ReverseADs[idx] + throw(ArgumentError("No reverse mode AD backend is compatible with the chosen problem. \ + This could be because no reverse mode autodiff backend is loaded \ + or the loaded backends don't support the problem.")) +end + +function select_jacobian_autodiff(prob::AbstractNonlinearProblem, ad::AbstractADType) + if incompatible_backend_and_problem(prob, ad) + adₙ = select_jacobian_autodiff(prob, nothing) + @warn "The chosen AD backend `$(ad)` does not support the chosen problem. After \ + running autodiff selection detected `$(adₙ)` as a potential jacobian \ + backend." + return adₙ + end + return ad +end + +function select_jacobian_autodiff(prob::AbstractNonlinearProblem, ::Nothing) + idx = findfirst(!Base.Fix1(incompatible_backend_and_problem, prob), ForwardADs) + idx !== nothing && !is_finite_differences_backend(ForwardADs[idx]) && + return ForwardADs[idx] + idx = findfirst(!Base.Fix1(incompatible_backend_and_problem, prob), ReverseADs) + idx !== nothing && return ReverseADs[idx] + throw(ArgumentError("No jacobian AD backend is compatible with the chosen problem. \ + This could be because no jacobian autodiff backend is loaded \ + or the loaded backends don't support the problem.")) +end + +function incompatible_backend_and_problem( + prob::AbstractNonlinearProblem, ad::AbstractADType) + !DI.check_available(ad) && return true + SciMLBase.isinplace(prob) && !DI.check_inplace(ad) && return true + return additional_incompatible_backend_check(prob, ad) +end + +additional_incompatible_backend_check(::AbstractNonlinearProblem, ::AbstractADType) = false + +is_finite_differences_backend(ad::AbstractADType) = false +is_finite_differences_backend(::ADTypes.AutoFiniteDiff) = true +is_finite_differences_backend(::ADTypes.AutoFiniteDifferences) = true