From cbd1b95fde3c1e440d711fc2409c606122fbb0fa Mon Sep 17 00:00:00 2001 From: Matt Bossart Date: Sun, 9 Jun 2024 09:17:58 -0400 Subject: [PATCH 1/3] use make_zero and make_zero! --- ext/DiffEqBaseEnzymeExt.jl | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/ext/DiffEqBaseEnzymeExt.jl b/ext/DiffEqBaseEnzymeExt.jl index aa5a1a0e0..b31c6f26c 100644 --- a/ext/DiffEqBaseEnzymeExt.jl +++ b/ext/DiffEqBaseEnzymeExt.jl @@ -28,10 +28,7 @@ function Enzyme.EnzymeRules.augmented_primal(config::Enzyme.EnzymeRules.ConfigWi SciMLBase.EnzymeOriginator(), ntuple(arg_copy, Val(length(args)))...; kwargs...) - dres = deepcopy(res[1])::RT - for v in dres.u - v .= 0 - end + dres = Enzyme.make_zero(res[1])::RT tup = (dres, res[2]) return Enzyme.EnzymeRules.AugmentedReturn{RT, RT, Any}(res[1], dres, tup::Any) end @@ -52,9 +49,7 @@ function Enzyme.EnzymeRules.reverse(config::Enzyme.EnzymeRules.ConfigWidth{1}, end ptr.dval .+= darg end - for v in dres.u - v .= 0 - end + Enzyme.make_zero!(dres) return ntuple(_ -> nothing, Val(length(args) + 4)) end From 27381c0affabd76b941f2b32ba5d9ea20c64b135 Mon Sep 17 00:00:00 2001 From: Matt Bossart Date: Sun, 9 Jun 2024 10:02:51 -0400 Subject: [PATCH 2/3] Enzyme compat for make_zero! --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 1c1402c7c..8a342ae5e 100644 --- a/Project.toml +++ b/Project.toml @@ -70,7 +70,7 @@ DataStructures = "0.18" Distributions = "0.25" DocStringExtensions = "0.9" EnumX = "1" -Enzyme = "0.11.9, 0.12" +Enzyme = "0.12.12" EnzymeCore = "0.5, 0.6, 0.7" FastBroadcast = "0.2, 0.3" FastClosures = "0.3.2" From 549bdfcb6544938f219e178d545be219b765728b Mon Sep 17 00:00:00 2001 From: Matt Bossart Date: Sun, 9 Jun 2024 11:31:58 -0400 Subject: [PATCH 3/3] fix make_zero! --- ext/DiffEqBaseEnzymeExt.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ext/DiffEqBaseEnzymeExt.jl b/ext/DiffEqBaseEnzymeExt.jl index b31c6f26c..2b2b7c001 100644 --- a/ext/DiffEqBaseEnzymeExt.jl +++ b/ext/DiffEqBaseEnzymeExt.jl @@ -49,7 +49,7 @@ function Enzyme.EnzymeRules.reverse(config::Enzyme.EnzymeRules.ConfigWidth{1}, end ptr.dval .+= darg end - Enzyme.make_zero!(dres) + Enzyme.make_zero!(dres.u) return ntuple(_ -> nothing, Val(length(args) + 4)) end