-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added new transforms submodules and fill_scale_tril.py transform to c… #5
Conversation
…onvert vectors into lower triangular matrices
The code is failing because torch.distributions.transforms.ComposeTransform requires me to define the variables |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this looks like a good start. i think the biggest things that need addressing are
- support ComposeTransform and add a test to make sure it works (maybe this is not possible for some reason?).
- support bijectors other than softplus for the diagonal. i maintain that hardcoding the derivatives in the log_det_jacobian block is too restrictive. it should be relatively easy to get the necessary info from self.diag_transform
""" | ||
super().__init__() | ||
self.diag_transform = ( | ||
diag_transform if diag_transform is not None else SoftplusTransform() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is not a bug, but we should discuss as a group whether we want to use ExpTransform or SofplusTransform in cases like these. either way, you might consider refactoring this to grab the appropriate constraint for positivity from the constraints registry. (https://pytorch.org/docs/stable/distributions.html#module-torch.distributions.constraint_registry)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Following up on this, the diag shift stuff should be included in the constraint registry for positive constraints. however, i think the default bijection for positive variables doesn't include an offset.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why don't you use the constraint registry to populate the diagonal transform if it is not specified? You can do this with torch.distributions.tranform_to(constraint)
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i think having a separate "diag_shift" argument is sloppy. why not make the default diag_transform
self.diag_transform = torch.distributions.ComposeTransform((
torch.distributions.SoftplusTransform(),
torch.distributions.AffineTransform(1e-5, 1.),
))
? if people want a different shift, they can supply it by the transform. thoughts?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is an older commit, i have removed the diag_shift
in the updated version
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My bad, I see what you did now. I think it is good to keep the shift in the default but implement it as part of the transform. I would recommend the following pattern:
def __init__(self, diag_transform=None):
if diag_transform is None:
diag_transform = torch.distributions.ComposeTransform((
torch.distributions.SoftplusTransform(),
torch.distributions.AffineTransform(1e-5, 1.),
))
super().__init__([FillTriL(), DiagTransform(diag_transform=diag_transform)])
self.diag_transform = diag_transform
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay I added this to FillScaleTriL
… instead of .copy_
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm requesting a couple of changes. First, can we have a params_size
staticmethod attached to the FillScaleTriL transform? that why no one has to remember the formula for the vector size. Second, I don't like mixing hardcoded diag_shift
parameters with transforms. if you want your bijector to have a shift, it should just be in the bijector.
""" | ||
super().__init__() | ||
self.diag_transform = ( | ||
diag_transform if diag_transform is not None else SoftplusTransform() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i think having a separate "diag_shift" argument is sloppy. why not make the default diag_transform
self.diag_transform = torch.distributions.ComposeTransform((
torch.distributions.SoftplusTransform(),
torch.distributions.AffineTransform(1e-5, 1.),
))
? if people want a different shift, they can supply it by the transform. thoughts?
from torch.distributions.utils import vec_to_tril_matrix, tril_matrix_to_vec | ||
|
||
|
||
class FillScaleTriL(Transform): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I request you add a params_size
staticmethod a la tfp's multivariate normal: https://www.tensorflow.org/probability/api_docs/python/tfp/layers/MultivariateNormalTriL#params_size
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Something like
@staticmethod
def params_size(event_size):
# for lower triangular we need n*(n+1)//2 elements
return event_size * (event_size + 1) // 2
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah that's what i am thinking. it's helpful to have when you are in dimensions greater than 3.
@staticmethod | ||
def params_size(event_size): | ||
""" | ||
Returns the number of parameters required to create lower triangular matrix, which is given by n*(n+1)//2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Docstring suggestion: Returns the number of parameters required to create an n-by-n lower triangular matrix...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good! Thanks
This PR adds a
transforms
submodule directory as well as a 'fill_scale_tril.py' transform to convert vectors into a lower triangular matrix with a positive diagonal. The current implementation uses a `SoftplusTransform' to constraint the diagonal. This should be compatible with torch.distributions.transforms.ComposeTransform, but I have not tested it.