-
Notifications
You must be signed in to change notification settings - Fork 89
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improvements to cholesky rrules #630
Conversation
else # C.uplo === 'L' | ||
L = C.L | ||
L̄ = eltype(L) <: Real ? real(tril(Δfactors)) : tril(Δfactors) | ||
mul!(Ā, L', L̄) | ||
LinearAlgebra.copytri!(Ā, 'L', true) | ||
eltype(Ā) <: Complex && _realifydiag!(Ā) | ||
rdiv!(Ā, L) | ||
ldiv!(L', Ā) | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since cholesky
doesn't have an uplo
argument, we have no way of making Cholesky
have an uplo
of L
, so this branch is unreachable and untestable. Maybe it's better to put a warning here to tell a user to open an issue, because they've found a magical way to reach this branch.
This could potentially be breaking. e.g. this would break this code in DistributionsAD : |
Ie |
Sorry, language was unclear and has been fixed. Previously only real positive definite (not PDMats) matrices were supported. Now we support also complex positive definite matrices. We still have the constraints of strided and diagonal matrices or Symmetric/Hermitian wrappers of them. No types defined in PDMats should hit these rules. |
@Red-Portal can you check that with these rules your issue would be resolved? |
Hi, just checked, and everything seems good except for the use of |
The adjoint is necessary for the rule to work for complex arrays. Ideally we would have a solution that works for both. We could do this: copy!(LowerTriangular(Ā), UpperTriangular(Ā)') but this is quite a bit slower than GPUArrays has its own |
I have opened a PR on GPUArrays addressing the issue. |
Ā = BLAS.trsm!('R', 'U', 'C', 'N', one(eltype(Ā)) / 2, U.data, Ā) | ||
function cholesky_HermOrSym_pullback(ΔC) | ||
Ā = _cholesky_pullback_shared_code(C, unthunk(ΔC)) | ||
rmul!(Ā, one(eltype(Ā)) / 2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the reason for not using
rmul!(Ā, one(eltype(Ā)) / 2) | |
rdiv!(Ā, 2) |
or
rmul!(Ā, one(eltype(Ā)) / 2) | |
ldiv!(2, Ā) |
? That seems more direct and simpler.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
rdiv!
performs elementwise divisions, so rmul!
by the reciprocal performs
julia> using BenchmarkTools
julia> foo(x, a) = rmul!(copy(x), inv(a));
julia> bar(x, a) = rdiv!(copy(x), a);
julia> x = randn(100, 100);
julia> @btime foo($x, 2.0);
4.718 μs (2 allocations: 78.17 KiB)
julia> @btime bar($x, 2.0);
9.003 μs (2 allocations: 78.17 KiB)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But isn't that something that should be optimized in base? We want to divide by 2, so the natural thing to do would be to use a division operator instead of manually working with eltype
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
BTW interestingly the difference seems to be smaller on my computer:
julia> @btime foo($x, 2.0);
4.690 μs (2 allocations: 78.17 KiB)
julia> @btime bar($x, 2.0);
7.098 μs (2 allocations: 78.17 KiB)
julia> VERSION
v"1.7.3"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FWIW, rmul!
used this way (to perform a division) is quite common throughout this codebase. e.g. all throughout https://github.com/JuliaDiff/ChainRules.jl/blob/main/src/rulesets/LinearAlgebra/dense.jl and https://github.com/JuliaDiff/ChainRules.jl/blob/main/src/rulesets/LinearAlgebra/lapack.jl. Base Julia itself uses this strategy: https://github.com/JuliaLang/julia/blob/b4eb88a71f8c2d8343b21d8fdd1ec403073a222c/stdlib/LinearAlgebra/src/dense.jl#L1595
So I don't think it's unreasonable to use it here for the extra performance. That being said, this is not the computational bottleneck.
Co-authored-by: David Widmann <[email protected]>
GPUArrays.jl #413 has been merged and I just checked that this PR works fine on the GPU as is. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, a few minor things.
Merge when happy
I wonder if we should be overloading getproperty(::Tangent{<:Cholesky}, sym)
to so that dX.U
still works etc.
Given Zygote loses primal types, and that is our biggest user, it seems nonpressing.
We might want an issues where we think about this.
Co-authored-by: Frames Catherine White <[email protected]>
Thanks for the review @oxinabox. What are your thoughts on whether this is breaking?
|
IMO, as large parts of DistributionsAD, this is a workaround, so I'd be happy if it would be removed. AFAICT it is mainly/only used to define rules for |
So @devmotion would you then say that this PR should be considered non-breaking because DistributionsAD is doing something risky anyways? |
Yes, I think DistributionsAD shouldn't hold back this PR. To me it seems the only fix required there will be replacing Edit: I opened a PR: TuringLang/DistributionsAD.jl#226 |
NOTE: This release also breaks Nabla.jl. https://github.com/invenia/Nabla.jl/runs/6959631070?check_suite_focus=true#step:6:390 Since I'm not familiar enough with either package to clearly identify what broke... I'm gonna temporarily restore the original branch so I can bisect the commits. |
Currently relying on piracy, but I've narrowed the specific changes to two specific changes that would probably be easy to re-add to make it non-breaking? invenia/Nabla.jl#217 |
This PR makes a number of improvements to the Cholesky-related
rrule
s:getproperty(::Cholesky, ::Symbol)
now returns aTangent
with afactors
entry instead of:U
or:L
(fixes part of Wrong call oncholesky
rrule
#611)rdiv!
instead ofBLAS.trsm!
(fixes Need a GPU compatiblerrule
for Cholesky #629)cholesky(::Quaternion)
andcholesky(::Diagonal{<:Quaternion})
, though this is untestedcholesky(::Number)
andcholesky(::Diagonal)
for failed factorizationThunk
cotangents (just unthunk received cotangents, which is a no-op forTangent
)