-
-
Notifications
You must be signed in to change notification settings - Fork 42
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: automatic backend selection for autodiff
- Loading branch information
Showing
4 changed files
with
132 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,10 +4,13 @@ authors = ["Avik Pal <[email protected]> and contributors"] | |
version = "1.0.0" | ||
|
||
[deps] | ||
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" | ||
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" | ||
CommonSolve = "38540f10-b2f7-11e9-35d8-d573e4eb0ff2" | ||
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" | ||
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471" | ||
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" | ||
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" | ||
FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a" | ||
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" | ||
Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a" | ||
|
@@ -24,10 +27,13 @@ NonlinearSolveBaseForwardDiffExt = "ForwardDiff" | |
NonlinearSolveBaseSparseArraysExt = "SparseArrays" | ||
|
||
[compat] | ||
ADTypes = "1.9" | ||
ArrayInterface = "7.9" | ||
CommonSolve = "0.2.4" | ||
Compat = "4.15" | ||
ConcreteStructs = "0.2.3" | ||
DifferentiationInterface = "0.6.1" | ||
EnzymeCore = "0.8" | ||
FastClosures = "0.3" | ||
ForwardDiff = "0.10.36" | ||
LinearAlgebra = "1.10" | ||
|
9 changes: 8 additions & 1 deletion
9
lib/NonlinearSolveBase/ext/NonlinearSolveBaseForwardDiffExt.jl
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
# Here we determine the preferred AD backend. We have a predefined list of ADs and then | ||
# we select the first one that is avialable and would work with the problem. | ||
|
||
# Ordering is important here. We want to select the first one that is compatible with the | ||
# problem. | ||
const ReverseADs = [ | ||
ADTypes.AutoEnzyme(; mode = EnzymeCore.Reverse), | ||
ADTypes.AutoZygote(), | ||
ADTypes.AutoTracker(), | ||
ADTypes.AutoReverseDiff(), | ||
ADTypes.AutoFiniteDiff() | ||
] | ||
|
||
const ForwardADs = [ | ||
ADTypes.AutoEnzyme(; mode = EnzymeCore.Forward), | ||
ADTypes.AutoPolyesterForwardDiff(), | ||
ADTypes.AutoForwardDiff(), | ||
ADTypes.AutoFiniteDiff() | ||
] | ||
|
||
# TODO: Handle Sparsity | ||
|
||
function select_forward_mode_autodiff( | ||
prob::AbstractNonlinearProblem, ad::AbstractADType; warn_check_mode::Bool = true) | ||
if warn_check_mode && !(ADTypes.mode(ad) isa ADTypes.ForwardMode) | ||
@warn "The chosen AD backend $(ad) is not a forward mode AD. Use with caution." | ||
end | ||
if incompatible_backend_and_problem(prob, ad) | ||
adₙ = select_forward_mode_autodiff(prob, nothing; warn_check_mode) | ||
@warn "The chosen AD backend `$(ad)` does not support the chosen problem. After \ | ||
running autodiff selection detected `$(adₙ)` as a potential forward mode \ | ||
backend." | ||
return adₙ | ||
end | ||
return ad | ||
end | ||
|
||
function select_forward_mode_autodiff(prob::AbstractNonlinearProblem, ::Nothing; | ||
warn_check_mode::Bool = true) | ||
idx = findfirst(!Base.Fix1(incompatible_backend_and_problem, prob), ForwardADs) | ||
idx !== nothing && return ForwardADs[idx] | ||
throw(ArgumentError("No forward mode AD backend is compatible with the chosen problem. \ | ||
This could be because no forward mode autodiff backend is loaded \ | ||
or the loaded backends don't support the problem.")) | ||
end | ||
|
||
function select_reverse_mode_autodiff( | ||
prob::AbstractNonlinearProblem, ad::AbstractADType; warn_check_mode::Bool = true) | ||
if warn_check_mode && !(ADTypes.mode(ad) isa ADTypes.ReverseMode) | ||
if !is_finite_differences_backend(ad) | ||
@warn "The chosen AD backend $(ad) is not a reverse mode AD. Use with caution." | ||
else | ||
@warn "The chosen AD backend $(ad) is a finite differences backend. This might \ | ||
be slow and inaccurate. Use with caution." | ||
end | ||
end | ||
if incompatible_backend_and_problem(prob, ad) | ||
adₙ = select_reverse_mode_autodiff(prob, nothing; warn_check_mode) | ||
@warn "The chosen AD backend `$(ad)` does not support the chosen problem. After \ | ||
running autodiff selection detected `$(adₙ)` as a potential reverse mode \ | ||
backend." | ||
return adₙ | ||
end | ||
return ad | ||
end | ||
|
||
function select_reverse_mode_autodiff(prob::AbstractNonlinearProblem, ::Nothing; | ||
warn_check_mode::Bool = true) | ||
idx = findfirst(!Base.Fix1(incompatible_backend_and_problem, prob), ReverseADs) | ||
idx !== nothing && return ReverseADs[idx] | ||
throw(ArgumentError("No reverse mode AD backend is compatible with the chosen problem. \ | ||
This could be because no reverse mode autodiff backend is loaded \ | ||
or the loaded backends don't support the problem.")) | ||
end | ||
|
||
function select_jacobian_autodiff(prob::AbstractNonlinearProblem, ad::AbstractADType) | ||
if incompatible_backend_and_problem(prob, ad) | ||
adₙ = select_jacobian_autodiff(prob, nothing) | ||
@warn "The chosen AD backend `$(ad)` does not support the chosen problem. After \ | ||
running autodiff selection detected `$(adₙ)` as a potential jacobian \ | ||
backend." | ||
return adₙ | ||
end | ||
return ad | ||
end | ||
|
||
function select_jacobian_autodiff(prob::AbstractNonlinearProblem, ::Nothing) | ||
idx = findfirst(!Base.Fix1(incompatible_backend_and_problem, prob), ForwardADs) | ||
idx !== nothing && !is_finite_differences_backend(ForwardADs[idx]) && | ||
return ForwardADs[idx] | ||
idx = findfirst(!Base.Fix1(incompatible_backend_and_problem, prob), ReverseADs) | ||
idx !== nothing && return ReverseADs[idx] | ||
throw(ArgumentError("No jacobian AD backend is compatible with the chosen problem. \ | ||
This could be because no jacobian autodiff backend is loaded \ | ||
or the loaded backends don't support the problem.")) | ||
end | ||
|
||
function incompatible_backend_and_problem( | ||
prob::AbstractNonlinearProblem, ad::AbstractADType) | ||
!DI.check_available(ad) && return true | ||
SciMLBase.isinplace(prob) && !DI.check_inplace(ad) && return true | ||
return additional_incompatible_backend_check(prob, ad) | ||
end | ||
|
||
additional_incompatible_backend_check(::AbstractNonlinearProblem, ::AbstractADType) = false | ||
|
||
is_finite_differences_backend(ad::AbstractADType) = false | ||
is_finite_differences_backend(::ADTypes.AutoFiniteDiff) = true | ||
is_finite_differences_backend(::ADTypes.AutoFiniteDifferences) = true |