diff --git a/test/enzyme.jl b/test/enzyme.jl index dbacb70f1..e00553d69 100644 --- a/test/enzyme.jl +++ b/test/enzyme.jl @@ -6,10 +6,8 @@ A = rand(n, n); dA = zeros(n, n); b1 = rand(n); db1 = zeros(n); -b2 = rand(n); -db2 = zeros(n); -function f(A, b1, b2; alg = LUFactorization()) +function f(A, b1; alg = LUFactorization()) prob = LinearProblem(A, b1) sol1 = solve(prob, alg) @@ -18,16 +16,15 @@ function f(A, b1, b2; alg = LUFactorization()) norm(s1) end -f(A, b1, b2) # Uses BLAS +f(A, b1) # Uses BLAS -Enzyme.autodiff(Reverse, f, Duplicated(copy(A), dA), Duplicated(copy(b1), db1), Duplicated(copy(b2), db2)) +Enzyme.autodiff(Reverse, f, Duplicated(copy(A), dA), Duplicated(copy(b1), db1)) dA2 = FiniteDiff.finite_difference_gradient(x->f(x,b1, b2), copy(A)) db12 = FiniteDiff.finite_difference_gradient(x->f(A,x, b2), copy(b1)) @test dA ≈ dA2 @test db1 ≈ db12 -@test db2 == zeros(4) A = rand(n, n); dA = zeros(n, n); @@ -36,9 +33,6 @@ b1 = rand(n); db1 = zeros(n); db12 = zeros(n); -b2 = rand(n); -db2 = zeros(n); -db22 = zeros(n); - -@test_broken Enzyme.autodiff(Reverse, f, Duplicated(copy(A), dA), BatchDuplicated(copy(b1), (db1, db12)), BatchDuplicated(copy(b2), (db2, db22))) -@test_broken Enzyme.autodiff(Reverse, f, BatchDuplicated(copy(A), (dA, dA2)), Duplicated(copy(b1), db1), Duplicated(copy(b2), db2)) \ No newline at end of file +# This is not legal, all args need to be batch'd at the same size +@test_broken Enzyme.autodiff(Reverse, f, Duplicated(copy(A), dA), BatchDuplicated(copy(b1), (db1, db12))) +@test_broken Enzyme.autodiff(Reverse, f, BatchDuplicated(copy(A), (dA, dA2)), Duplicated(copy(b1), db1))