From 0d8a47a4046574944f423d36460b615b5695b48c Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 26 Sep 2024 12:28:39 -0400 Subject: [PATCH] fix: forward rules aliasing issue --- ext/LinearSolveEnzymeExt.jl | 18 +++++++++++------- test/enzyme.jl | 2 +- 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/ext/LinearSolveEnzymeExt.jl b/ext/LinearSolveEnzymeExt.jl index ddc37f63..abd2232e 100644 --- a/ext/LinearSolveEnzymeExt.jl +++ b/ext/LinearSolveEnzymeExt.jl @@ -17,9 +17,15 @@ function EnzymeRules.forward(config::EnzymeRules.FwdConfigWidth{1}, return nothing end end + dres = func.val(prob.dval, alg.val; kwargs...) - dres.b .= res.b == dres.b ? zero(dres.b) : dres.b - dres.A .= res.A == dres.A ? zero(dres.A) : dres.A + + if dres.b == res.b + dres.b .= false + end + if dres.A == res.A + dres.A .= false + end if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config) return Duplicated(res, dres) @@ -50,14 +56,12 @@ function EnzymeRules.forward( if linsolve.val.alg isa LinearSolve.AbstractKrylovSubspaceMethod error("Algorithm $(_linsolve.alg) is currently not supported by Enzyme rules on LinearSolve.jl. Please open an issue on LinearSolve.jl detailing which algorithm is missing the adjoint handling") end - b = deepcopy(linsolve.val.b) - db = linsolve.dval.b - dA = linsolve.dval.A + res = deepcopy(res) # Without this copy, the next solve will end up mutating the result - linsolve.val.b = db - dA * res.u + b = linsolve.val.b + linsolve.val.b = linsolve.dval.b - linsolve.dval.A * res.u dres = func.val(linsolve.val; kwargs...) - linsolve.val.b = b if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config) diff --git a/test/enzyme.jl b/test/enzyme.jl index 9192b63a..b09c0de5 100644 --- a/test/enzyme.jl +++ b/test/enzyme.jl @@ -209,7 +209,7 @@ end en_jac = map(onehot(A)) do dA return only(Enzyme.autodiff(set_runtime_activity(Forward), fnice, Duplicated(A, dA), Const(b1), Const(alg))) - end |> collect |> (x -> reshape(x, n, n)) + end |> collect @show en_jac @test en_jac≈fd_jac rtol=1e-4