From 4979d9ef02028f01c8cb75b767658568945d0fd3 Mon Sep 17 00:00:00 2001 From: Luis Aldama Date: Wed, 24 Apr 2024 23:49:59 -0400 Subject: [PATCH] Added new transforms submodules and fill_scale_tril.py transform to convert vectors into lower triangular matrices --- src/rs_distributions/transforms/__init__.py | 5 ++ .../transforms/fill_scale_tril.py | 84 +++++++++++++++++++ tests/transforms/fill_scale_tril.py | 27 ++++++ 3 files changed, 116 insertions(+) create mode 100644 src/rs_distributions/transforms/__init__.py create mode 100644 src/rs_distributions/transforms/fill_scale_tril.py create mode 100644 tests/transforms/fill_scale_tril.py diff --git a/src/rs_distributions/transforms/__init__.py b/src/rs_distributions/transforms/__init__.py new file mode 100644 index 0000000..553c7bd --- /dev/null +++ b/src/rs_distributions/transforms/__init__.py @@ -0,0 +1,5 @@ +from .fill_scale_tril import FillScaleTriL + +__all__ = [ + "FillScaleTriL", +] diff --git a/src/rs_distributions/transforms/fill_scale_tril.py b/src/rs_distributions/transforms/fill_scale_tril.py new file mode 100644 index 0000000..6f931a9 --- /dev/null +++ b/src/rs_distributions/transforms/fill_scale_tril.py @@ -0,0 +1,84 @@ +import torch +from torch.distributions import Transform, constraints +from torch.distributions.transforms import SoftplusTransform +from torch.distributions.utils import vec_to_tril_matrix, tril_matrix_to_vec + + +class FillScaleTriL(Transform): + def __init__(self, diag_transform=None, diag_shift=1e-05): + """ + Converts a tensor into a lower triangular matrix with positive diagonal entries. + + Args: + diag_transform: transformation used on diagonal to ensure positive values. + Default is SoftplusTransform + diag_shift (float): small offset to avoid diagonals very close to zero. + Default offset is 1e-05 + + """ + super().__init__() + self.diag_transform = ( + diag_transform if diag_transform is not None else SoftplusTransform() + ) + self.diag_shift = diag_shift + + domain = constraints.real_vector + codomain = constraints.lower_cholesky + bijective = True + + def _call(self, x): + """ + Transform input vector to lower triangular. + + Args: + x (torch.Tensor): Input vector to transform + Returns: + torch.Tensor: Transformed lower triangular matrix + """ + x = vec_to_tril_matrix(x) + diagonal_elements = x.diagonal(dim1=-2, dim2=-1) + transformed_diagonal = self.diag_transform(diagonal_elements) + if self.diag_shift is not None: + transformed_diagonal += self.diag_shift + x.diagonal(dim1=-2, dim2=-1).copy_(transformed_diagonal) + return x + + def _inverse(self, y): + """ + Apply the inverse transformation to the input lower triangular matrix. + + Args: + y (torch.Tensor): Invertible lower triangular matrix + + Returns: + torch.Tensor: Inversely transformed vector + + """ + diagonal_elements = y.diagonal(dim1=-2, dim2=-1) + if self.diag_shift is not None: + transformed_diagonal = self.diag_transform.inv( + diagonal_elements - self.diag_shift + ) + else: + transformed_diagonal = self.diag_transform.inv(diagonal_elements) + y.diagonal(dim1=-2, dim2=-1).copy_(transformed_diagonal) + return tril_matrix_to_vec(y) + + def log_abs_det_jacobian(self, x, y): + """ + Computes the log absolute determinant of the Jacobian matrix for the transformation. + + Assumes that Softplus is used on the diagonal. + The derivative of the softplus function is the sigmoid function. + + Args: + x (torch.Tensor): Input vector before transformation + y (torch.Tensor): Output lower triangular matrix from _call + + Returns: + torch.Tensor: Log absolute determinant of the Jacobian matrix + """ + diag_elements = y.diagonal(dim1=-2, dim2=-1) + derivatives = torch.sigmoid(diag_elements) + log_det_jacobian = torch.log(derivatives).sum() + return log_det_jacobian diff --git a/tests/transforms/fill_scale_tril.py b/tests/transforms/fill_scale_tril.py new file mode 100644 index 0000000..43f126e --- /dev/null +++ b/tests/transforms/fill_scale_tril.py @@ -0,0 +1,27 @@ +import pytest +from rs_distributions.transforms.fill_scale_tril import FillScaleTriL +import torch +from torch.distributions.constraints import lower_cholesky + + +@pytest.mark.parametrize("input_shape", [(6,), (10,)]) +def test_forward_transform(input_shape): + transform = FillScaleTriL() + input_vector = torch.randn(input_shape) + transformed_vector = transform._call(input_vector) + + assert isinstance(transformed_vector, torch.Tensor) + assert transformed_vector.shape == ( + (-1 + torch.sqrt(torch.tensor(1 + input_shape[0] * 8))) / 2, + (-1 + torch.sqrt(torch.tensor(1 + input_shape[0] * 8))) / 2, + ) + assert lower_cholesky.check(transformed_vector) + + +@pytest.mark.parametrize("input_vector", [torch.randn(3), torch.randn(6)]) +def test_forward_equals_inverse(input_vector): + transform = FillScaleTriL() + L = transform._call(input_vector) + invL = transform._inverse(L) + + assert torch.allclose(input_vector, invL, atol=1e-6)