diff --git a/ext/DiffEqBaseEnzymeExt.jl b/ext/DiffEqBaseEnzymeExt.jl index f69d5233b..aa5a1a0e0 100644 --- a/ext/DiffEqBaseEnzymeExt.jl +++ b/ext/DiffEqBaseEnzymeExt.jl @@ -28,7 +28,10 @@ function Enzyme.EnzymeRules.augmented_primal(config::Enzyme.EnzymeRules.ConfigWi SciMLBase.EnzymeOriginator(), ntuple(arg_copy, Val(length(args)))...; kwargs...) - dres = Enzyme.make_zero(res[1])::RT + dres = deepcopy(res[1])::RT + for v in dres.u + v .= 0 + end tup = (dres, res[2]) return Enzyme.EnzymeRules.AugmentedReturn{RT, RT, Any}(res[1], dres, tup::Any) end @@ -49,7 +52,9 @@ function Enzyme.EnzymeRules.reverse(config::Enzyme.EnzymeRules.ConfigWidth{1}, end ptr.dval .+= darg end - dres = Enzyme.make_zero(dres) + for v in dres.u + v .= 0 + end return ntuple(_ -> nothing, Val(length(args) + 4)) end