diff --git a/src/rs_distributions/transforms/__init__.py b/src/rs_distributions/transforms/__init__.py new file mode 100644 index 0000000..73ce3bc --- /dev/null +++ b/src/rs_distributions/transforms/__init__.py @@ -0,0 +1,7 @@ +from .fill_scale_tril import FillScaleTriL, FillTriL, DiagTransform + +__all__ = [ + "FillScaleTriL", + "FillTriL", + "DiagTransform", +] diff --git a/src/rs_distributions/transforms/fill_scale_tril.py b/src/rs_distributions/transforms/fill_scale_tril.py new file mode 100644 index 0000000..8b5cbe4 --- /dev/null +++ b/src/rs_distributions/transforms/fill_scale_tril.py @@ -0,0 +1,128 @@ +import torch +from torch.distributions import Transform, ComposeTransform, constraints +from torch.distributions.transforms import SoftplusTransform, AffineTransform +from torch.distributions.utils import vec_to_tril_matrix, tril_matrix_to_vec + + +class FillTriL(Transform): + """ + Transform for converting a real-valued vector into a lower triangular matrix + """ + + def __init__(self): + super().__init__() + + @property + def domain(self): + return constraints.real_vector + + @property + def codomain(self): + return constraints.lower_triangular + + @property + def bijective(self): + return True + + def _call(self, x): + """ + Converts real-valued vector to lower triangular matrix. + + Args: + x (torch.Tensor): input real-valued vector + Returns: + torch.Tensor: Lower triangular matrix + """ + + return vec_to_tril_matrix(x) + + def _inverse(self, y): + return tril_matrix_to_vec(y) + + def log_abs_det_jacobian(self, x, y): + batch_shape = x.shape[:-1] + return torch.zeros(batch_shape, dtype=x.dtype, device=x.device) + + +class DiagTransform(Transform): + """ + Applies transformation to the diagonal of a square matrix + """ + + def __init__(self, diag_transform): + super().__init__() + self.diag_transform = diag_transform + + @property + def domain(self): + return self.diag_transform.domain + + @property + def codomain(self): + return self.diag_transform.codomain + + @property + def bijective(self): + return self.diag_transform.bijective + + def _call(self, x): + """ + Args: + x (torch.Tensor): Input matrix + Returns + torch.Tensor: Transformed matrix + """ + diagonal = x.diagonal(dim1=-2, dim2=-1) + transformed_diagonal = self.diag_transform(diagonal) + result = x.diagonal_scatter(transformed_diagonal, dim1=-2, dim2=-1) + + return result + + def _inverse(self, y): + diagonal = y.diagonal(dim1=-2, dim2=-1) + result = y.diagonal_scatter(self.diag_transform.inv(diagonal), dim1=-2, dim2=-1) + return result + + def log_abs_det_jacobian(self, x, y): + diagonal = x.diagonal(dim1=-2, dim2=-1) + return self.diag_transform.log_abs_det_jacobian(diagonal, y) + + +class FillScaleTriL(ComposeTransform): + """ + A `ComposeTransform` that reshapes a real-valued vector into a lower triangular matrix. + The diagonal of the matrix is transformed with `diag_transform`. + """ + + def __init__(self, diag_transform=None): + if diag_transform is None: + diag_transform = torch.distributions.ComposeTransform( + ( + SoftplusTransform(), + AffineTransform(1e-5, 1.0), + ) + ) + super().__init__([FillTriL(), DiagTransform(diag_transform=diag_transform)]) + self.diag_transform = diag_transform + + @property + def bijective(self): + return True + + def log_abs_det_jacobian(self, x, y): + x = FillTriL()._call(x) + diagonal = x.diagonal(dim1=-2, dim2=-1) + return self.diag_transform.log_abs_det_jacobian(diagonal, diagonal) + + @staticmethod + def params_size(event_size): + """ + Returns the number of parameters required to create an n-by-n lower triangular matrix, which is given by n*(n+1)//2 + + Args: + event_size (int): size of event + Returns: + int: Number of parameters needed + + """ + return event_size * (event_size + 1) // 2 diff --git a/tests/transforms/fill_scale_tril.py b/tests/transforms/fill_scale_tril.py new file mode 100644 index 0000000..8f01cae --- /dev/null +++ b/tests/transforms/fill_scale_tril.py @@ -0,0 +1,73 @@ +import pytest +from rs_distributions.transforms.fill_scale_tril import ( + FillScaleTriL, + FillTriL, + DiagTransform, +) +import torch +from torch.distributions.constraints import lower_cholesky +from torch.distributions.transforms import SoftplusTransform, ExpTransform + + +@pytest.mark.parametrize("batch_shape, d", [((2, 3), 6), ((1, 4, 5), 10)]) +def test_forward_transform(batch_shape, d): + transform = FillScaleTriL() + input_shape = batch_shape + (d,) + input_vector = torch.randn(input_shape) + transformed_vector = transform(input_vector) + + n = int((-1 + torch.sqrt(torch.tensor(1 + 8 * d))) / 2) + expected_output_shape = batch_shape + (n, n) + cholesky_constraint_check = lower_cholesky.check(transformed_vector) + + assert isinstance(transformed_vector, torch.Tensor), "Output is not a torch.Tensor" + assert ( + transformed_vector.shape == expected_output_shape + ), f"Expected shape {expected_output_shape}, got {transformed_vector.shape}" + assert cholesky_constraint_check.all() + + +@pytest.mark.parametrize("batch_shape, d", [((2, 3), 6), ((1, 4, 5), 10)]) +def test_forward_equals_inverse(batch_shape, d): + transform = FillScaleTriL() + input_shape = batch_shape + (d,) + input_vector = torch.randn(input_shape) + L = transform(input_vector) + invL = transform.inv(L) + + assert torch.allclose( + input_vector, invL, atol=1e-4 + ), "Original input and the result of applying inverse transformation are not close enough" + + +@pytest.mark.parametrize( + "batch_shape, d, diag_transform", + [ + ((2, 3), 6, SoftplusTransform()), + ((1, 4, 5), 10, SoftplusTransform()), + ((2, 3), 6, ExpTransform()), + ((1, 4, 5), 10, ExpTransform()), + ], +) +def test_log_abs_det_jacobian_softplus_and_exp(batch_shape, d, diag_transform): + transform = FillScaleTriL(diag_transform=diag_transform) + filltril = FillTriL() + diagtransform = DiagTransform(diag_transform=diag_transform) + input_shape = batch_shape + (d,) + input_vector = torch.randn(input_shape, requires_grad=True) + transformed_vector = transform(input_vector) + + # Calculate gradients log_abs_det_jacobian from FillScaleTriL + log_abs_det_jacobian = transform.log_abs_det_jacobian( + input_vector, transformed_vector + ) + + # Extract diagonal elements from input and transformed vectors + tril = filltril(input_vector) + diagonal_transformed = diagtransform(tril) + + # Calculate diagonal gradients + diag_jacobian = diagtransform.log_abs_det_jacobian(tril, diagonal_transformed) + + # Assert diagonal gradients are approximately equal + assert torch.allclose(diag_jacobian, log_abs_det_jacobian, atol=1e-4)