Skip to content

Commit

Permalink
Merge pull request #416 from sharanry/enzyme_forward
Browse files Browse the repository at this point in the history
Add forward enzyme rules for init and solve
  • Loading branch information
ChrisRackauckas authored Nov 8, 2023
2 parents b9da6ac + 0935919 commit 9b540ba
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 1 deletion.
47 changes: 47 additions & 0 deletions ext/LinearSolveEnzymeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,53 @@ using Enzyme

using EnzymeCore

function EnzymeCore.EnzymeRules.forward(func::Const{typeof(LinearSolve.init)}, ::Type{RT}, prob::EnzymeCore.Annotation{LP}, alg::Const; kwargs...) where {RT, LP <: LinearSolve.LinearProblem}
@assert !(prob isa Const)
res = func.val(prob.val, alg.val; kwargs...)
if RT <: Const
return res
end
dres = func.val(prob.dval, alg.val; kwargs...)
dres.b .= res.b == dres.b ? zero(dres.b) : dres.b
dres.A .= res.A == dres.A ? zero(dres.A) : dres.A
if RT <: DuplicatedNoNeed
return dres
elseif RT <: Duplicated
return Duplicated(res, dres)
end
error("Unsupported return type $RT")
end

function EnzymeCore.EnzymeRules.forward(func::Const{typeof(LinearSolve.solve!)}, ::Type{RT}, linsolve::EnzymeCore.Annotation{LP}; kwargs...) where {RT, LP <: LinearSolve.LinearCache}
@assert !(linsolve isa Const)

res = func.val(linsolve.val; kwargs...)

if RT <: Const
return res
end
if linsolve.val.alg isa LinearSolve.AbstractKrylovSubspaceMethod
error("Algorithm $(_linsolve.alg) is currently not supported by Enzyme rules on LinearSolve.jl. Please open an issue on LinearSolve.jl detailing which algorithm is missing the adjoint handling")
end
b = deepcopy(linsolve.val.b)

db = linsolve.dval.b
dA = linsolve.dval.A

linsolve.val.b = db - dA * res.u
dres = func.val(linsolve.val; kwargs...)

linsolve.val.b = b

if RT <: DuplicatedNoNeed
return dres
elseif RT <: Duplicated
return Duplicated(res, dres)
end

return Duplicated(res, dres)
end

function EnzymeCore.EnzymeRules.augmented_primal(config, func::Const{typeof(LinearSolve.init)}, ::Type{RT}, prob::EnzymeCore.Annotation{LP}, alg::Const; kwargs...) where {RT, LP <: LinearSolve.LinearProblem}
res = func.val(prob.val, alg.val; kwargs...)
dres = if EnzymeRules.width(config) == 1
Expand Down
54 changes: 53 additions & 1 deletion test/enzyme.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
using Enzyme, ForwardDiff
using LinearSolve, LinearAlgebra, Test
using FiniteDiff
using SafeTestsets

n = 4
A = rand(n, n);
Expand Down Expand Up @@ -161,4 +163,54 @@ Enzyme.autodiff(Reverse, f3, Duplicated(copy(A), dA), Duplicated(copy(b1), db1),
@test dA ≈ dA2 atol=5e-5
@test db1 ≈ db12
@test db2 ≈ db22
=#
=#

A = rand(n, n);
dA = zeros(n, n);
b1 = rand(n);
for alg in (
LUFactorization(),
RFLUFactorization(),
# KrylovJL_GMRES(), fails
)
@show alg
function fb(b)
prob = LinearProblem(A, b)

sol1 = solve(prob, alg)

sum(sol1.u)
end
fb(b1)

fd_jac = FiniteDiff.finite_difference_jacobian(fb, b1) |> vec
@show fd_jac

en_jac = map(onehot(b1)) do db1
eres = Enzyme.autodiff(Forward, fb, Duplicated(copy(b1), db1))
eres[1]
end |> collect
@show en_jac

@test en_jac fd_jac rtol=1e-4

function fA(A)
prob = LinearProblem(A, b1)

sol1 = solve(prob, alg)

sum(sol1.u)
end
fA(A)

fd_jac = FiniteDiff.finite_difference_jacobian(fA, A) |> vec
@show fd_jac

en_jac = map(onehot(A)) do dA
eres = Enzyme.autodiff(Forward, fA, Duplicated(copy(A), dA))
eres[1]
end |> collect
@show en_jac

@test en_jac fd_jac rtol=1e-4
end

0 comments on commit 9b540ba

Please sign in to comment.