-
Notifications
You must be signed in to change notification settings - Fork 34
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Use AbstractVector in LKJ and LKJCholesky bijectors #253
Use AbstractVector in LKJ and LKJCholesky bijectors #253
Conversation
for `LKJCholesky`
to `Matrix` row index
to `_logabsdetjac_inv_corr`
[WIP] LKJ and LKJCholesky bijectors
Actually let's wait, as the AD tests I've just added on the roundtrip transformation fail, will have a look. |
Looked more into Test is failing on Unsure about the
All these tests are for AD through |
Regarding the test-failures, it's a bit strange. This seems related: JuliaDiff/ForwardDiff.jl#606. But I thought this was fixed because they pulled 0.10.33 after the discussion there, and deferred the breaking changes to 0.11. The tests are running on 0.10.35 so I don't get why we're seeing this 😕 Think this needs a bit further inspection. And we do actually need to AD through the (Btw, I'm not done with my review, will continue later) |
Good points @torfjelde , thanks! Here's more about the AD issues : TrackerFrom discussions elsewhere (Slack) I understand that we agree to drop support for this. ForwardDiff
It might actually not be. It seems like a numerical issue when comparing values in I found two samples from the same using Bijectors, DistributionsAD, LinearAlgebra
using Bijectors: VecCorrBijector
using ForwardDiff
using ForwardDiff: Dual
b = VecCorrBijector('C') # bijector(LKJ(5,1))
binv = inverse(b)
f = x -> sum(b(binv(x)))
# x_f ~ LKJ(5,1)
x_f = [
1.0 0.38808945715615550398 0.55251148082365042491 0.06333711952583508109 -0.51630779311225594164
0.38808945715615550398 1.0 0.31760367441586356829 0.34585990227668395036 0.06051504059466897290
0.55251148082365042491 0.31760367441586356829 1.0 0.17416714618194936715 -0.02825518349677474950
0.06333711952583508109 0.34585990227668395036 0.17416714618194936715 1.0 -0.07513830680477201485
-0.51630779311225594164 0.06051504059466897290 -0.02825518349677474950 -0.07513830680477201485 1.0
]
df_f = ForwardDiff.gradient(f, b(x_f)) # Errors, ishermitian returns false
# x_s ~ LKJ(5,1)
x_s = [
1.0 -0.01569213125090618277 -0.79039374741027101923 -0.03400980954333766848 0.54371128016847525277
-0.01569213125090618277 1.0 -0.19877390203937703173 -0.37124942960738860354 -0.39209191569764001439
-0.79039374741027101923 -0.19877390203937703173 1.0 0.03430683023840974677 -0.62744676631878926187
-0.03400980954333766848 -0.37124942960738860354 0.03430683023840974677 1.0 0.50841756191547016197
0.54371128016847525277 -0.39209191569764001439 -0.62744676631878926187 0.50841756191547016197 1.0
]
df_s = ForwardDiff.gradient(f, b(x_s)) # Runs, ishermitian returns true
# Let's see where x_f fails
function ish(A::AbstractMatrix)
# Just a copy of ishermitian with a `@show`
indsm, indsn = axes(A)
if indsm != indsn
return false
end
for i = indsn, j = i:last(indsn)
if A[i,j] != adjoint(A[j,i])
@show abs(A[i,j] - adjoint(A[j,i]))
return false
end
end
return true
end
y_f = b(x_f)
ish(binv(Dual.(y_f))) # Returns false, shows abs(A[i, j] - adjoint(A[j, i])) = Dual{Nothing}(2.0816681711721685e-17)
# Without using `Dual`s though, all is good
ish(binv(y_f)) # Returns true So EDIT: Tried using ZygoteThis indeed has to do with using Bijectors, DistributionsAD, LinearAlgebra
using Zygote
dist = LKJ(5, 1)
x = rand(dist)
g = x -> sum(cholesky(x).U)
dg = Zygote.gradient(g, x) # Returns correct gradient
h = x -> sum(cholesky(x).UL)
dh = Zygote.gradient(h, x) # Returns (nothing, ) So Line 18 in 0d599e8
to X.U , take the potential extra allocation (if uplo === :L ) and always work with UpperTriangular downstream. Using PDMats.chol_upper as suggested here results in the same issue by accessing getproperty(::Cholesky, :factors) .
Any thoughts on how to handle the |
It is for the case of |
|
I restarted the Inference tests multiple times and the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Almost there! But I think the cholesky-version should just be its own struct so we avoid the type-stabilities.
Otherwise it's looking pretty dank!
And I'll have a look at the ForwardDiff issue.
src/bijectors/corr.jl
Outdated
|
||
# Fields | ||
- mode :`Symbol`. Controls the inverse tranformation : | ||
- if `mode === :C` returns a correlation matrix |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need this? I'm personally happy to just support U or L.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That is, make the cholesky version into a separate type, e.g. VecCholCorrBijector
. This will avoid the type-instabilities + moves the conditional handling you have in some functions into multiple dispatch instead.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@harisorgn Any updates on this?:)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was completely off last week. Agree with splitting/specialising the structures, will implement it this week!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, no worries! Sweet!
src/chainrules.jl
Outdated
return UpperTriangular(X)' * UpperTriangular(X), Δ -> begin | ||
Xu = UpperTriangular(X) | ||
return ChainRulesCore.NoTangent(), UpperTriangular(Xu * Δ + Xu * Δ') | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This needs a ChainRulesCore.unthunk
, no? https://juliadiff.org/ChainRulesCore.jl/stable/rule_author/writing_good_rules.html#Thunks
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, maybe add a rrule test? That would have caught the missing unthunk
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The thing is I was testing the rrules locally as I was adding them and this was passing. Probably unthunk
ing would be needed if it's part of multiple function calls that get differentiated? I am adding it anyway.
@torfjelde , I implemented your suggestions, thanks for the feedback again : ) I couldn't locally reproduce the Also disregard my previous confusion about reproducing the So there are still these two errors, plus the (Apologies for the format, only have phone access for now) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great stuff @harisorgn :) Really close now!
I had a super-quick look, and made some very minor comments + changes. Once those are addressed, I think we should be good go!
Again, awesome work; I imagine this isn't the most fun PR to work on, so appreciate you seeing this through ❤️
@@ -182,7 +188,23 @@ end | |||
|
|||
upperinds = [LinearIndices(size(x))[I] for I in CartesianIndices(size(x)) if I[2] > I[1]] | |||
J = ForwardDiff.jacobian(x->link(dist, x), x) | |||
J = J[upperinds, upperinds] | |||
J = J[:, upperinds] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What was this for again? Sorry, we might have discussed this before.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't think we have : ) . It's because the output of dist
is an AbstractVector
now, so the indices of upper triagular elements don't apply anymore. In this test, x is 3x3 matrix, link(dist, x)
is length 3 vector, the Jacobian is then a 3x9 matrix, and we are keeping all output elements (as they are all relevant now) and only the upperinds
of the input elements.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Aaah yeah now I remember:) We didn't discuss this but I had a think through it myself the last time I looked at it 😅 Just had a vague memory of at some point being befuddled about it and then figuring it out.
@torfjelde accidental merge, sorry, was setting up git in a new machine 😅 . Please revert it and I'll implement the last changes. |
Is it maybe easier if you just take over the other PR?:) #246 |
Expands on #246 .
Use
::AbstractVector
inVecCorrBijector
operations, so we won't need totransform
to::AbstractMatrix
and back.Add bijector for
LKJCholesky
. I believe this was missing and in practice it is the more efficient alternative when working with correlation matrices (avoids Cholesky decompositions on every call).In
LKJCholesky
there is control over the returned factor ('U' -> UpperTriangular
or'L' -> LowerTriangular
). I was wondering whether we want to respect the factor choice and always return the same triangular factor. If yes, we can useVecTriuBijector
andVecTrilBijector
to retain information about the original factor inLKJCholesky
and return it. If no, we can always work with one type, e.g.UpperTriangular
.TO DO :
AddChainRulesCore.rrule
s for all link functions that work on::AbstractVector
, defined in this PR. I have only added one rule for the forward link function, butChainRulesTestUtils.test_rrule
complains about type instability and value mismatch. When comparing the values returned by the pullback inside the closure ofrrule
against the one defined for Zygote I'm getting the same output though. I will have more of a look next week.Document how I ended up withThis was based on the Stan manual pages for correlation matrices and Cholesky factors of correlation matrices._logabsdetjac_inv_chol
, so it can be verified.EDIT 2: I have not documented the formula derivation but added a test for it that passes.
Remove this dispatch
function _link_chol_lkj(W::LowerTriangular)
.Bijectors.jl/src/bijectors/corr.jl
Line 320 in 7f5d0fc
and use
transpose(W::UpperTriangular)
Related to the second point, right above, in general it would be nice if we could test these analytical formulas for logabsdetjac derived by hand. I played around with it a bit, but couldn't come up with something.
EDIT : This can be done using AD. I see there is something already implemented along these lines in test/transform.jl, just needs some tweaking.
cc @torfjelde if you want to have a look already