diff --git a/src/alg_utils.jl b/src/alg_utils.jl index 4b7e4765..9e1de8ea 100644 --- a/src/alg_utils.jl +++ b/src/alg_utils.jl @@ -12,6 +12,28 @@ SciMLBase.forwarddiffs_model(alg::Union{StochasticDiffEqNewtonAlgorithm, # Required for initialization, because ODECore._initialize_dae! calls it during # OverrideInit OrdinaryDiffEqCore.has_autodiff(::Union{StochasticDiffEqAlgorithm,StochasticDiffEqRODEAlgorithm,StochasticDiffEqJumpAlgorithm}) = false +for T in [ + StochasticDiffEqNewtonAlgorithm, StochasticDiffEqNewtonAdaptiveAlgorithm, + StochasticDiffEqJumpNewtonAdaptiveAlgorithm, + StochasticDiffEqJumpNewtonDiffusionAdaptiveAlgorithm] + @eval OrdinaryDiffEqCore.has_autodiff(::$T) = true +end + +_alg_autodiff(::StochasticDiffEqNewtonAlgorithm{T, AD}) where {T, AD} = Val{AD}() +_alg_autodiff(::StochasticDiffEqNewtonAdaptiveAlgorithm{T, AD}) where {T, AD} = Val{AD}() +_alg_autodiff(::StochasticDiffEqJumpNewtonAdaptiveAlgorithm{T, AD}) where {T, AD} = Val{AD}() +_alg_autodiff(::StochasticDiffEqJumpNewtonDiffusionAdaptiveAlgorithm{T, AD}) where {T, AD} = Val{AD}() + +function OrdinaryDiffEqCore.alg_autodiff(alg) + ad = _alg_autodiff(alg) + if ad == Val(false) + return AutoFiniteDiff() + elseif ad == Val(true) + return AutoForwardDiff() + else + return SciMLBase._unwrap_val(ad) + end +end isadaptive(alg::Union{StochasticDiffEqAlgorithm,StochasticDiffEqRODEAlgorithm}) = false isadaptive(alg::Union{StochasticDiffEqAdaptiveAlgorithm,StochasticDiffEqRODEAdaptiveAlgorithm,StochasticDiffEqJumpAdaptiveAlgorithm,StochasticDiffEqJumpDiffusionAdaptiveAlgorithm}) = true