From 70e48f38142a489eb1871a39379779c840dfb3b2 Mon Sep 17 00:00:00 2001 From: Luis Aldama Date: Tue, 7 May 2024 10:30:51 -0400 Subject: [PATCH] Wrote individual FillTriL, DiagTransform. Wrote composition FillScaleTriL --- src/rs_distributions/transforms/__init__.py | 4 +- .../transforms/fill_scale_tril.py | 112 ++++++++---------- tests/transforms/fill_scale_tril.py | 12 +- 3 files changed, 56 insertions(+), 72 deletions(-) diff --git a/src/rs_distributions/transforms/__init__.py b/src/rs_distributions/transforms/__init__.py index 553c7bd..73ce3bc 100644 --- a/src/rs_distributions/transforms/__init__.py +++ b/src/rs_distributions/transforms/__init__.py @@ -1,5 +1,7 @@ -from .fill_scale_tril import FillScaleTriL +from .fill_scale_tril import FillScaleTriL, FillTriL, DiagTransform __all__ = [ "FillScaleTriL", + "FillTriL", + "DiagTransform", ] diff --git a/src/rs_distributions/transforms/fill_scale_tril.py b/src/rs_distributions/transforms/fill_scale_tril.py index a741c48..03c97b4 100644 --- a/src/rs_distributions/transforms/fill_scale_tril.py +++ b/src/rs_distributions/transforms/fill_scale_tril.py @@ -1,26 +1,12 @@ import torch -from torch.distributions import Transform, constraints +from torch.distributions import Transform, ComposeTransform, constraints from torch.distributions.transforms import SoftplusTransform from torch.distributions.utils import vec_to_tril_matrix, tril_matrix_to_vec -class FillScaleTriL(Transform): - def __init__(self, diag_transform=None, diag_shift=1e-06): - """ - Converts a tensor into a lower triangular matrix with positive diagonal entries. - - Args: - diag_transform: transformation used on diagonal to ensure positive values. - Default is SoftplusTransform - diag_shift (float): small offset to avoid diagonals very close to zero. - Default offset is 1e-06 - - """ +class FillTriL(Transform): + def __init__(self): super().__init__() - self.diag_transform = ( - diag_transform if diag_transform is not None else SoftplusTransform() - ) - self.diag_shift = diag_shift @property def domain(self): @@ -28,65 +14,61 @@ def domain(self): @property def codomain(self): - return constraints.lower_cholesky + return constraints.lower_triangular @property def bijective(self): return True def _call(self, x): - """ - Transform input vector to lower triangular. - - Args: - x (torch.Tensor): Input vector to transform - Returns: - torch.Tensor: Transformed lower triangular matrix - """ - x = vec_to_tril_matrix(x) - diagonal = x.diagonal(dim1=-2, dim2=-1) - if self.diag_shift is not None: - result = x.diagonal_scatter( - self.diag_transform(diagonal + self.diag_shift), dim1=-2, dim2=-1 - ) - else: - result = x.diagonal_scatter(self.diag_transform(diagonal), dim1=-2, dim2=-1) - return result + return vec_to_tril_matrix(x) def _inverse(self, y): - """ - Apply the inverse transformation to the input lower triangular matrix. + return tril_matrix_to_vec(y) + + def log_abs_det_jacobian(self, x, y): + return torch.zeros(x.shape[0], dtype=x.dtype, device=x.device) + + +class DiagTransform(Transform): + def __init__(self, diag_transform): + super().__init__() + self.diag_transform = diag_transform + + @property + def domain(self): + return self.diag_transform.domain + + @property + def codomain(self): + return self.diag_transform.codomain - Args: - y (torch.Tensor): Invertible lower triangular matrix + @property + def bijective(self): + return self.diag_transform.bijective + + def _call(self, x): + diagonal = x.diagonal(dim1=-2, dim2=-1) + transformed_diagonal = self.diag_transform(diagonal) + shifted_diag = transformed_diagonal + result = x.diagonal_scatter(shifted_diag, dim1=-2, dim2=-1) - Returns: - torch.Tensor: Inversely transformed vector + return result - """ + def _inverse(self, y): diagonal = y.diagonal(dim1=-2, dim2=-1) - if self.diag_shift is not None: - result = y.diagonal_scatter( - self.diag_transform.inv(diagonal - self.diag_shift), dim1=-2, dim2=-1 - ) - else: - result = y.diagonal_scatter( - self.diag_transform.inv(diagonal), dim1=-2, dim2=-1 - ) - return tril_matrix_to_vec(result) + result = y.diagonal_scatter(self.diag_transform.inv(diagonal), dim1=-2, dim2=-1) + return result def log_abs_det_jacobian(self, x, y): - L = vec_to_tril_matrix(x) - diag = L.diagonal(dim1=-2, dim2=-1) - diag.requires_grad_(True) - if self.diag_shift is not None: - transformed_diag = self.diag_transform(diag + self.diag_shift) - else: - transformed_diag = self.diag_transform(diag) - derivatives = torch.autograd.grad( - outputs=transformed_diag, - inputs=diag, - grad_outputs=torch.ones_like(transformed_diag), - )[0] - log_det_jacobian = torch.log(torch.abs(derivatives)).sum() - return log_det_jacobian + diagonal = x.diagonal(dim1=-2, dim2=-1) + return self.diag_transform.log_abs_det_jacobian(diagonal, y) + + +class FillScaleTriL(ComposeTransform): + def __init__(self, diag_transform=SoftplusTransform()): + super().__init__([FillTriL(), DiagTransform(diag_transform=diag_transform)]) + + @property + def bijective(self): + return True diff --git a/tests/transforms/fill_scale_tril.py b/tests/transforms/fill_scale_tril.py index 33c145f..475960d 100644 --- a/tests/transforms/fill_scale_tril.py +++ b/tests/transforms/fill_scale_tril.py @@ -1,5 +1,7 @@ import pytest -from rs_distributions.transforms.fill_scale_tril import FillScaleTriL +from rs_distributions.transforms.fill_scale_tril import ( + FillScaleTriL, +) import torch from torch.distributions.constraints import lower_cholesky @@ -9,7 +11,7 @@ def test_forward_transform(batch_shape, d): transform = FillScaleTriL() input_shape = batch_shape + (d,) input_vector = torch.randn(input_shape) - transformed_vector = transform._call(input_vector) + transformed_vector = transform(input_vector) n = int((-1 + torch.sqrt(torch.tensor(1 + 8 * d))) / 2) expected_output_shape = batch_shape + (n, n) @@ -27,10 +29,8 @@ def test_forward_equals_inverse(batch_shape, d): transform = FillScaleTriL() input_shape = batch_shape + (d,) input_vector = torch.randn(input_shape) - L = transform._call(input_vector) - invL = transform._inverse(L) - - n = int((-1 + torch.sqrt(torch.tensor(1 + 8 * d))) / 2) + L = transform(input_vector) + invL = transform.inv(L) assert torch.allclose( input_vector, invL, atol=1e-4