Skip to content

Commit

Permalink
feat: check for branching for ReverseDiff(compile=true)
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Oct 4, 2024
1 parent 1075e2d commit 1d7df8d
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 4 deletions.
2 changes: 2 additions & 0 deletions lib/NonlinearSolveBase/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a"
FunctionProperties = "f62d2435-5019-4c03-9749-2d4c77af0cbc"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
Expand All @@ -36,6 +37,7 @@ DifferentiationInterface = "0.6.1"
EnzymeCore = "0.8"
FastClosures = "0.3"
ForwardDiff = "0.10.36"
FunctionProperties = "0.1.2"
LinearAlgebra = "1.10"
Markdown = "1.10"
RecursiveArrayTools = "3"
Expand Down
10 changes: 6 additions & 4 deletions lib/NonlinearSolveBase/src/NonlinearSolveBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,14 @@ using ConcreteStructs: @concrete
using DifferentiationInterface: DifferentiationInterface
using EnzymeCore: EnzymeCore
using FastClosures: @closure
using FunctionProperties: hasbranching
using LinearAlgebra: norm
using Markdown: @doc_str
using RecursiveArrayTools: AbstractVectorOfArray, ArrayPartition
using SciMLBase: SciMLBase, ReturnCode, AbstractODEIntegrator, AbstractNonlinearProblem,
NonlinearProblem, NonlinearLeastSquaresProblem, AbstractNonlinearFunction,
@add_kwonly, StandardNonlinearProblem, NullParameters, NonlinearProblem,
isinplace, warn_paramtype
@add_kwonly, StandardNonlinearProblem, NullParameters, isinplace,
warn_paramtype
using StaticArraysCore: StaticArray

const DI = DifferentiationInterface
Expand All @@ -30,8 +31,9 @@ include("autodiff.jl")
# Unexported Public API
@compat(public, (L2_NORM, Linf_NORM, NAN_CHECK, UNITLESS_ABS2, get_tolerance))
@compat(public, (nonlinearsolve_forwarddiff_solve, nonlinearsolve_dual_solution))
@compat(public, (select_forward_mode_autodiff, select_reverse_mode_autodiff,
select_jacobian_autodiff))
@compat(public,
(select_forward_mode_autodiff, select_reverse_mode_autodiff,
select_jacobian_autodiff))

export RelTerminationMode, AbsTerminationMode, NormTerminationMode, RelNormTerminationMode,
AbsNormTerminationMode, RelNormSafeTerminationMode, AbsNormSafeTerminationMode,
Expand Down
9 changes: 9 additions & 0 deletions lib/NonlinearSolveBase/src/autodiff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ const ReverseADs = [
ADTypes.AutoEnzyme(; mode = EnzymeCore.Reverse),
ADTypes.AutoZygote(),
ADTypes.AutoTracker(),
ADTypes.AutoReverseDiff(; compile = true),
ADTypes.AutoReverseDiff(),
ADTypes.AutoFiniteDiff()
]
Expand Down Expand Up @@ -103,6 +104,14 @@ function incompatible_backend_and_problem(
end

additional_incompatible_backend_check(::AbstractNonlinearProblem, ::AbstractADType) = false
function additional_incompatible_backend_check(prob::AbstractNonlinearProblem,
::ADTypes.AutoReverseDiff{true})
if SciMLBase.isinplace(prob)
fu = prob.f.resid_prototype === nothing ? zero(prob.u0) : prob.f.resid_prototype
return hasbranching(prob.f, fu, prob.u0, prob.p)
end
return hasbranching(prob.f, prob.u0, prob.p)
end

is_finite_differences_backend(ad::AbstractADType) = false
is_finite_differences_backend(::ADTypes.AutoFiniteDiff) = true
Expand Down

0 comments on commit 1d7df8d

Please sign in to comment.