-
-
Notifications
You must be signed in to change notification settings - Fork 55
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #377 from wsmoses/master
Add Enzyme extension
- Loading branch information
Showing
6 changed files
with
303 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,166 @@ | ||
module LinearSolveEnzymeExt | ||
|
||
using LinearSolve | ||
using LinearSolve.LinearAlgebra | ||
isdefined(Base, :get_extension) ? (import Enzyme) : (import ..Enzyme) | ||
|
||
|
||
using Enzyme | ||
|
||
using EnzymeCore | ||
|
||
function EnzymeCore.EnzymeRules.augmented_primal(config, func::Const{typeof(LinearSolve.init)}, ::Type{RT}, prob::EnzymeCore.Annotation{LP}, alg::Const; kwargs...) where {RT, LP <: LinearSolve.LinearProblem} | ||
res = func.val(prob.val, alg.val; kwargs...) | ||
dres = if EnzymeRules.width(config) == 1 | ||
func.val(prob.dval, alg.val; kwargs...) | ||
else | ||
ntuple(Val(EnzymeRules.width(config))) do i | ||
Base.@_inline_meta | ||
func.val(prob.dval[i], alg.val; kwargs...) | ||
end | ||
end | ||
d_A = if EnzymeRules.width(config) == 1 | ||
dres.A | ||
else | ||
(dval.A for dval in dres) | ||
end | ||
d_b = if EnzymeRules.width(config) == 1 | ||
dres.b | ||
else | ||
(dval.b for dval in dres) | ||
end | ||
|
||
|
||
prob_d_A = if EnzymeRules.width(config) == 1 | ||
prob.dval.A | ||
else | ||
(dval.A for dval in prob.dval) | ||
end | ||
prob_d_b = if EnzymeRules.width(config) == 1 | ||
prob.dval.b | ||
else | ||
(dval.b for dval in prob.dval) | ||
end | ||
|
||
return EnzymeCore.EnzymeRules.AugmentedReturn(res, dres, (d_A, d_b, prob_d_A, prob_d_b)) | ||
end | ||
|
||
function EnzymeCore.EnzymeRules.reverse(config, func::Const{typeof(LinearSolve.init)}, ::Type{RT}, cache, prob::EnzymeCore.Annotation{LP}, alg::Const; kwargs...) where {RT, LP <: LinearSolve.LinearProblem} | ||
d_A, d_b, prob_d_A, prob_d_b = cache | ||
|
||
if EnzymeRules.width(config) == 1 | ||
if d_A !== prob_d_A | ||
prob_d_A .+= d_A | ||
d_A .= 0 | ||
end | ||
if d_b !== prob_d_b | ||
prob_d_b .+= d_b | ||
d_b .= 0 | ||
end | ||
else | ||
for i in 1:EnzymeRules.width(config) | ||
if d_A !== prob_d_A[i] | ||
prob_d_A[i] .+= d_A[i] | ||
d_A[i] .= 0 | ||
end | ||
if d_b !== prob_d_b[i] | ||
prob_d_b[i] .+= d_b[i] | ||
d_b[i] .= 0 | ||
end | ||
end | ||
end | ||
|
||
return (nothing, nothing) | ||
end | ||
|
||
# y=inv(A) B | ||
# dA −= z y^T | ||
# dB += z, where z = inv(A^T) dy | ||
function EnzymeCore.EnzymeRules.augmented_primal(config, func::Const{typeof(LinearSolve.solve!)}, ::Type{RT}, linsolve::EnzymeCore.Annotation{LP}; kwargs...) where {RT, LP <: LinearSolve.LinearCache} | ||
res = func.val(linsolve.val; kwargs...) | ||
|
||
dres = if EnzymeRules.width(config) == 1 | ||
deepcopy(res) | ||
else | ||
ntuple(Val(EnzymeRules.width(config))) do i | ||
Base.@_inline_meta | ||
deepcopy(res) | ||
end | ||
end | ||
|
||
if EnzymeRules.width(config) == 1 | ||
dres.u .= 0 | ||
else | ||
for dr in dres | ||
dr.u .= 0 | ||
end | ||
end | ||
|
||
resvals = if EnzymeRules.width(config) == 1 | ||
dres.u | ||
else | ||
ntuple(Val(EnzymeRules.width(config))) do i | ||
Base.@_inline_meta | ||
dres[i].u | ||
end | ||
end | ||
|
||
dAs = if EnzymeRules.width(config) == 1 | ||
(linsolve.dval.A,) | ||
else | ||
ntuple(Val(EnzymeRules.width(config))) do i | ||
Base.@_inline_meta | ||
linsolve.dval[i].A | ||
end | ||
end | ||
|
||
dbs = if EnzymeRules.width(config) == 1 | ||
(linsolve.dval.b,) | ||
else | ||
ntuple(Val(EnzymeRules.width(config))) do i | ||
Base.@_inline_meta | ||
linsolve.dval[i].b | ||
end | ||
end | ||
|
||
cachesolve = deepcopy(linsolve.val) | ||
|
||
cache = (copy(res.u), resvals, cachesolve, dAs, dbs) | ||
return EnzymeCore.EnzymeRules.AugmentedReturn(res, dres, cache) | ||
end | ||
|
||
function EnzymeCore.EnzymeRules.reverse(config, func::Const{typeof(LinearSolve.solve!)}, ::Type{RT}, cache, linsolve::EnzymeCore.Annotation{LP}; kwargs...) where {RT, LP <: LinearSolve.LinearCache} | ||
y, dys, _linsolve, dAs, dbs = cache | ||
|
||
@assert !(typeof(linsolve) <: Const) | ||
@assert !(typeof(linsolve) <: Active) | ||
|
||
if EnzymeRules.width(config) == 1 | ||
dys = (dys,) | ||
end | ||
|
||
for (dA, db, dy) in zip(dAs, dbs, dys) | ||
z = if _linsolve.cacheval isa Factorization | ||
_linsolve.cacheval' \ dy | ||
elseif _linsolve.cacheval isa Tuple && _linsolve.cacheval[1] isa Factorization | ||
_linsolve.cacheval[1]' \ dy | ||
elseif _linsolve.alg isa AbstractKrylovSubspaceMethod | ||
# Doesn't modify `A`, so it's safe to just reuse it | ||
invprob = LinearSolve.LinearProblem(transpose(_linsolve.A), dy) | ||
solve(invprob; | ||
abstol = _linsolve.val.abstol, | ||
reltol = _linsolve.val.reltol, | ||
verbose = _linsolve.val.verbose) | ||
else | ||
error("Algorithm $(_linsolve.alg) is currently not supported by Enzyme rules on LinearSolve.jl. Please open an issue on LinearSolve.jl detailing which algorithm is missing the adjoint handling") | ||
end | ||
|
||
dA .-= z * transpose(y) | ||
db .+= z | ||
dy .= eltype(dy)(0) | ||
end | ||
|
||
return (nothing,) | ||
end | ||
|
||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
using Enzyme, ForwardDiff | ||
using LinearSolve, LinearAlgebra, Test | ||
|
||
n = 4 | ||
A = rand(n, n); | ||
dA = zeros(n, n); | ||
b1 = rand(n); | ||
db1 = zeros(n); | ||
|
||
function f(A, b1; alg = LUFactorization()) | ||
prob = LinearProblem(A, b1) | ||
|
||
sol1 = solve(prob, alg) | ||
|
||
s1 = sol1.u | ||
norm(s1) | ||
end | ||
|
||
f(A, b1) # Uses BLAS | ||
|
||
Enzyme.autodiff(Reverse, f, Duplicated(copy(A), dA), Duplicated(copy(b1), db1)) | ||
|
||
dA2 = ForwardDiff.gradient(x->f(x,eltype(x).(b1)), copy(A)) | ||
db12 = ForwardDiff.gradient(x->f(eltype(x).(A),x), copy(b1)) | ||
|
||
@test dA ≈ dA2 | ||
@test db1 ≈ db12 | ||
|
||
A = rand(n, n); | ||
dA = zeros(n, n); | ||
dA2 = zeros(n, n); | ||
b1 = rand(n); | ||
db1 = zeros(n); | ||
db12 = zeros(n); | ||
|
||
#= | ||
# Batch test fails | ||
# Captured in MWE: https://github.com/EnzymeAD/Enzyme.jl/issues/1075 | ||
function fbatch(y, A, b1; alg = LUFactorization()) | ||
prob = LinearProblem(A, b1) | ||
sol1 = solve(prob, alg) | ||
s1 = sol1.u | ||
y[1] = norm(s1) | ||
nothing | ||
end | ||
y = [0.0] | ||
dy1 = [1.0] | ||
dy2 = [1.0] | ||
Enzyme.autodiff(Reverse, fbatch, BatchDuplicated(y, (dy1, dy2)), BatchDuplicated(copy(A), (dA, dA2)), BatchDuplicated(copy(b1), (db1, db12))) | ||
dA2 = ForwardDiff.gradient(x->f(x,eltype(x).(b1)), copy(A)) | ||
db12 = ForwardDiff.gradient(x->f(eltype(x).(A),x), copy(b1)) | ||
@test_broken dA ≈ dA_2 | ||
@test_broken dA2 ≈ dA_2 | ||
@test_broken db1 ≈ db1_2 | ||
@test_broken db12 ≈ db1_2 | ||
=# | ||
|
||
function f(A, b1, b2; alg = LUFactorization()) | ||
prob = LinearProblem(A, b1) | ||
cache = init(prob, alg) | ||
s1 = copy(solve!(cache).u) | ||
cache.b = b2 | ||
s2 = solve!(cache).u | ||
norm(s1 + s2) | ||
end | ||
|
||
A = rand(n, n); | ||
dA = zeros(n, n); | ||
b1 = rand(n); | ||
db1 = zeros(n); | ||
b2 = rand(n); | ||
db2 = zeros(n); | ||
|
||
f(A, b1, b2) | ||
Enzyme.autodiff(Reverse, f, Duplicated(copy(A), dA), Duplicated(copy(b1), db1), Duplicated(copy(b2), db2)) | ||
|
||
dA2 = ForwardDiff.gradient(x->f(x,eltype(x).(b1),eltype(x).(b2)), copy(A)) | ||
db12 = ForwardDiff.gradient(x->f(eltype(x).(A),x,eltype(x).(b2)), copy(b1)) | ||
db22 = ForwardDiff.gradient(x->f(eltype(x).(A),eltype(x).(b1),x), copy(b2)) | ||
|
||
@test dA ≈ dA2 | ||
@test db1 ≈ db12 | ||
@test db2 ≈ db22 | ||
|
||
function f2(A, b1, b2; alg = RFLUFactorization()) | ||
prob = LinearProblem(A, b1) | ||
cache = init(prob, alg) | ||
s1 = copy(solve!(cache).u) | ||
cache.b = b2 | ||
s2 = solve!(cache).u | ||
norm(s1 + s2) | ||
end | ||
|
||
f2(A, b1, b2) | ||
dA = zeros(n, n); | ||
db1 = zeros(n); | ||
db2 = zeros(n); | ||
Enzyme.autodiff(Reverse, f2, Duplicated(copy(A), dA), Duplicated(copy(b1), db1), Duplicated(copy(b2), db2)) | ||
|
||
@test dA ≈ dA2 | ||
@test db1 ≈ db12 | ||
@test db2 ≈ db22 | ||
|
||
#= | ||
function f3(A, b1, b2; alg = KrylovJL_GMRES()) | ||
prob = LinearProblem(A, b1) | ||
cache = init(prob, alg) | ||
s1 = copy(solve!(cache).u) | ||
cache.b = b2 | ||
s2 = solve!(cache).u | ||
norm(s1 + s2) | ||
end | ||
Enzyme.autodiff(Reverse, f3, Duplicated(copy(A), dA), Duplicated(copy(b1), db1), Duplicated(copy(b2), db2)) | ||
@test dA ≈ dA2 atol=5e-5 | ||
@test db1 ≈ db12 | ||
@test db2 ≈ db22 | ||
=# |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters