-
-
Notifications
You must be signed in to change notification settings - Fork 56
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Enzyme extension #377
Add Enzyme extension #377
Changes from 6 commits
bb6d623
9f8d18f
391b602
ce7ffc0
a08386d
9273a20
84c5196
bb93d68
9d19db2
f9b0784
3b39753
cbb5f1d
9630121
b0d228d
c2ad2db
54f0722
e4f0785
d69af77
be91ba2
89e10df
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
module LinearSolveEnzymeExt | ||
|
||
using LinearSolve | ||
isdefined(Base, :get_extension) ? (import Enzyme) : (import ..Enzyme) | ||
|
||
|
||
using Enzyme | ||
|
||
using EnzymeCore | ||
|
||
function EnzymeCore.EnzymeRules.augmented_primal(config, func::Const{typeof(LinearSolve.init)}, ::Type{RT}, prob::EnzymeCore.Annotation{LP}, alg::Const; kwargs...) where {RT, LP <: LinearSolve.LinearProblem} | ||
res = func.val(prob.val, alg.val; kwargs...) | ||
dres = if EnzymeRules.width(config) == 1 | ||
func.val(prob.dval, alg.val; kwargs...) | ||
else | ||
(func.val(dval, alg.val; kwargs...) for dval in prob.dval) | ||
end | ||
return EnzymeCore.EnzymeRules.AugmentedReturn(res, dres, nothing) | ||
end | ||
|
||
function EnzymeCore.EnzymeRules.reverse(config, func::Const{typeof(LinearSolve.init)}, ::Type{RT}, cache, prob::EnzymeCore.Annotation{LP}, alg::Const; kwargs...) where {RT, LP <: LinearSolve.LinearProblem} | ||
return (nothing, nothing) | ||
end | ||
|
||
# y=inv(A) B | ||
# dA −= z y^T | ||
# dB += z, where z = inv(A^T) dy | ||
function EnzymeCore.EnzymeRules.augmented_primal(config, func::Const{typeof(LinearSolve.solve!)}, ::Type{RT}, linsolve::EnzymeCore.Annotation{LP}; kwargs...) where {RT, LP <: LinearSolve.LinearCache} | ||
res = func.val(linsolve.val; kwargs...) | ||
dres = if EnzymeRules.width(config) == 1 | ||
deepcopy(res) | ||
else | ||
(deepcopy(res) for dval in linsolve.dval) | ||
end | ||
|
||
if EnzymeRules.width(config) == 1 | ||
dres.u .= 0 | ||
else | ||
for dr in dres | ||
dr.u .= 0 | ||
end | ||
end | ||
|
||
resvals = if EnzymeRules.width(config) == 1 | ||
dres.u | ||
else | ||
(dr.u for dr in dres) | ||
end | ||
|
||
cache = (copy(linsolve.val.A), res, resvals) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this copy necessary? |
||
return EnzymeCore.EnzymeRules.AugmentedReturn(res, dres, cache) | ||
end | ||
|
||
function EnzymeCore.EnzymeRules.reverse(config, func::Const{typeof(LinearSolve.solve!)}, ::Type{RT}, cache, linsolve::EnzymeCore.Annotation{LP}; kwargs...) where {RT, LP <: LinearSolve.LinearCache} | ||
A, y, dys = cache | ||
|
||
@assert !(typeof(linsolve) <: Const) | ||
@assert !(typeof(linsolve) <: Active) | ||
|
||
if EnzymeRules.width(config) == 1 | ||
dys = (dys,) | ||
end | ||
|
||
dAs = if EnzymeRules.width(config) == 1 | ||
(linsolve.dval.A,) | ||
else | ||
(dval.A for dval in linsolve.dval) | ||
end | ||
|
||
dbs = if EnzymeRules.width(config) == 1 | ||
(linsolve.dval.b,) | ||
else | ||
(dval.b for dval in linsolve.dval) | ||
end | ||
|
||
for (dA, db, dy) in zip(dAs, dbs, dys) | ||
invprob = LinearSolve.LinearProblem(transpose(A), dy) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In the forward pass the matrix |
||
z = solve(invprob; | ||
abstol = linsolve.val.abstol, | ||
reltol = linsolve.val.reltol, | ||
verbose = linsolve.val.verbose) | ||
|
||
dA .-= z * transpose(y) | ||
db .+= z | ||
dy .= eltype(dy)(0) | ||
end | ||
|
||
return (nothing,) | ||
end | ||
|
||
end |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
using Enzyme, FiniteDiff | ||
using LinearSolve, LinearAlgebra, Test | ||
|
||
n = 4 | ||
A = rand(n, n); | ||
dA = zeros(n, n); | ||
b1 = rand(n); | ||
db1 = zeros(n); | ||
b2 = rand(n); | ||
db2 = zeros(n); | ||
|
||
function f(A, b1, b2; alg = LUFactorization()) | ||
prob = LinearProblem(A, b1) | ||
|
||
sol1 = solve(prob, alg) | ||
|
||
s1 = sol1.u | ||
norm(s1) | ||
end | ||
|
||
f(A, b1, b2) # Uses BLAS | ||
|
||
Enzyme.autodiff(Reverse, f, Duplicated(copy(A), dA), Duplicated(copy(b1), db1), Duplicated(copy(b2), db2)) | ||
|
||
dA2 = FiniteDiff.finite_difference_gradient(x->f(x,b1, b2), copy(A)) | ||
db12 = FiniteDiff.finite_difference_gradient(x->f(A,x, b2), copy(b1)) | ||
|
||
@test dA ≈ dA2 | ||
@test db1 ≈ db12 | ||
@test db2 == zeros(4) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this one required? It seems like it doesn't do much?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Init hits that global variable stuff, so we need a rule for corresponding shadow initialization.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see