diff --git a/test/enzyme.jl b/test/enzyme.jl index 323f3e60..9192b63a 100644 --- a/test/enzyme.jl +++ b/test/enzyme.jl @@ -177,47 +177,39 @@ Enzyme.autodiff(Reverse, f3, Duplicated(copy(A), dA), Duplicated(copy(b1), db1), A = rand(n, n); dA = zeros(n, n); b1 = rand(n); -for alg in ( + +function fnice(A, b, alg) + prob = LinearProblem(A, b) + sol1 = solve(prob, alg) + return sum(sol1.u) +end + +@testset for alg in ( LUFactorization(), RFLUFactorization() # KrylovJL_GMRES(), fails ) - @show alg - function fb(b) - prob = LinearProblem(A, b) - - sol1 = solve(prob, alg) + fb_closure = b -> fnice(A, b, alg) - sum(sol1.u) - end - fb(b1) - - fd_jac = FiniteDiff.finite_difference_jacobian(fb, b1) |> vec + fd_jac = FiniteDiff.finite_difference_jacobian(fb_closure, b1) |> vec @show fd_jac en_jac = map(onehot(b1)) do db1 - eres = Enzyme.autodiff(Forward, fb, Duplicated(copy(b1), db1)) - eres[1] + return only(Enzyme.autodiff(set_runtime_activity(Forward), fnice, + Const(A), Duplicated(b1, db1), Const(alg))) end |> collect @show en_jac @test en_jac≈fd_jac rtol=1e-4 - function fA(A) - prob = LinearProblem(A, b1) - - sol1 = solve(prob, alg) + fA_closure = A -> fnice(A, b1, alg) - sum(sol1.u) - end - fA(A) - - fd_jac = FiniteDiff.finite_difference_jacobian(fA, A) |> vec + fd_jac = FiniteDiff.finite_difference_jacobian(fA_closure, A) |> vec @show fd_jac en_jac = map(onehot(A)) do dA - eres = Enzyme.autodiff(Forward, fA, Duplicated(copy(A), dA)) - eres[1] - end |> collect + return only(Enzyme.autodiff(set_runtime_activity(Forward), fnice, + Duplicated(A, dA), Const(b1), Const(alg))) + end |> collect |> (x -> reshape(x, n, n)) @show en_jac @test en_jac≈fd_jac rtol=1e-4