Skip to content

Commit

Permalink
Merge pull request #377 from wsmoses/master
Browse files Browse the repository at this point in the history
Add Enzyme extension
  • Loading branch information
ChrisRackauckas authored Sep 24, 2023
2 parents 37e5328 + 89e10df commit 1e6150e
Show file tree
Hide file tree
Showing 6 changed files with 303 additions and 3 deletions.
7 changes: 6 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
EnumX = "4e289a0a-7415-4d19-859d-a7e5c4648b56"
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
FastLapackInterface = "29a986be-02c6-4525-aec4-84b980013641"
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
Expand All @@ -30,6 +31,7 @@ UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"

[weakdeps]
BlockDiagonals = "0a1fb500-61f7-11e9-3c65-f5ef3456f9f0"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
HYPRE = "b5ffcf37-a2bd-41ab-a3da-4bd9bc8ad771"
IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153"
Expand All @@ -42,6 +44,7 @@ Pardiso = "46dd5b70-b6fb-5a00-ae2d-e8fea33afaf2"
[extensions]
LinearSolveBlockDiagonalsExt = "BlockDiagonals"
LinearSolveCUDAExt = "CUDA"
LinearSolveEnzymeExt = "Enzyme"
LinearSolveHYPREExt = "HYPRE"
LinearSolveIterativeSolversExt = "IterativeSolvers"
LinearSolveKernelAbstractionsExt = "KernelAbstractions"
Expand Down Expand Up @@ -78,6 +81,8 @@ julia = "1.6"

[extras]
BlockDiagonals = "0a1fb500-61f7-11e9-3c65-f5ef3456f9f0"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
HYPRE = "b5ffcf37-a2bd-41ab-a3da-4bd9bc8ad771"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
Expand All @@ -95,4 +100,4 @@ SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test", "IterativeSolvers", "InteractiveUtils", "JET", "KrylovKit", "Pkg", "Random", "SafeTestsets", "MultiFloats", "ForwardDiff", "HYPRE", "MPI", "MKL_jll", "BlockDiagonals"]
test = ["Test", "IterativeSolvers", "InteractiveUtils", "JET", "KrylovKit", "Pkg", "Random", "SafeTestsets", "MultiFloats", "ForwardDiff", "HYPRE", "MPI", "MKL_jll", "BlockDiagonals", "Enzyme", "FiniteDiff"]
166 changes: 166 additions & 0 deletions ext/LinearSolveEnzymeExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
module LinearSolveEnzymeExt

using LinearSolve
using LinearSolve.LinearAlgebra
isdefined(Base, :get_extension) ? (import Enzyme) : (import ..Enzyme)


using Enzyme

using EnzymeCore

function EnzymeCore.EnzymeRules.augmented_primal(config, func::Const{typeof(LinearSolve.init)}, ::Type{RT}, prob::EnzymeCore.Annotation{LP}, alg::Const; kwargs...) where {RT, LP <: LinearSolve.LinearProblem}
res = func.val(prob.val, alg.val; kwargs...)
dres = if EnzymeRules.width(config) == 1
func.val(prob.dval, alg.val; kwargs...)
else
ntuple(Val(EnzymeRules.width(config))) do i
Base.@_inline_meta
func.val(prob.dval[i], alg.val; kwargs...)
end
end
d_A = if EnzymeRules.width(config) == 1
dres.A
else
(dval.A for dval in dres)
end
d_b = if EnzymeRules.width(config) == 1
dres.b
else
(dval.b for dval in dres)
end


prob_d_A = if EnzymeRules.width(config) == 1
prob.dval.A
else
(dval.A for dval in prob.dval)
end
prob_d_b = if EnzymeRules.width(config) == 1
prob.dval.b
else
(dval.b for dval in prob.dval)
end

return EnzymeCore.EnzymeRules.AugmentedReturn(res, dres, (d_A, d_b, prob_d_A, prob_d_b))
end

function EnzymeCore.EnzymeRules.reverse(config, func::Const{typeof(LinearSolve.init)}, ::Type{RT}, cache, prob::EnzymeCore.Annotation{LP}, alg::Const; kwargs...) where {RT, LP <: LinearSolve.LinearProblem}
d_A, d_b, prob_d_A, prob_d_b = cache

if EnzymeRules.width(config) == 1
if d_A !== prob_d_A
prob_d_A .+= d_A
d_A .= 0
end
if d_b !== prob_d_b
prob_d_b .+= d_b
d_b .= 0
end
else
for i in 1:EnzymeRules.width(config)
if d_A !== prob_d_A[i]
prob_d_A[i] .+= d_A[i]
d_A[i] .= 0
end
if d_b !== prob_d_b[i]
prob_d_b[i] .+= d_b[i]
d_b[i] .= 0
end
end
end

return (nothing, nothing)
end

# y=inv(A) B
# dA −= z y^T
# dB += z, where z = inv(A^T) dy
function EnzymeCore.EnzymeRules.augmented_primal(config, func::Const{typeof(LinearSolve.solve!)}, ::Type{RT}, linsolve::EnzymeCore.Annotation{LP}; kwargs...) where {RT, LP <: LinearSolve.LinearCache}
res = func.val(linsolve.val; kwargs...)

dres = if EnzymeRules.width(config) == 1
deepcopy(res)
else
ntuple(Val(EnzymeRules.width(config))) do i
Base.@_inline_meta
deepcopy(res)
end
end

if EnzymeRules.width(config) == 1
dres.u .= 0
else
for dr in dres
dr.u .= 0
end
end

resvals = if EnzymeRules.width(config) == 1
dres.u
else
ntuple(Val(EnzymeRules.width(config))) do i
Base.@_inline_meta
dres[i].u
end
end

dAs = if EnzymeRules.width(config) == 1
(linsolve.dval.A,)
else
ntuple(Val(EnzymeRules.width(config))) do i
Base.@_inline_meta
linsolve.dval[i].A
end
end

dbs = if EnzymeRules.width(config) == 1
(linsolve.dval.b,)
else
ntuple(Val(EnzymeRules.width(config))) do i
Base.@_inline_meta
linsolve.dval[i].b
end
end

cachesolve = deepcopy(linsolve.val)

cache = (copy(res.u), resvals, cachesolve, dAs, dbs)
return EnzymeCore.EnzymeRules.AugmentedReturn(res, dres, cache)
end

function EnzymeCore.EnzymeRules.reverse(config, func::Const{typeof(LinearSolve.solve!)}, ::Type{RT}, cache, linsolve::EnzymeCore.Annotation{LP}; kwargs...) where {RT, LP <: LinearSolve.LinearCache}
y, dys, _linsolve, dAs, dbs = cache

@assert !(typeof(linsolve) <: Const)
@assert !(typeof(linsolve) <: Active)

if EnzymeRules.width(config) == 1
dys = (dys,)
end

for (dA, db, dy) in zip(dAs, dbs, dys)
z = if _linsolve.cacheval isa Factorization
_linsolve.cacheval' \ dy
elseif _linsolve.cacheval isa Tuple && _linsolve.cacheval[1] isa Factorization
_linsolve.cacheval[1]' \ dy
elseif _linsolve.alg isa AbstractKrylovSubspaceMethod
# Doesn't modify `A`, so it's safe to just reuse it
invprob = LinearSolve.LinearProblem(transpose(_linsolve.A), dy)
solve(invprob;
abstol = _linsolve.val.abstol,
reltol = _linsolve.val.reltol,
verbose = _linsolve.val.verbose)
else
error("Algorithm $(_linsolve.alg) is currently not supported by Enzyme rules on LinearSolve.jl. Please open an issue on LinearSolve.jl detailing which algorithm is missing the adjoint handling")
end

dA .-= z * transpose(y)
db .+= z
dy .= eltype(dy)(0)
end

return (nothing,)
end

end
3 changes: 3 additions & 0 deletions src/init.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,8 @@ function __init__()
@require MKL_jll="856f044c-d86e-5d09-b602-aeab76dc8ba7" begin
include("../ext/LinearSolveMKLExt.jl")
end
@require Enzyme="7da242da-08ed-463a-9acd-ee780be4f1d9" begin
include("../ext/LinearSolveEnzymeExt.jl")
end
end
end
2 changes: 1 addition & 1 deletion test/basictests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ end
end
end

test_algs = if VERISON >= v"1.9"
test_algs = if VERSION >= v"1.9"
(LUFactorization(),
QRFactorization(),
SVDFactorization(),
Expand Down
125 changes: 125 additions & 0 deletions test/enzyme.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
using Enzyme, ForwardDiff
using LinearSolve, LinearAlgebra, Test

n = 4
A = rand(n, n);
dA = zeros(n, n);
b1 = rand(n);
db1 = zeros(n);

function f(A, b1; alg = LUFactorization())
prob = LinearProblem(A, b1)

sol1 = solve(prob, alg)

s1 = sol1.u
norm(s1)
end

f(A, b1) # Uses BLAS

Enzyme.autodiff(Reverse, f, Duplicated(copy(A), dA), Duplicated(copy(b1), db1))

dA2 = ForwardDiff.gradient(x->f(x,eltype(x).(b1)), copy(A))
db12 = ForwardDiff.gradient(x->f(eltype(x).(A),x), copy(b1))

@test dA dA2
@test db1 db12

A = rand(n, n);
dA = zeros(n, n);
dA2 = zeros(n, n);
b1 = rand(n);
db1 = zeros(n);
db12 = zeros(n);

#=
# Batch test fails
# Captured in MWE: https://github.com/EnzymeAD/Enzyme.jl/issues/1075
function fbatch(y, A, b1; alg = LUFactorization())
prob = LinearProblem(A, b1)
sol1 = solve(prob, alg)
s1 = sol1.u
y[1] = norm(s1)
nothing
end
y = [0.0]
dy1 = [1.0]
dy2 = [1.0]
Enzyme.autodiff(Reverse, fbatch, BatchDuplicated(y, (dy1, dy2)), BatchDuplicated(copy(A), (dA, dA2)), BatchDuplicated(copy(b1), (db1, db12)))
dA2 = ForwardDiff.gradient(x->f(x,eltype(x).(b1)), copy(A))
db12 = ForwardDiff.gradient(x->f(eltype(x).(A),x), copy(b1))
@test_broken dA ≈ dA_2
@test_broken dA2 ≈ dA_2
@test_broken db1 ≈ db1_2
@test_broken db12 ≈ db1_2
=#

function f(A, b1, b2; alg = LUFactorization())
prob = LinearProblem(A, b1)
cache = init(prob, alg)
s1 = copy(solve!(cache).u)
cache.b = b2
s2 = solve!(cache).u
norm(s1 + s2)
end

A = rand(n, n);
dA = zeros(n, n);
b1 = rand(n);
db1 = zeros(n);
b2 = rand(n);
db2 = zeros(n);

f(A, b1, b2)
Enzyme.autodiff(Reverse, f, Duplicated(copy(A), dA), Duplicated(copy(b1), db1), Duplicated(copy(b2), db2))

dA2 = ForwardDiff.gradient(x->f(x,eltype(x).(b1),eltype(x).(b2)), copy(A))
db12 = ForwardDiff.gradient(x->f(eltype(x).(A),x,eltype(x).(b2)), copy(b1))
db22 = ForwardDiff.gradient(x->f(eltype(x).(A),eltype(x).(b1),x), copy(b2))

@test dA dA2
@test db1 db12
@test db2 db22

function f2(A, b1, b2; alg = RFLUFactorization())
prob = LinearProblem(A, b1)
cache = init(prob, alg)
s1 = copy(solve!(cache).u)
cache.b = b2
s2 = solve!(cache).u
norm(s1 + s2)
end

f2(A, b1, b2)
dA = zeros(n, n);
db1 = zeros(n);
db2 = zeros(n);
Enzyme.autodiff(Reverse, f2, Duplicated(copy(A), dA), Duplicated(copy(b1), db1), Duplicated(copy(b2), db2))

@test dA dA2
@test db1 db12
@test db2 db22

#=
function f3(A, b1, b2; alg = KrylovJL_GMRES())
prob = LinearProblem(A, b1)
cache = init(prob, alg)
s1 = copy(solve!(cache).u)
cache.b = b2
s2 = solve!(cache).u
norm(s1 + s2)
end
Enzyme.autodiff(Reverse, f3, Duplicated(copy(A), dA), Duplicated(copy(b1), db1), Duplicated(copy(b2), db2))
@test dA ≈ dA2 atol=5e-5
@test db1 ≈ db12
@test db2 ≈ db22
=#
3 changes: 2 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,12 @@ const HAS_EXTENSIONS = isdefined(Base, :get_extension)

if GROUP == "All" || GROUP == "Core"
@time @safetestset "Basic Tests" include("basictests.jl")
@time @safetestset "Re-solve" include("resolve.jl")
VERSION >= v"1.9" && @time @safetestset "Re-solve" include("resolve.jl")
@time @safetestset "Zero Initialization Tests" include("zeroinittests.jl")
@time @safetestset "Non-Square Tests" include("nonsquare.jl")
@time @safetestset "SparseVector b Tests" include("sparse_vector.jl")
@time @safetestset "Default Alg Tests" include("default_algs.jl")
VERSION >= v"1.9" && @time @safetestset "Enzyme Derivative Rules" include("enzyme.jl")
@time @safetestset "Traits" include("traits.jl")
end

Expand Down

0 comments on commit 1e6150e

Please sign in to comment.