diff --git a/Project.toml b/Project.toml index 1c1402c7c..8a342ae5e 100644 --- a/Project.toml +++ b/Project.toml @@ -70,7 +70,7 @@ DataStructures = "0.18" Distributions = "0.25" DocStringExtensions = "0.9" EnumX = "1" -Enzyme = "0.11.9, 0.12" +Enzyme = "0.12.12" EnzymeCore = "0.5, 0.6, 0.7" FastBroadcast = "0.2, 0.3" FastClosures = "0.3.2" diff --git a/ext/DiffEqBaseEnzymeExt.jl b/ext/DiffEqBaseEnzymeExt.jl index aa5a1a0e0..2b2b7c001 100644 --- a/ext/DiffEqBaseEnzymeExt.jl +++ b/ext/DiffEqBaseEnzymeExt.jl @@ -28,10 +28,7 @@ function Enzyme.EnzymeRules.augmented_primal(config::Enzyme.EnzymeRules.ConfigWi SciMLBase.EnzymeOriginator(), ntuple(arg_copy, Val(length(args)))...; kwargs...) - dres = deepcopy(res[1])::RT - for v in dres.u - v .= 0 - end + dres = Enzyme.make_zero(res[1])::RT tup = (dres, res[2]) return Enzyme.EnzymeRules.AugmentedReturn{RT, RT, Any}(res[1], dres, tup::Any) end @@ -52,9 +49,7 @@ function Enzyme.EnzymeRules.reverse(config::Enzyme.EnzymeRules.ConfigWidth{1}, end ptr.dval .+= darg end - for v in dres.u - v .= 0 - end + Enzyme.make_zero!(dres.u) return ntuple(_ -> nothing, Val(length(args) + 4)) end