-
-
Notifications
You must be signed in to change notification settings - Fork 55
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add forward enzyme rules for init and solve #416
Conversation
Codecov Report
@@ Coverage Diff @@
## main #416 +/- ##
==========================================
- Coverage 64.94% 63.83% -1.11%
==========================================
Files 26 26
Lines 2068 2099 +31
==========================================
- Hits 1343 1340 -3
- Misses 725 759 +34
... and 1 file with indirect coverage changes 📣 Codecov offers a browser extension for seamless coverage viewing on GitHub. Try it in Chrome or Firefox today! |
76d30f8
to
e89be05
Compare
test/enzyme.jl
Outdated
@test_broken en_jac ≈ fd_jac |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These tests fail due to numerical imprecision. Not entirely sure why.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
df/db:
manual_jac = [-5.061916901336408, 1.5770609499033128, 5.446411839233853, 0.612648432464526]
fd_jac = [-5.061916947364807, 1.5770610570907593, 5.446411848068237, 0.6126484870910645]
en_jac = [-5.061916901336408, 1.5770609499033128, 5.446411839233853, 0.612648432464526]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's an expected difference due to summation order.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
addressed in 31745cc
4bcdc63
to
3f88629
Compare
ext/LinearSolveEnzymeExt.jl
Outdated
end | ||
|
||
dres = deepcopy(res) | ||
invA = inv(A) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
never use inv
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
addressed in d010911
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not fully, see #416 (comment). Your version has an extra factorization and is using the wrong linear solver.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
addressed in 31745cc
ext/LinearSolveEnzymeExt.jl
Outdated
invA = inv(A) | ||
db = linsolve.dval.b | ||
dA = linsolve.dval.A | ||
dres.u .= invA * (db - dA * res.u) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's no need to refactorize here. This is just A * u = db - dA * u
, or u = A \ (db - dA * u)
. But this is just a linsolve
call. Not only that, but it's the same operator as the one used in the forward pass, so you don't need to refactorize A
. Therefore, this should simply use the same linsolve
and do linsolve.b = db - dA * u
and then solve!
.
But I don't think the formula is correct. Isn't it just linsolve.b = db
and then solve!(linsolve)
then du = linsolve.u
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But linsolves.A
seem to mutate after the solve. That is why I had to add all this other stuff.
MWE:
using LinearSolve
A = rand(5,5)
b = rand(5)
prob = LinearProblem(A,b)
linsolve = init(prob)
@show linsolve.A
sol = solve!(linsolve)
@show linsolve.A
5×5 Matrix{Float64}:
0.207897 0.0737386 0.220551 0.935437 0.883482
0.362979 0.687049 0.220086 0.216771 0.145246
0.538937 0.577133 0.965879 0.438704 0.689019
0.348145 0.174776 0.0491639 0.145817 0.155529
0.92258 0.637729 0.508441 0.166917 0.218296
5×5 Matrix{Float64}:
0.92258 0.637729 0.508441 0.166917 0.218296
0.393439 0.436142 0.0200455 0.1511 0.0593595
0.584162 0.469106 0.659464 0.270316 0.533653
0.225343 -0.160427 0.16558 0.877305 0.755451
0.37736 -0.151045 -0.211799 0.185687 0.0548682
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, the mutation that it's doing is lu!(A)
, i.e. its using the same memory in the representation of the LU-factorization. You want to let it use exactly the same mutated A
via cache.cacheval
in order to do the next solve, which is precisely what's done in the caching interface. See https://docs.sciml.ai/LinearSolve/stable/tutorials/caching_interface/
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good. And the simple derivation of the formula is mentioned in https://scicomp.stackexchange.com/a/29421.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, the correctness of the formula is checked with FiniteDiff
in the tests - https://github.com/SciML/LinearSolve.jl/pull/416/files#diff-7c5a302dc0e55153407f2354959d201e28d2b9a61c4ef6bc86f05697a52b4cacR180-R196
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks like a real test failure. |
if RT <: DuplicatedNoNeed | ||
return dres | ||
elseif RT <: Duplicated | ||
return Duplicated(res, dres) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This doesn't handle batching atm
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Throw a good error for now? It would be good to get something merged even if it doesn't handle every case, as long as the errors are clear.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tried to address this in 585cbb8. We can make a followup PR for batch support.
@ChrisRackauckas The test failures seem unrelated. |
… due to summation ordering
…d random failures
2803bda
to
4265053
Compare
TODO: