Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improvements to cholesky rrules #630

Merged
merged 26 commits into from
Jun 17, 2022
Merged

Improvements to cholesky rrules #630

merged 26 commits into from
Jun 17, 2022

Conversation

sethaxen
Copy link
Member

@sethaxen sethaxen commented Jun 11, 2022

This PR makes a number of improvements to the Cholesky-related rrules:

  • The rrule for getproperty(::Cholesky, ::Symbol) now returns a Tangent with a factors entry instead of :U or :L (fixes part of Wrong call on cholesky rrule #611)
  • Use rdiv! instead of BLAS.trsm! (fixes Need a GPU compatible rrule for Cholesky #629)
  • Support complex numbers and complex PD matrices (also cholesky(::Quaternion) and cholesky(::Diagonal{<:Quaternion}), though this is untested
  • Fix cholesky(::Number) and cholesky(::Diagonal) for failed factorization
  • Remove rules for 1-arg methods (the 1-arg methods fall back to the 2-arg methods, so they should be hit anyways)
  • Remove specialization for Thunk cotangents (just unthunk received cotangents, which is a no-op for Tangent)
  • Add missing tests

@github-actions github-actions bot added the needs version bump Version needs to be incremented or set to -DEV in Project.toml label Jun 11, 2022
Comment on lines 508 to 516
else # C.uplo === 'L'
L = C.L
L̄ = eltype(L) <: Real ? real(tril(Δfactors)) : tril(Δfactors)
mul!(Ā, L', L̄)
LinearAlgebra.copytri!(Ā, 'L', true)
eltype(Ā) <: Complex && _realifydiag!(Ā)
rdiv!(Ā, L)
ldiv!(L', Ā)
end
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since cholesky doesn't have an uplo argument, we have no way of making Cholesky have an uplo of L, so this branch is unreachable and untestable. Maybe it's better to put a warning here to tell a user to open an issue, because they've found a magical way to reach this branch.

@sethaxen
Copy link
Member Author

* [x]  The rrule for `getproperty(::Cholesky, ::Symbol)` now returns a `Tangent` with a `factors` entry instead of `:U` or `:L` (fixes part of [Wrong call on `cholesky` `rrule` #611](https://github.com/JuliaDiff/ChainRules.jl/issues/611))

This could potentially be breaking. e.g. this would break this code in DistributionsAD :
https://github.com/TuringLang/DistributionsAD.jl/blob/44a57e974e386ab576a0251967cf7e57e42c63f7/src/common.jl#L3-L29.

@sethaxen sethaxen marked this pull request as ready for review June 11, 2022 16:17
@devmotion
Copy link
Member

 Support [...] PD matrices

Ie cholesky(::AbstractPDMat) will hit an rrule here now? We don't want these to be handled by a generic method but instead AD should just follow and differentiate the optimized implementations in PDMats. We could add opt-outs (there's an open issue for det as well) but these seem still a bit unsatisfying to me - if the upstream rules would be less generic, everything would just work without having to think about AD in PDMats (and having to know about these definitions in CR).

@sethaxen
Copy link
Member Author

Support [...] PD matrices

Ie cholesky(::AbstractPDMat) will hit an rrule here now?

Sorry, language was unclear and has been fixed. Previously only real positive definite (not PDMats) matrices were supported. Now we support also complex positive definite matrices. We still have the constraints of strided and diagonal matrices or Symmetric/Hermitian wrappers of them. No types defined in PDMats should hit these rules.

@github-actions github-actions bot removed the needs version bump Version needs to be incremented or set to -DEV in Project.toml label Jun 11, 2022
@sethaxen sethaxen requested a review from devmotion June 11, 2022 20:24
@sethaxen sethaxen requested a review from mzgubic June 11, 2022 20:25
@sethaxen
Copy link
Member Author

@Red-Portal can you check that with these rules your issue would be resolved?

@Red-Portal
Copy link

Red-Portal commented Jun 12, 2022

Hi, just checked, and everything seems good except for the use of copytri! with the conjugate option set to true. This triggers scalar indexing for CuArrays. Zygote's adjoint seems to be good without conjugation. Could we get around it?

@sethaxen
Copy link
Member Author

Zygote's adjoint seems to be good without conjugation. Could we get around it?

The adjoint is necessary for the rule to work for complex arrays. Ideally we would have a solution that works for both. We could do this:

copy!(LowerTriangular(Ā), UpperTriangular(Ā)')

but this is quite a bit slower than copytri! for large arrays.

GPUArrays has its own copytri! implementation that is missing two of the options in LinearAlgebra.copytri!, which is why as soon as we use one of those options, we fall back to the one in Base. https://github.com/JuliaGPU/GPUArrays.jl/blob/fc0d327ecc2fd0b3b73427cf6f491591aa096b75/src/host/linalg.jl#L35-L59 This seems like something that should be fixed in GPUArrays.

@Red-Portal
Copy link

Red-Portal commented Jun 12, 2022

I have opened a PR on GPUArrays addressing the issue.

src/rulesets/LinearAlgebra/factorization.jl Outdated Show resolved Hide resolved
Ā = BLAS.trsm!('R', 'U', 'C', 'N', one(eltype(Ā)) / 2, U.data, Ā)
function cholesky_HermOrSym_pullback(ΔC)
= _cholesky_pullback_shared_code(C, unthunk(ΔC))
rmul!(Ā, one(eltype(Ā)) / 2)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the reason for not using

Suggested change
rmul!(Ā, one(eltype(Ā)) / 2)
rdiv!(Ā, 2)

or

Suggested change
rmul!(Ā, one(eltype(Ā)) / 2)
ldiv!(2, Ā)

? That seems more direct and simpler.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rdiv! performs elementwise divisions, so $n^2$ division operations, whereas rmul! by the reciprocal performs $n^2$ elementwise multiplications and a single division. Division is generally more expensive than multiplication, so this is cheaper e.g.

julia> using BenchmarkTools

julia> foo(x, a) = rmul!(copy(x), inv(a));

julia> bar(x, a) = rdiv!(copy(x), a);

julia> x = randn(100, 100);

julia> @btime foo($x, 2.0);
  4.718 μs (2 allocations: 78.17 KiB)

julia> @btime bar($x, 2.0);
  9.003 μs (2 allocations: 78.17 KiB)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But isn't that something that should be optimized in base? We want to divide by 2, so the natural thing to do would be to use a division operator instead of manually working with eltype.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW interestingly the difference seems to be smaller on my computer:

julia> @btime foo($x, 2.0);
  4.690 μs (2 allocations: 78.17 KiB)

julia> @btime bar($x, 2.0);
  7.098 μs (2 allocations: 78.17 KiB)

julia> VERSION
v"1.7.3"

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FWIW, rmul! used this way (to perform a division) is quite common throughout this codebase. e.g. all throughout https://github.com/JuliaDiff/ChainRules.jl/blob/main/src/rulesets/LinearAlgebra/dense.jl and https://github.com/JuliaDiff/ChainRules.jl/blob/main/src/rulesets/LinearAlgebra/lapack.jl. Base Julia itself uses this strategy: https://github.com/JuliaLang/julia/blob/b4eb88a71f8c2d8343b21d8fdd1ec403073a222c/stdlib/LinearAlgebra/src/dense.jl#L1595

So I don't think it's unreasonable to use it here for the extra performance. That being said, this is not the computational bottleneck.

test/rulesets/LinearAlgebra/factorization.jl Show resolved Hide resolved
@Red-Portal
Copy link

Red-Portal commented Jun 14, 2022

GPUArrays.jl #413 has been merged and I just checked that this PR works fine on the GPU as is.

Copy link
Member

@oxinabox oxinabox left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, a few minor things.
Merge when happy

I wonder if we should be overloading getproperty(::Tangent{<:Cholesky}, sym) to so that dX.U still works etc.
Given Zygote loses primal types, and that is our biggest user, it seems nonpressing.
We might want an issues where we think about this.

Co-authored-by: Frames Catherine White <[email protected]>
@sethaxen
Copy link
Member Author

Thanks for the review @oxinabox. What are your thoughts on whether this is breaking?

* [x]  The rrule for `getproperty(::Cholesky, ::Symbol)` now returns a `Tangent` with a `factors` entry instead of `:U` or `:L` (fixes part of [Wrong call on `cholesky` `rrule` #611](https://github.com/JuliaDiff/ChainRules.jl/issues/611))

This could potentially be breaking. e.g. this would break this code in DistributionsAD : https://github.com/TuringLang/DistributionsAD.jl/blob/44a57e974e386ab576a0251967cf7e57e42c63f7/src/common.jl#L3-L29.

@devmotion
Copy link
Member

IMO, as large parts of DistributionsAD, this is a workaround, so I'd be happy if it would be removed. AFAICT it is mainly/only used to define rules for cholesky for Tracker and ReverseDiff. It's a bit horrifying that loading DistributionsAD changes cholesky for these packages, so the sooner it's gone the better I would say.

@sethaxen
Copy link
Member Author

So @devmotion would you then say that this PR should be considered non-breaking because DistributionsAD is doing something risky anyways?

@devmotion
Copy link
Member

devmotion commented Jun 17, 2022

Yes, I think DistributionsAD shouldn't hold back this PR. To me it seems the only fix required there will be replacing U in https://github.com/TuringLang/DistributionsAD.jl/blob/48c43f8e8062ba95542330735593b5275117e592/src/common.jl#L10 and https://github.com/TuringLang/DistributionsAD.jl/blob/48c43f8e8062ba95542330735593b5275117e592/src/common.jl#L24 with factors. That is the right thing anyway (for the time being until this stuff is removed completely) since the primal function returns the factors as first element of a tuple. I guess it is just not done currently since otherwise the CR would have returned ZeroTangent (due to the bug fixed in the PR here).

Edit: I opened a PR: TuringLang/DistributionsAD.jl#226

@sethaxen sethaxen merged commit 6ff4c31 into main Jun 17, 2022
@sethaxen sethaxen deleted the choleskyfactors branch June 17, 2022 11:45
@rofinn
Copy link
Contributor

rofinn commented Jun 20, 2022

NOTE: This release also breaks Nabla.jl.

https://github.com/invenia/Nabla.jl/runs/6959631070?check_suite_focus=true#step:6:390

Since I'm not familiar enough with either package to clearly identify what broke... I'm gonna temporarily restore the original branch so I can bisect the commits.

@rofinn rofinn restored the choleskyfactors branch June 20, 2022 22:01
@rofinn
Copy link
Contributor

rofinn commented Jun 22, 2022

Currently relying on piracy, but I've narrowed the specific changes to two specific changes that would probably be easy to re-add to make it non-breaking? invenia/Nabla.jl#217

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Need a GPU compatible rrule for Cholesky
5 participants