Skip to content

Commit

Permalink
feat: properly implement has_autodiff and alg_autodiff
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed Dec 9, 2024
1 parent baa5adf commit 86daec7
Showing 1 changed file with 22 additions and 0 deletions.
22 changes: 22 additions & 0 deletions src/alg_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,28 @@ SciMLBase.forwarddiffs_model(alg::Union{StochasticDiffEqNewtonAlgorithm,
# Required for initialization, because ODECore._initialize_dae! calls it during
# OverrideInit
OrdinaryDiffEqCore.has_autodiff(::Union{StochasticDiffEqAlgorithm,StochasticDiffEqRODEAlgorithm,StochasticDiffEqJumpAlgorithm}) = false
for T in [
StochasticDiffEqNewtonAlgorithm, StochasticDiffEqNewtonAdaptiveAlgorithm,
StochasticDiffEqJumpNewtonAdaptiveAlgorithm,
StochasticDiffEqJumpNewtonDiffusionAdaptiveAlgorithm]
@eval OrdinaryDiffEqCore.has_autodiff(::$T) = true
end

_alg_autodiff(::StochasticDiffEqNewtonAlgorithm{T, AD}) where {T, AD} = Val{AD}()
_alg_autodiff(::StochasticDiffEqNewtonAdaptiveAlgorithm{T, AD}) where {T, AD} = Val{AD}()
_alg_autodiff(::StochasticDiffEqJumpNewtonAdaptiveAlgorithm{T, AD}) where {T, AD} = Val{AD}()
_alg_autodiff(::StochasticDiffEqJumpNewtonDiffusionAdaptiveAlgorithm{T, AD}) where {T, AD} = Val{AD}()

function OrdinaryDiffEqCore.alg_autodiff(alg)
ad = _alg_autodiff(alg)
if ad == Val(false)
return AutoFiniteDiff()
elseif ad == Val(true)
return AutoForwardDiff()
else
return SciMLBase._unwrap_val(ad)
end
end

isadaptive(alg::Union{StochasticDiffEqAlgorithm,StochasticDiffEqRODEAlgorithm}) = false
isadaptive(alg::Union{StochasticDiffEqAdaptiveAlgorithm,StochasticDiffEqRODEAdaptiveAlgorithm,StochasticDiffEqJumpAdaptiveAlgorithm,StochasticDiffEqJumpDiffusionAdaptiveAlgorithm}) = true
Expand Down

0 comments on commit 86daec7

Please sign in to comment.