Skip to content

Commit

Permalink
Add tests for other algs and handle cases of algs currently not suppo…
Browse files Browse the repository at this point in the history
…rted
  • Loading branch information
sharanry committed Nov 6, 2023
1 parent 950864c commit 3a3102b
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 30 deletions.
4 changes: 3 additions & 1 deletion ext/LinearSolveEnzymeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@ function EnzymeCore.EnzymeRules.forward(func::Const{typeof(LinearSolve.solve!)},
if RT <: Const
return res

Check warning on line 34 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L33-L34

Added lines #L33 - L34 were not covered by tests
end

if linsolve.alg isa LinearSolve.AbstractKrylovSubspaceMethod
error("Algorithm $(_linsolve.alg) is currently not supported by Enzyme rules on LinearSolve.jl. Please open an issue on LinearSolve.jl detailing which algorithm is missing the adjoint handling")

Check warning on line 37 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L36-L37

Added lines #L36 - L37 were not covered by tests
end
b = deepcopy(linsolve.val.b)

Check warning on line 39 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L39

Added line #L39 was not covered by tests

db = linsolve.dval.b
Expand Down
66 changes: 37 additions & 29 deletions test/enzyme.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using Enzyme, ForwardDiff
using LinearSolve, LinearAlgebra, Test
using FiniteDiff
using SafeTestsets

n = 4
A = rand(n, n);
Expand Down Expand Up @@ -164,46 +165,53 @@ Enzyme.autodiff(Reverse, f3, Duplicated(copy(A), dA), Duplicated(copy(b1), db1),
@test db2 ≈ db22
=#


A = rand(n, n);
dA = zeros(n, n);
b1 = rand(n);
function fb(b; alg = LUFactorization())
prob = LinearProblem(A, b)
for alg in (
LUFactorization(),
RFLUFactorization(),
# KrylovJL_GMRES(), fails
)
alg_str = string(alg)
@show alg_str
function fb(b)
prob = LinearProblem(A, b)

sol1 = solve(prob, alg)
sol1 = solve(prob, alg)

sum(sol1.u)
end
fb(b1)
sum(sol1.u)
end
fb(b1)

fd_jac = FiniteDiff.finite_difference_jacobian(fb, b1) |> vec
@show fd_jac
fd_jac = FiniteDiff.finite_difference_jacobian(fb, b1) |> vec
@show fd_jac

en_jac = map(onehot(b1)) do db1
eres = Enzyme.autodiff(Forward, fb, Duplicated(copy(b1), db1))
eres[1]
end |> collect
@show en_jac
en_jac = map(onehot(b1)) do db1
eres = Enzyme.autodiff(Forward, fb, Duplicated(copy(b1), db1))
eres[1]
end |> collect
@show en_jac

@test en_jac fd_jac rtol=1e-6
@test en_jac fd_jac rtol=1e-6

function fA(A; alg = LUFactorization())
prob = LinearProblem(A, b1)
function fA(A)
prob = LinearProblem(A, b1)

sol1 = solve(prob, alg)
sol1 = solve(prob, alg)

sum(sol1.u)
end
fA(A)
sum(sol1.u)
end
fA(A)

fd_jac = FiniteDiff.finite_difference_jacobian(fA, A) |> vec
@show fd_jac
fd_jac = FiniteDiff.finite_difference_jacobian(fA, A) |> vec
@show fd_jac

en_jac = map(onehot(A)) do dA
eres = Enzyme.autodiff(Forward, fA, Duplicated(copy(A), dA))
eres[1]
end |> collect
@show en_jac
en_jac = map(onehot(A)) do dA
eres = Enzyme.autodiff(Forward, fA, Duplicated(copy(A), dA))
eres[1]
end |> collect
@show en_jac

@test en_jac fd_jac rtol=1e-6
@test en_jac fd_jac rtol=1e-6
end

0 comments on commit 3a3102b

Please sign in to comment.