Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Enzyme extension #377

Merged
merged 20 commits into from
Sep 24, 2023
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
EnumX = "4e289a0a-7415-4d19-859d-a7e5c4648b56"
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
FastLapackInterface = "29a986be-02c6-4525-aec4-84b980013641"
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
Expand All @@ -30,6 +31,7 @@ UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"

[weakdeps]
BlockDiagonals = "0a1fb500-61f7-11e9-3c65-f5ef3456f9f0"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
HYPRE = "b5ffcf37-a2bd-41ab-a3da-4bd9bc8ad771"
IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153"
Expand All @@ -42,6 +44,7 @@ Pardiso = "46dd5b70-b6fb-5a00-ae2d-e8fea33afaf2"
[extensions]
LinearSolveBlockDiagonalsExt = "BlockDiagonals"
LinearSolveCUDAExt = "CUDA"
LinearSolveEnzymeExt = "Enzyme"
LinearSolveHYPREExt = "HYPRE"
LinearSolveIterativeSolversExt = "IterativeSolvers"
LinearSolveKernelAbstractionsExt = "KernelAbstractions"
Expand Down Expand Up @@ -78,6 +81,8 @@ julia = "1.6"

[extras]
BlockDiagonals = "0a1fb500-61f7-11e9-3c65-f5ef3456f9f0"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
HYPRE = "b5ffcf37-a2bd-41ab-a3da-4bd9bc8ad771"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
Expand All @@ -95,4 +100,4 @@ SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test", "IterativeSolvers", "InteractiveUtils", "JET", "KrylovKit", "Pkg", "Random", "SafeTestsets", "MultiFloats", "ForwardDiff", "HYPRE", "MPI", "MKL_jll", "BlockDiagonals"]
test = ["Test", "IterativeSolvers", "InteractiveUtils", "JET", "KrylovKit", "Pkg", "Random", "SafeTestsets", "MultiFloats", "ForwardDiff", "HYPRE", "MPI", "MKL_jll", "BlockDiagonals", "Enzyme", "FiniteDiff"]
137 changes: 137 additions & 0 deletions ext/LinearSolveEnzymeExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
module LinearSolveEnzymeExt

using LinearSolve
using LinearSolve.LinearAlgebra
isdefined(Base, :get_extension) ? (import Enzyme) : (import ..Enzyme)


using Enzyme

using EnzymeCore

function EnzymeCore.EnzymeRules.augmented_primal(config, func::Const{typeof(LinearSolve.init)}, ::Type{RT}, prob::EnzymeCore.Annotation{LP}, alg::Const; kwargs...) where {RT, LP <: LinearSolve.LinearProblem}
res = func.val(prob.val, alg.val; kwargs...)
dres = if EnzymeRules.width(config) == 1
func.val(prob.dval, alg.val; kwargs...)
else
(func.val(dval, alg.val; kwargs...) for dval in prob.dval)
end
d_A = if EnzymeRules.width(config) == 1
dres.A
else
(dval.A for dval in dres)
end
d_b = if EnzymeRules.width(config) == 1
dres.b
else
(dval.b for dval in dres)
end
return EnzymeCore.EnzymeRules.AugmentedReturn(res, dres, (d_A, d_b))
end

function EnzymeCore.EnzymeRules.reverse(config, func::Const{typeof(LinearSolve.init)}, ::Type{RT}, cache, prob::EnzymeCore.Annotation{LP}, alg::Const; kwargs...) where {RT, LP <: LinearSolve.LinearProblem}
d_A, d_b = cache

if EnzymeRules.width(config) == 1
if d_A !== prob.dval.A
prob.dval.A .+= d_A
d_A .= 0
end
if d_b !== prob.dval.b
prob.dval.b .+= d_b
d_b .= 0
end
else
for i in 1:EnzymeRules.width(config)
if d_A !== prob.dval.A
prob.dval.A[i] .+= d_A[i]
d_A[i] .= 0
end
if d_b !== prob.dval.b
prob.dval.b[i] .+= d_b[i]
d_b[i] .= 0
end
end
end

return (nothing, nothing)
end

# y=inv(A) B
# dA −= z y^T
# dB += z, where z = inv(A^T) dy
function EnzymeCore.EnzymeRules.augmented_primal(config, func::Const{typeof(LinearSolve.solve!)}, ::Type{RT}, linsolve::EnzymeCore.Annotation{LP}; kwargs...) where {RT, LP <: LinearSolve.LinearCache}
res = func.val(linsolve.val; kwargs...)

dres = if EnzymeRules.width(config) == 1
deepcopy(res)
else
(deepcopy(res) for dval in linsolve.dval)
end

if EnzymeRules.width(config) == 1
dres.u .= 0
else
for dr in dres
dr.u .= 0
end
end

resvals = if EnzymeRules.width(config) == 1
dres.u
else
(dr.u for dr in dres)
end

cache = (res, resvals)
return EnzymeCore.EnzymeRules.AugmentedReturn(res, dres, cache)
end

function EnzymeCore.EnzymeRules.reverse(config, func::Const{typeof(LinearSolve.solve!)}, ::Type{RT}, cache, linsolve::EnzymeCore.Annotation{LP}; kwargs...) where {RT, LP <: LinearSolve.LinearCache}
y, dys = cache
_linsolve = linsolve.val
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is still wrong, because linsolve still couldve been overwritten from forward to reverse. You need to cache it.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okay was just about to ask that, thanks. I think with that this may be completed. Though check the batch syntax in the test: the test still errors with BatchDuplicated and I'm not sure what to do there.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is the error log from?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ERROR: TypeError: in ccall argument 6, expected Tuple{Float64, Float64}, got a value of type Float64
Stacktrace:
 [1] macro expansion
   @ C:\Users\accou\.julia\packages\Enzyme\VS5jo\src\compiler.jl:9774 [inlined]
 [2] enzyme_call
   @ C:\Users\accou\.julia\packages\Enzyme\VS5jo\src\compiler.jl:9452 [inlined]
 [3] CombinedAdjointThunk
   @ C:\Users\accou\.julia\packages\Enzyme\VS5jo\src\compiler.jl:9415 [inlined]
 [4] autodiff
   @ C:\Users\accou\.julia\packages\Enzyme\VS5jo\src\Enzyme.jl:213 [inlined]
 [5] autodiff
   @ C:\Users\accou\.julia\packages\Enzyme\VS5jo\src\Enzyme.jl:236 [inlined]
 [6] autodiff(::ReverseMode{false, FFIABI}, ::typeof(f), ::BatchDuplicated{Matrix{Float64}, 2}, ::BatchDuplicated{Vector{Float64}, 2})
   @ Enzyme C:\Users\accou\.julia\packages\Enzyme\VS5jo\src\Enzyme.jl:222
 [7] top-level scope
   @ c:\Users\accou\.julia\dev\LinearSolve\test\enzyme.jl:36

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh thats an easy one [which we sohuld fix]. You can't use an active return right now in batch mode (which also makes little sense here since you'd back propagate the same value to each). Just wrap that func in a closure that stores it to a vector or something

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense, yeah the test was a bit dumb but just a quick sanity check 😓. Fixing that gives:

ERROR: Enzyme execution failed.
Enzyme: Augmented forward pass custom rule Tuple{EnzymeCore.EnzymeRules.ConfigWidth{2, true, true, (false, false, false)}, Const{typeof(init)}, Type{BatchDuplicated{LinearSolve.LinearCache{Matrix{Float64}, Vector{Float64}, Vector{Float64}, SciMLBase.NullParameters, LUFactorization{RowMaximum}, LU{Float64, Matrix{Float64}, Vector{Int64}}, IdentityOperator, IdentityOperator, Float64, Bool}, 2}}, BatchDuplicated{LinearProblem{Nothing, true, Matrix{Float64}, Vector{Float64}, SciMLBase.NullParameters, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, 2}, Const{LUFactorization{RowMaximum}}} return type mismatch, expected EnzymeCore.EnzymeRules.AugmentedReturn{LinearSolve.LinearCache{Matrix{Float64}, Vector{Float64}, Vector{Float64}, SciMLBase.NullParameters, LUFactorization{RowMaximum}, LU{Float64, Matrix{Float64}, Vector{Int64}}, IdentityOperator, IdentityOperator, Float64, Bool}, Tuple{LinearSolve.LinearCache{Matrix{Float64}, Vector{Float64}, Vector{Float64}, SciMLBase.NullParameters, LUFactorization{RowMaximum}, LU{Float64, Matrix{Float64}, Vector{Int64}}, IdentityOperator, IdentityOperator, Float64, Bool}, LinearSolve.LinearCache{Matrix{Float64}, Vector{Float64}, Vector{Float64}, SciMLBase.NullParameters, LUFactorization{RowMaximum}, LU{Float64, Matrix{Float64}, Vector{Int64}}, IdentityOperator, IdentityOperator, Float64, Bool}}, Any} found EnzymeCore.EnzymeRules.AugmentedReturn{LinearSolve.LinearCache{Matrix{Float64}, Vector{Float64}, Vector{Float64}, SciMLBase.NullParameters, LUFactorization{RowMaximum}, LU{Float64, Matrix{Float64}, Vector{Int64}}, IdentityOperator, IdentityOperator, Float64, Bool}, Base.Generator{Tuple{LinearProblem{Nothing, true, Matrix{Float64}, Vector{Float64}, SciMLBase.NullParameters, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, LinearProblem{Nothing, true, Matrix{Float64}, Vector{Float64}, SciMLBase.NullParameters, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}}, LinearSolveEnzymeExt.var"#2#5"{Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, Const{typeof(init)}, Const{LUFactorization{RowMaximum}}}}, Tuple{Base.Generator{Base.Generator{Tuple{LinearProblem{Nothing, true, Matrix{Float64}, Vector{Float64}, SciMLBase.NullParameters, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, LinearProblem{Nothing, true, Matrix{Float64}, Vector{Float64}, SciMLBase.NullParameters, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}}, LinearSolveEnzymeExt.var"#2#5"{Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, Const{typeof(init)}, Const{LUFactorization{RowMaximum}}}}, LinearSolveEnzymeExt.var"#3#6"}, Base.Generator{Base.Generator{Tuple{LinearProblem{Nothing, true, Matrix{Float64}, Vector{Float64}, SciMLBase.NullParameters, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, LinearProblem{Nothing, true, Matrix{Float64}, Vector{Float64}, SciMLBase.NullParameters, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}}, LinearSolveEnzymeExt.var"#2#5"{Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, Const{typeof(init)}, Const{LUFactorization{RowMaximum}}}}, LinearSolveEnzymeExt.var"#4#7"}}}
Stacktrace:
 [1] #solve#5
   @ C:\Users\accou\.julia\dev\LinearSolve\src\common.jl:193
 [2] solve
   @ C:\Users\accou\.julia\dev\LinearSolve\src\common.jl:190
 [3] #fbatch#207
   @ c:\Users\accou\.julia\dev\LinearSolve\test\enzyme.jl:39
 [4] fbatch
   @ c:\Users\accou\.julia\dev\LinearSolve\test\enzyme.jl:36
 [5] fbatch
   @ c:\Users\accou\.julia\dev\LinearSolve\test\enzyme.jl:0

Stacktrace:
 [1] throwerr(cstr::Cstring)
   @ Enzyme.Compiler C:\Users\accou\.julia\packages\Enzyme\VS5jo\src\compiler.jl:3066


@assert !(typeof(linsolve) <: Const)
@assert !(typeof(linsolve) <: Active)

if EnzymeRules.width(config) == 1
dys = (dys,)
end

dAs = if EnzymeRules.width(config) == 1
(linsolve.dval.A,)
else
(dval.A for dval in linsolve.dval)
end

dbs = if EnzymeRules.width(config) == 1
(linsolve.dval.b,)
else
(dval.b for dval in linsolve.dval)
end

for (dA, db, dy) in zip(dAs, dbs, dys)
z = if _linsolve.cacheval isa Factorization
_linsolve.cacheval' \ dy
elseif linsolve.cacheval isa Tuple && linsolve.cacheval[1] isa Factorization
_linsolve.cacheval[1]' \ dy
elseif linsolve.alg isa AbstractKrylovSubspaceMethod
# Doesn't modify `A`, so it's safe to just reuse it
invprob = LinearSolve.LinearProblem(transpose(_linsolve.A), dy)
solve(invprob;
abstol = _linsolve.val.abstol,
reltol = _linsolve.val.reltol,
verbose = _linsolve.val.verbose)
else
error("Algorithm $(_linsolve.alg) is currently not supported by Enzyme rules on LinearSolve.jl. Please open an issue on LinearSolve.jl detailing which algorithm is missing the adjoint handling")
end

dA .-= z * transpose(y)
db .+= z
dy .= eltype(dy)(0)
end

return (nothing,)
end

end
3 changes: 3 additions & 0 deletions src/init.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,8 @@ function __init__()
@require MKL_jll="856f044c-d86e-5d09-b602-aeab76dc8ba7" begin
include("../ext/LinearSolveMKLExt.jl")
end
@require Enzyme="7da242da-08ed-463a-9acd-ee780be4f1d9" begin
include("../ext/LinearSolveEnzymeExt.jl")
end
end
end
2 changes: 1 addition & 1 deletion test/basictests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ end
end
end

test_algs = if VERISON >= v"1.9"
test_algs = if VERSION >= v"1.9"
(LUFactorization(),
QRFactorization(),
SVDFactorization(),
Expand Down
71 changes: 71 additions & 0 deletions test/enzyme.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
using Enzyme, FiniteDiff
using LinearSolve, LinearAlgebra, Test

n = 4
A = rand(n, n);
dA = zeros(n, n);
b1 = rand(n);
db1 = zeros(n);

function f(A, b1; alg = LUFactorization())
prob = LinearProblem(A, b1)

sol1 = solve(prob, alg)

s1 = sol1.u
norm(s1)
end

f(A, b1) # Uses BLAS

Enzyme.autodiff(Reverse, f, Duplicated(copy(A), dA), Duplicated(copy(b1), db1))

dA2 = FiniteDiff.finite_difference_gradient(x->f(x,b1), copy(A))
db12 = FiniteDiff.finite_difference_gradient(x->f(A,x), copy(b1))

@test dA ≈ dA2
@test db1 ≈ db12

A = rand(n, n);
dA = zeros(n, n);
dA2 = zeros(n, n);
b1 = rand(n);
db1 = zeros(n);
db12 = zeros(n);

@test_broken Enzyme.autodiff(Reverse, f, BatchDuplicated(copy(A), (dA, dA2)), BatchDuplicated(copy(b1), (db1, db12)))

dA_2 = FiniteDiff.finite_difference_gradient(x->f(x,b1), copy(A))
db1_2 = FiniteDiff.finite_difference_gradient(x->f(A,x), copy(b1))

@test_broken dA ≈ dA_2
@test_broken dA2 ≈ dA_2
@test_broken db1 ≈ db1_2
@test_broken db12 ≈ db1_2

function f(A, b1, b2; alg = LUFactorization())
prob = LinearProblem(A, b1)

cache = init(prob, alg)
s1 = solve!(cache).u
cache.b = b2
s2 = solve!(cache).u
norm(s1 + s2)
end

A = rand(n, n);
dA = zeros(n, n);
b1 = rand(n);
db1 = zeros(n);
b2 = rand(n);
db2 = zeros(n);

Enzyme.autodiff(Reverse, f, Duplicated(copy(A), dA), Duplicated(copy(b1), db1), Duplicated(copy(b2), db2))

dA2 = FiniteDiff.finite_difference_gradient(x->f(x,b1,b2), copy(A))
db12 = FiniteDiff.finite_difference_gradient(x->f(A,x,b2), copy(b1))
db22 = FiniteDiff.finite_difference_gradient(x->f(A,b1,x), copy(b2))

@test dA ≈ dA2 atol=5e-5
@test db1 ≈ db12
@test db2 ≈ db22
3 changes: 2 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,12 @@ const HAS_EXTENSIONS = isdefined(Base, :get_extension)

if GROUP == "All" || GROUP == "Core"
@time @safetestset "Basic Tests" include("basictests.jl")
@time @safetestset "Re-solve" include("resolve.jl")
VERSION >= v"1.9" && @time @safetestset "Re-solve" include("resolve.jl")
@time @safetestset "Zero Initialization Tests" include("zeroinittests.jl")
@time @safetestset "Non-Square Tests" include("nonsquare.jl")
@time @safetestset "SparseVector b Tests" include("sparse_vector.jl")
@time @safetestset "Default Alg Tests" include("default_algs.jl")
VERSION >= v"1.9" && @time @safetestset "Enzyme Derivative Rules" include("enzyme.jl")
@time @safetestset "Traits" include("traits.jl")
end

Expand Down