Skip to content

Commit

Permalink
Fix enzyme batch mode
Browse files Browse the repository at this point in the history
  • Loading branch information
ChrisRackauckas committed Oct 28, 2023
1 parent 5a8aa51 commit e56227a
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 20 deletions.
14 changes: 7 additions & 7 deletions ext/LinearSolveEnzymeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,14 +58,14 @@ function EnzymeCore.EnzymeRules.reverse(config, func::Const{typeof(LinearSolve.i
d_b .= 0
end
else
for i in 1:EnzymeRules.width(config)
if d_A !== prob_d_A[i]
prob_d_A[i] .+= d_A[i]
d_A[i] .= 0
for (_prob_d_A,_d_A,_prob_d_b, _d_b) in zip(prob_d_A, d_A, prob_d_b, d_b)
if _d_A !== _prob_d_A
_prob_d_A .+= _d_A
_d_A .= 0
end
if d_b !== prob_d_b[i]
prob_d_b[i] .+= d_b[i]
d_b[i] .= 0
if _d_b !== _prob_d_b
_prob_d_b .+= _d_b
_d_b .= 0
end
end
end
Expand Down
48 changes: 35 additions & 13 deletions test/enzyme.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,25 @@ b1 = rand(n);
db1 = zeros(n);
db12 = zeros(n);

#=
# Batch test fails
# Captured in MWE: https://github.com/EnzymeAD/Enzyme.jl/issues/1075
# Batch test
n = 4
A = rand(n, n);
dA = zeros(n, n);
dA2 = zeros(n, n);
b1 = rand(n);
db1 = zeros(n);
db12 = zeros(n);

function fbatch(y, A, b1; alg = LUFactorization())
function f(A, b1; alg = LUFactorization())
prob = LinearProblem(A, b1)
sol1 = solve(prob, alg)
s1 = sol1.u
norm(s1)
end

function fbatch(y, A, b1; alg = LUFactorization())
prob = LinearProblem(A, b1)
sol1 = solve(prob, alg)
s1 = sol1.u
y[1] = norm(s1)
nothing
Expand All @@ -50,16 +60,28 @@ end
y = [0.0]
dy1 = [1.0]
dy2 = [1.0]
Enzyme.autodiff(Reverse, fbatch, Duplicated(y, dy1), Duplicated(copy(A), dA), Duplicated(copy(b1), db1))

@test y[1] f(copy(A),b1)
dA_2 = ForwardDiff.gradient(x->f(x,eltype(x).(b1)), copy(A))
db1_2 = ForwardDiff.gradient(x->f(eltype(x).(A),x), copy(b1))

@test dA dA_2
@test db1 db1_2

y .= 0
dy1 .= 1
dy2 .= 1
dA .= 0
dA2 .= 0
db1 .= 0
db12 .= 0
Enzyme.autodiff(Reverse, fbatch, BatchDuplicated(y, (dy1, dy2)), BatchDuplicated(copy(A), (dA, dA2)), BatchDuplicated(copy(b1), (db1, db12)))

dA2 = ForwardDiff.gradient(x->f(x,eltype(x).(b1)), copy(A))
db12 = ForwardDiff.gradient(x->f(eltype(x).(A),x), copy(b1))
@test_broken dA ≈ dA_2
@test_broken dA2 ≈ dA_2
@test_broken db1 ≈ db1_2
@test_broken db12 ≈ db1_2
=#
@test dA dA_2
@test db1 db1_2
@test dA2 dA_2
@test db12 db1_2

function f(A, b1, b2; alg = LUFactorization())
prob = LinearProblem(A, b1)
Expand Down

0 comments on commit e56227a

Please sign in to comment.