diff --git a/src/rs_distributions/transforms/fill_scale_tril.py b/src/rs_distributions/transforms/fill_scale_tril.py index 6f931a9..a741c48 100644 --- a/src/rs_distributions/transforms/fill_scale_tril.py +++ b/src/rs_distributions/transforms/fill_scale_tril.py @@ -5,7 +5,7 @@ class FillScaleTriL(Transform): - def __init__(self, diag_transform=None, diag_shift=1e-05): + def __init__(self, diag_transform=None, diag_shift=1e-06): """ Converts a tensor into a lower triangular matrix with positive diagonal entries. @@ -13,7 +13,7 @@ def __init__(self, diag_transform=None, diag_shift=1e-05): diag_transform: transformation used on diagonal to ensure positive values. Default is SoftplusTransform diag_shift (float): small offset to avoid diagonals very close to zero. - Default offset is 1e-05 + Default offset is 1e-06 """ super().__init__() @@ -22,9 +22,17 @@ def __init__(self, diag_transform=None, diag_shift=1e-05): ) self.diag_shift = diag_shift - domain = constraints.real_vector - codomain = constraints.lower_cholesky - bijective = True + @property + def domain(self): + return constraints.real_vector + + @property + def codomain(self): + return constraints.lower_cholesky + + @property + def bijective(self): + return True def _call(self, x): """ @@ -36,12 +44,14 @@ def _call(self, x): torch.Tensor: Transformed lower triangular matrix """ x = vec_to_tril_matrix(x) - diagonal_elements = x.diagonal(dim1=-2, dim2=-1) - transformed_diagonal = self.diag_transform(diagonal_elements) + diagonal = x.diagonal(dim1=-2, dim2=-1) if self.diag_shift is not None: - transformed_diagonal += self.diag_shift - x.diagonal(dim1=-2, dim2=-1).copy_(transformed_diagonal) - return x + result = x.diagonal_scatter( + self.diag_transform(diagonal + self.diag_shift), dim1=-2, dim2=-1 + ) + else: + result = x.diagonal_scatter(self.diag_transform(diagonal), dim1=-2, dim2=-1) + return result def _inverse(self, y): """ @@ -54,31 +64,29 @@ def _inverse(self, y): torch.Tensor: Inversely transformed vector """ - diagonal_elements = y.diagonal(dim1=-2, dim2=-1) + diagonal = y.diagonal(dim1=-2, dim2=-1) if self.diag_shift is not None: - transformed_diagonal = self.diag_transform.inv( - diagonal_elements - self.diag_shift + result = y.diagonal_scatter( + self.diag_transform.inv(diagonal - self.diag_shift), dim1=-2, dim2=-1 ) else: - transformed_diagonal = self.diag_transform.inv(diagonal_elements) - y.diagonal(dim1=-2, dim2=-1).copy_(transformed_diagonal) - return tril_matrix_to_vec(y) + result = y.diagonal_scatter( + self.diag_transform.inv(diagonal), dim1=-2, dim2=-1 + ) + return tril_matrix_to_vec(result) def log_abs_det_jacobian(self, x, y): - """ - Computes the log absolute determinant of the Jacobian matrix for the transformation. - - Assumes that Softplus is used on the diagonal. - The derivative of the softplus function is the sigmoid function. - - Args: - x (torch.Tensor): Input vector before transformation - y (torch.Tensor): Output lower triangular matrix from _call - - Returns: - torch.Tensor: Log absolute determinant of the Jacobian matrix - """ - diag_elements = y.diagonal(dim1=-2, dim2=-1) - derivatives = torch.sigmoid(diag_elements) - log_det_jacobian = torch.log(derivatives).sum() + L = vec_to_tril_matrix(x) + diag = L.diagonal(dim1=-2, dim2=-1) + diag.requires_grad_(True) + if self.diag_shift is not None: + transformed_diag = self.diag_transform(diag + self.diag_shift) + else: + transformed_diag = self.diag_transform(diag) + derivatives = torch.autograd.grad( + outputs=transformed_diag, + inputs=diag, + grad_outputs=torch.ones_like(transformed_diag), + )[0] + log_det_jacobian = torch.log(torch.abs(derivatives)).sum() return log_det_jacobian diff --git a/tests/transforms/fill_scale_tril.py b/tests/transforms/fill_scale_tril.py index 43f126e..61a005b 100644 --- a/tests/transforms/fill_scale_tril.py +++ b/tests/transforms/fill_scale_tril.py @@ -1,27 +1,43 @@ import pytest from rs_distributions.transforms.fill_scale_tril import FillScaleTriL import torch +from torch.distributions.utils import vec_to_tril_matrix, tril_matrix_to_vec from torch.distributions.constraints import lower_cholesky +from torch.distributions.transforms import ( + ComposeTransform, + ExpTransform, + SoftplusTransform, +) -@pytest.mark.parametrize("input_shape", [(6,), (10,)]) -def test_forward_transform(input_shape): +@pytest.mark.parametrize("batch_shape, d", [((2, 3), 6), ((1, 4, 5), 10)]) +def test_forward_transform(batch_shape, d): transform = FillScaleTriL() + input_shape = batch_shape + (d,) input_vector = torch.randn(input_shape) transformed_vector = transform._call(input_vector) - assert isinstance(transformed_vector, torch.Tensor) - assert transformed_vector.shape == ( - (-1 + torch.sqrt(torch.tensor(1 + input_shape[0] * 8))) / 2, - (-1 + torch.sqrt(torch.tensor(1 + input_shape[0] * 8))) / 2, - ) - assert lower_cholesky.check(transformed_vector) + n = int((-1 + torch.sqrt(torch.tensor(1 + 8 * d))) / 2) + expected_output_shape = batch_shape + (n, n) + cholesky_constraint_check = lower_cholesky.check(transformed_vector) + assert isinstance(transformed_vector, torch.Tensor), "Output is not a torch.Tensor" + assert ( + transformed_vector.shape == expected_output_shape + ), f"Expected shape {expected_output_shape}, got {transformed_vector.shape}" + assert cholesky_constraint_check.all() -@pytest.mark.parametrize("input_vector", [torch.randn(3), torch.randn(6)]) -def test_forward_equals_inverse(input_vector): + +@pytest.mark.parametrize("batch_shape, d", [((2, 3), 6), ((1, 4, 5), 10)]) +def test_forward_equals_inverse(batch_shape, d): transform = FillScaleTriL() + input_shape = batch_shape + (d,) + input_vector = torch.randn(input_shape) L = transform._call(input_vector) invL = transform._inverse(L) - assert torch.allclose(input_vector, invL, atol=1e-6) + n = int((-1 + torch.sqrt(torch.tensor(1 + 8 * d))) / 2) + + assert torch.allclose( + input_vector, invL, atol=1e-4 + ), "Original input and the result of applying inverse transformation are not close enough"