-
-
Notifications
You must be signed in to change notification settings - Fork 55
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Enzyme extension #377
Add Enzyme extension #377
Conversation
Sample call: using Enzyme
using LinearSolve, LinearAlgebra
n = 4
A = rand(n, n);
dA = zeros(n, n);
b1 = rand(n);
db1 = zeros(n);
b2 = rand(n);
db2 = zeros(n);
function f(A, b1, b2; alg = LUFactorization())
prob = LinearProblem(A, b1)
sol1 = solve(prob, alg)
s1 = sol1.u
norm(s1)
end
f(A, b1, b2) # Uses BLAS
Enzyme.autodiff(Reverse, f, Duplicated(A, dA), Duplicated(b1, db1), Duplicated(b2, db2))
@show dA, db1, db2 |
Codecov Report
@@ Coverage Diff @@
## main #377 +/- ##
===========================================
+ Coverage 20.01% 68.25% +48.24%
===========================================
Files 14 24 +10
Lines 1444 1884 +440
===========================================
+ Hits 289 1286 +997
+ Misses 1155 598 -557
... and 22 files with indirect coverage changes 📣 We’re building smart automated test selection to slash your CI/CD build times. Learn more |
004eb9d
to
70e8599
Compare
70e8599
to
9f8d18f
Compare
It looks like this only handles the case of using LinearSolve, LinearAlgebra
# using MKL_jll
n = 100
A = rand(n, n)
b1 = rand(n);
b2 = rand(n);
function f(A, b1, b2; alg = LUFactorization())
prob = LinearProblem(A, b1)
linsolve = init(prob, alg)
sol1 = solve!(linsolve)
s1 = copy(sol1.u)
linsolve.b = b2
sol2 = solve!(linsolve)
s2 = copy(sol2.u)
norm(s1 + s2)
end
f(A, b1, b2) # Uses BLAS
f(A, b1, b2; alg=RFLUFactorization()) # Uses loops
f(A, b1, b2; alg=MKLLUFactorization()) # Requires `using MKL_jll`
using Enzyme
dA = zero(A)
db1 = zero(b1)
db2 = zero(b2)
Enzyme.autodiff(Reverse, f, Duplicated(A,dA),
Duplicated(b1, db1), Duplicated(b2, db2)) which is EnzymeAD/Enzyme.jl#1065. I at least added a test for the |
Pushed extension for solve! and init now. While was at it, also added batch mode support. |
ext/LinearSolveEnzymeExt.jl
Outdated
function EnzymeCore.EnzymeRules.augmented_primal(config, func::Const{typeof(LinearSolve.init)}, ::Type{RT}, prob::EnzymeCore.Annotation{LP}, alg::Const; kwargs...) where {RT, LP <: LinearSolve.LinearProblem} | ||
res = func.val(prob.val, alg.val; kwargs...) | ||
dres = if EnzymeRules.width(config) == 1 | ||
func.val(prob.dval, alg.val; kwargs...) | ||
else | ||
(func.val(dval, alg.val; kwargs...) for dval in prob.dval) | ||
end | ||
return EnzymeCore.EnzymeRules.AugmentedReturn(res, dres, nothing) | ||
end | ||
|
||
function EnzymeCore.EnzymeRules.reverse(config, func::Const{typeof(LinearSolve.init)}, ::Type{RT}, cache, prob::EnzymeCore.Annotation{LP}, alg::Const; kwargs...) where {RT, LP <: LinearSolve.LinearProblem} | ||
return (nothing, nothing) | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this one required? It seems like it doesn't do much?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Init hits that global variable stuff, so we need a rule for corresponding shadow initialization.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see
What in here was required for batch mode support? |
ext/LinearSolveEnzymeExt.jl
Outdated
(dr.u for dr in dres) | ||
end | ||
|
||
cache = (copy(linsolve.val.A), res, resvals) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this copy necessary?
Not specializing to just duplicated but also supporting batchduplicated, which has dval as a tuple of shadows |
As a tuple, does that have an issue scaling to say batch of a 100 or 1000 things? |
For conservative correctness yes. A may be modified between the forward and
reverse pass.
The overwritten set of bools says if the outermost struct pointer is
overwritten and has no information about internal members being overwritten.
As Julia and other Alias analysis is improved (or we have an ImmutableArray
type or something), this can be elided in the future.
…On Fri, Sep 22, 2023 at 5:28 PM Christopher Rackauckas < ***@***.***> wrote:
***@***.**** commented on this pull request.
------------------------------
In ext/LinearSolveEnzymeExt.jl
<#377 (comment)>:
> +
+ if EnzymeRules.width(config) == 1
+ dres.u .= 0
+ else
+ for dr in dres
+ dr.u .= 0
+ end
+ end
+
+ resvals = if EnzymeRules.width(config) == 1
+ dres.u
+ else
+ (dr.u for dr in dres)
+ end
+
+ cache = (copy(linsolve.val.A), res, resvals)
Is this copy necessary?
—
Reply to this email directly, view it on GitHub
<#377 (review)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AAJTUXH63K3U4YYGH6FJEALX3YGHPANCNFSM6AAAAAA5CMNQWY>
.
You are receiving this because you authored the thread.Message ID:
***@***.***>
|
ext/LinearSolveEnzymeExt.jl
Outdated
end | ||
|
||
for (dA, db, dy) in zip(dAs, dbs, dys) | ||
invprob = LinearSolve.LinearProblem(transpose(A), dy) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the forward pass the matrix A
is factorized, so in theory we don't need to factorize it again, just transpose A
from the forward pass. Is there a way to grab that?
It supports being used in arbitrary sizes.
In practice of course some sizes could be better than others. Eg for
vectorization sake a power of two. Likewise, if a computation can be reused
for all batch elements that could improve perf. Eg if transpose(A)
generated a new matrix and not a view we could do that once for all batches.
…On Fri, Sep 22, 2023 at 5:29 PM Christopher Rackauckas < ***@***.***> wrote:
As a tuple, does that have an issue scaling to say batch of a 100 or 1000
things?
—
Reply to this email directly, view it on GitHub
<#377 (comment)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AAJTUXAF3EZFCL2WXMKTA4TX3YGOJANCNFSM6AAAAAA5CMNQWY>
.
You are receiving this because you authored the thread.Message ID:
***@***.***>
|
Sure, if you know a better set of things to cache, we can choose those
instead. I don’t know much about the internals of solve so I went for this
form.
…On Fri, Sep 22, 2023 at 5:31 PM Christopher Rackauckas < ***@***.***> wrote:
***@***.**** commented on this pull request.
------------------------------
In ext/LinearSolveEnzymeExt.jl
<#377 (comment)>:
> + end
+
+ dAs = if EnzymeRules.width(config) == 1
+ (linsolve.dval.A,)
+ else
+ (dval.A for dval in linsolve.dval)
+ end
+
+ dbs = if EnzymeRules.width(config) == 1
+ (linsolve.dval.b,)
+ else
+ (dval.b for dval in linsolve.dval)
+ end
+
+ for (dA, db, dy) in zip(dAs, dbs, dys)
+ invprob = LinearSolve.LinearProblem(transpose(A), dy)
In the forward pass the matrix A is factorized, so in theory we don't
need to factorize it again, just transpose A from the forward pass. Is
there a way to grab that?
—
Reply to this email directly, view it on GitHub
<#377 (review)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AAJTUXBE5PMQ2DDPQRP7SSTX3YGUJANCNFSM6AAAAAA5CMNQWY>
.
You are receiving this because you authored the thread.Message ID:
***@***.***>
|
The key that I'm pointing out here is similar to the top of https://docs.sciml.ai/LinearSolve/stable/tutorials/caching_interface/. But here, what _A = lu!(A)
_A \ b1 and then the backpass is: _At = lu!(A')
_At \ db1 but we also have that (essentially) So what I'm wondering is if it's safe to assume that |
It's the same Julia object, but it's possible it's fields may have been modified. If it's immutable, then it's the same. |
Even if it's overwritten, however, you can still add whatever is relevant from he LU into the cache and use that as a starting point |
Awesome, I'll leave that as a follow-up, no need to handle it in this PR. But the tests do need to get fixed. |
The transpose of the factorization is the factorization of the transpose: using LinearAlgebra
A = rand(4,4)
luA = lu(A)
At = transpose(A)
luAt = lu(At)
b = rand(4)
x = A \ b
x2 = A' \ b
x3 = luA \ b
x4 = luAt \ b
x5 = luA' \ b
x ≈ x3
x2 ≈ x4 ≈ x5 Confirmed from https://web.mit.edu/18.06/www/Spring17/Transposes.pdf. We can use this to generalize and optimize a bit. |
ext/LinearSolveEnzymeExt.jl
Outdated
|
||
function EnzymeCore.EnzymeRules.reverse(config, func::Const{typeof(LinearSolve.solve!)}, ::Type{RT}, cache, linsolve::EnzymeCore.Annotation{LP}; kwargs...) where {RT, LP <: LinearSolve.LinearCache} | ||
y, dys = cache | ||
_linsolve = linsolve.val |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is still wrong, because linsolve still couldve been overwritten from forward to reverse. You need to cache it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
okay was just about to ask that, thanks. I think with that this may be completed. Though check the batch syntax in the test: the test still errors with BatchDuplicated and I'm not sure what to do there.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what is the error log from?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ERROR: TypeError: in ccall argument 6, expected Tuple{Float64, Float64}, got a value of type Float64
Stacktrace:
[1] macro expansion
@ C:\Users\accou\.julia\packages\Enzyme\VS5jo\src\compiler.jl:9774 [inlined]
[2] enzyme_call
@ C:\Users\accou\.julia\packages\Enzyme\VS5jo\src\compiler.jl:9452 [inlined]
[3] CombinedAdjointThunk
@ C:\Users\accou\.julia\packages\Enzyme\VS5jo\src\compiler.jl:9415 [inlined]
[4] autodiff
@ C:\Users\accou\.julia\packages\Enzyme\VS5jo\src\Enzyme.jl:213 [inlined]
[5] autodiff
@ C:\Users\accou\.julia\packages\Enzyme\VS5jo\src\Enzyme.jl:236 [inlined]
[6] autodiff(::ReverseMode{false, FFIABI}, ::typeof(f), ::BatchDuplicated{Matrix{Float64}, 2}, ::BatchDuplicated{Vector{Float64}, 2})
@ Enzyme C:\Users\accou\.julia\packages\Enzyme\VS5jo\src\Enzyme.jl:222
[7] top-level scope
@ c:\Users\accou\.julia\dev\LinearSolve\test\enzyme.jl:36
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh thats an easy one [which we sohuld fix]. You can't use an active return right now in batch mode (which also makes little sense here since you'd back propagate the same value to each). Just wrap that func in a closure that stores it to a vector or something
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense, yeah the test was a bit dumb but just a quick sanity check 😓. Fixing that gives:
ERROR: Enzyme execution failed.
Enzyme: Augmented forward pass custom rule Tuple{EnzymeCore.EnzymeRules.ConfigWidth{2, true, true, (false, false, false)}, Const{typeof(init)}, Type{BatchDuplicated{LinearSolve.LinearCache{Matrix{Float64}, Vector{Float64}, Vector{Float64}, SciMLBase.NullParameters, LUFactorization{RowMaximum}, LU{Float64, Matrix{Float64}, Vector{Int64}}, IdentityOperator, IdentityOperator, Float64, Bool}, 2}}, BatchDuplicated{LinearProblem{Nothing, true, Matrix{Float64}, Vector{Float64}, SciMLBase.NullParameters, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, 2}, Const{LUFactorization{RowMaximum}}} return type mismatch, expected EnzymeCore.EnzymeRules.AugmentedReturn{LinearSolve.LinearCache{Matrix{Float64}, Vector{Float64}, Vector{Float64}, SciMLBase.NullParameters, LUFactorization{RowMaximum}, LU{Float64, Matrix{Float64}, Vector{Int64}}, IdentityOperator, IdentityOperator, Float64, Bool}, Tuple{LinearSolve.LinearCache{Matrix{Float64}, Vector{Float64}, Vector{Float64}, SciMLBase.NullParameters, LUFactorization{RowMaximum}, LU{Float64, Matrix{Float64}, Vector{Int64}}, IdentityOperator, IdentityOperator, Float64, Bool}, LinearSolve.LinearCache{Matrix{Float64}, Vector{Float64}, Vector{Float64}, SciMLBase.NullParameters, LUFactorization{RowMaximum}, LU{Float64, Matrix{Float64}, Vector{Int64}}, IdentityOperator, IdentityOperator, Float64, Bool}}, Any} found EnzymeCore.EnzymeRules.AugmentedReturn{LinearSolve.LinearCache{Matrix{Float64}, Vector{Float64}, Vector{Float64}, SciMLBase.NullParameters, LUFactorization{RowMaximum}, LU{Float64, Matrix{Float64}, Vector{Int64}}, IdentityOperator, IdentityOperator, Float64, Bool}, Base.Generator{Tuple{LinearProblem{Nothing, true, Matrix{Float64}, Vector{Float64}, SciMLBase.NullParameters, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, LinearProblem{Nothing, true, Matrix{Float64}, Vector{Float64}, SciMLBase.NullParameters, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}}, LinearSolveEnzymeExt.var"#2#5"{Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, Const{typeof(init)}, Const{LUFactorization{RowMaximum}}}}, Tuple{Base.Generator{Base.Generator{Tuple{LinearProblem{Nothing, true, Matrix{Float64}, Vector{Float64}, SciMLBase.NullParameters, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, LinearProblem{Nothing, true, Matrix{Float64}, Vector{Float64}, SciMLBase.NullParameters, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}}, LinearSolveEnzymeExt.var"#2#5"{Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, Const{typeof(init)}, Const{LUFactorization{RowMaximum}}}}, LinearSolveEnzymeExt.var"#3#6"}, Base.Generator{Base.Generator{Tuple{LinearProblem{Nothing, true, Matrix{Float64}, Vector{Float64}, SciMLBase.NullParameters, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, LinearProblem{Nothing, true, Matrix{Float64}, Vector{Float64}, SciMLBase.NullParameters, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}}, LinearSolveEnzymeExt.var"#2#5"{Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, Const{typeof(init)}, Const{LUFactorization{RowMaximum}}}}, LinearSolveEnzymeExt.var"#4#7"}}}
Stacktrace:
[1] #solve#5
@ C:\Users\accou\.julia\dev\LinearSolve\src\common.jl:193
[2] solve
@ C:\Users\accou\.julia\dev\LinearSolve\src\common.jl:190
[3] #fbatch#207
@ c:\Users\accou\.julia\dev\LinearSolve\test\enzyme.jl:39
[4] fbatch
@ c:\Users\accou\.julia\dev\LinearSolve\test\enzyme.jl:36
[5] fbatch
@ c:\Users\accou\.julia\dev\LinearSolve\test\enzyme.jl:0
Stacktrace:
[1] throwerr(cstr::Cstring)
@ Enzyme.Compiler C:\Users\accou\.julia\packages\Enzyme\VS5jo\src\compiler.jl:3066
The solving twice tests are a bit odd: julia> db1
4-element Vector{Float64}:
0.0
0.0
0.0
0.0
julia> db2
4-element Vector{Float64}:
2.1215949279204196
-3.7095838683317943
-1.2286715744423384
5.967859589815037 It doubles |
We can skip over that last test to merge, but do you know why that one algorithm would be treated so differently by Enzyme? I would've thought it didn't care if we're capturing stuff in rules, but it treats this algorithm particularly differently: https://github.com/SciML/LinearSolve.jl/actions/runs/6290016689/job/17077077461?pr=377#step:6:807 |
requires current Enzyme main for a custom rules fix