Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added new transforms submodules and fill_scale_tril.py transform to c… #5

Merged
merged 13 commits into from
May 28, 2024

Conversation

LuisA92
Copy link
Collaborator

@LuisA92 LuisA92 commented Apr 25, 2024

This PR adds a transforms submodule directory as well as a 'fill_scale_tril.py' transform to convert vectors into a lower triangular matrix with a positive diagonal. The current implementation uses a `SoftplusTransform' to constraint the diagonal. This should be compatible with torch.distributions.transforms.ComposeTransform, but I have not tested it.

…onvert vectors into lower triangular matrices
@LuisA92 LuisA92 requested review from kmdalton and minhuanli April 25, 2024 03:58
@LuisA92
Copy link
Collaborator Author

LuisA92 commented Apr 25, 2024

The code is failing because torch.distributions.transforms.ComposeTransform requires me to define the variables domain, codomain, and bijective but they are not used in the class.

Copy link
Member

@kmdalton kmdalton left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this looks like a good start. i think the biggest things that need addressing are

  1. support ComposeTransform and add a test to make sure it works (maybe this is not possible for some reason?).
  2. support bijectors other than softplus for the diagonal. i maintain that hardcoding the derivatives in the log_det_jacobian block is too restrictive. it should be relatively easy to get the necessary info from self.diag_transform

"""
super().__init__()
self.diag_transform = (
diag_transform if diag_transform is not None else SoftplusTransform()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is not a bug, but we should discuss as a group whether we want to use ExpTransform or SofplusTransform in cases like these. either way, you might consider refactoring this to grab the appropriate constraint for positivity from the constraints registry. (https://pytorch.org/docs/stable/distributions.html#module-torch.distributions.constraint_registry)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Following up on this, the diag shift stuff should be included in the constraint registry for positive constraints. however, i think the default bijection for positive variables doesn't include an offset.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why don't you use the constraint registry to populate the diagonal transform if it is not specified? You can do this with torch.distributions.tranform_to(constraint).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think having a separate "diag_shift" argument is sloppy. why not make the default diag_transform

self.diag_transform = torch.distributions.ComposeTransform((
    torch.distributions.SoftplusTransform(),
    torch.distributions.AffineTransform(1e-5, 1.),
))

? if people want a different shift, they can supply it by the transform. thoughts?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is an older commit, i have removed the diag_shift in the updated version

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My bad, I see what you did now. I think it is good to keep the shift in the default but implement it as part of the transform. I would recommend the following pattern:

    def __init__(self, diag_transform=None):
        if diag_transform is None:
            diag_transform = torch.distributions.ComposeTransform((
                 torch.distributions.SoftplusTransform(),
                 torch.distributions.AffineTransform(1e-5, 1.),
             ))
        super().__init__([FillTriL(), DiagTransform(diag_transform=diag_transform)])
        self.diag_transform = diag_transform

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay I added this to FillScaleTriL

src/rs_distributions/transforms/fill_scale_tril.py Outdated Show resolved Hide resolved
src/rs_distributions/transforms/fill_scale_tril.py Outdated Show resolved Hide resolved
tests/transforms/fill_scale_tril.py Outdated Show resolved Hide resolved
tests/transforms/fill_scale_tril.py Outdated Show resolved Hide resolved
src/rs_distributions/transforms/fill_scale_tril.py Outdated Show resolved Hide resolved
Copy link
Member

@kmdalton kmdalton left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm requesting a couple of changes. First, can we have a params_size staticmethod attached to the FillScaleTriL transform? that why no one has to remember the formula for the vector size. Second, I don't like mixing hardcoded diag_shift parameters with transforms. if you want your bijector to have a shift, it should just be in the bijector.

"""
super().__init__()
self.diag_transform = (
diag_transform if diag_transform is not None else SoftplusTransform()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think having a separate "diag_shift" argument is sloppy. why not make the default diag_transform

self.diag_transform = torch.distributions.ComposeTransform((
    torch.distributions.SoftplusTransform(),
    torch.distributions.AffineTransform(1e-5, 1.),
))

? if people want a different shift, they can supply it by the transform. thoughts?

from torch.distributions.utils import vec_to_tril_matrix, tril_matrix_to_vec


class FillScaleTriL(Transform):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I request you add a params_size staticmethod a la tfp's multivariate normal: https://www.tensorflow.org/probability/api_docs/python/tfp/layers/MultivariateNormalTriL#params_size

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Something like

@staticmethod
def params_size(event_size):
# for lower triangular we need n*(n+1)//2 elements
return event_size * (event_size + 1) // 2 

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah that's what i am thinking. it's helpful to have when you are in dimensions greater than 3.

@staticmethod
def params_size(event_size):
"""
Returns the number of parameters required to create lower triangular matrix, which is given by n*(n+1)//2
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Docstring suggestion: Returns the number of parameters required to create an n-by-n lower triangular matrix...

Copy link
Member

@kmdalton kmdalton left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good! Thanks

@LuisA92 LuisA92 merged commit 55f9957 into rs-station:main May 28, 2024
15 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants