Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added new transforms submodules and fill_scale_tril.py transform to c… #5

Merged
merged 13 commits into from
May 28, 2024
Merged
7 changes: 7 additions & 0 deletions src/rs_distributions/transforms/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from .fill_scale_tril import FillScaleTriL, FillTriL, DiagTransform

__all__ = [
"FillScaleTriL",
"FillTriL",
"DiagTransform",
]
127 changes: 127 additions & 0 deletions src/rs_distributions/transforms/fill_scale_tril.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
import torch
from torch.distributions import Transform, ComposeTransform, constraints
from torch.distributions.transforms import SoftplusTransform, AffineTransform
from torch.distributions.utils import vec_to_tril_matrix, tril_matrix_to_vec
kmdalton marked this conversation as resolved.
Show resolved Hide resolved


class FillTriL(Transform):
"""
Transform for converting a real-valued vector into a lower triangular matrix
"""

def __init__(self):
super().__init__()

@property
def domain(self):
return constraints.real_vector

@property
def codomain(self):
return constraints.lower_triangular

@property
def bijective(self):
return True

def _call(self, x):
"""
Converts real-valued vector to lower triangular matrix.

Args:
x (torch.Tensor): input real-valued vector
Returns:
torch.Tensor: Lower triangular matrix
"""

return vec_to_tril_matrix(x)

def _inverse(self, y):
return tril_matrix_to_vec(y)

def log_abs_det_jacobian(self, x, y):
batch_shape = x.shape[:-1]
return torch.zeros(batch_shape, dtype=x.dtype, device=x.device)


class DiagTransform(Transform):
"""
Applies transformation to the diagonal of a square matrix
"""

def __init__(self, diag_transform):
super().__init__()
self.diag_transform = diag_transform

@property
def domain(self):
return self.diag_transform.domain

@property
def codomain(self):
return self.diag_transform.codomain

@property
def bijective(self):
return self.diag_transform.bijective

def _call(self, x):
"""
Args:
x (torch.Tensor): Input matrix
Returns
torch.Tensor: Transformed matrix
"""
diagonal = x.diagonal(dim1=-2, dim2=-1)
transformed_diagonal = self.diag_transform(diagonal)
result = x.diagonal_scatter(transformed_diagonal, dim1=-2, dim2=-1)

return result

def _inverse(self, y):
diagonal = y.diagonal(dim1=-2, dim2=-1)
result = y.diagonal_scatter(self.diag_transform.inv(diagonal), dim1=-2, dim2=-1)
return result

def log_abs_det_jacobian(self, x, y):
diagonal = x.diagonal(dim1=-2, dim2=-1)
return self.diag_transform.log_abs_det_jacobian(diagonal, y)


class FillScaleTriL(ComposeTransform):
"""
A `ComposeTransform` that reshapes a real-valued vector into a lower triangular matrix.
The diagonal of the matrix is transformed with `diag_transform`.
"""

def __init__(self, diag_transform=None):
if diag_transform is None:
diag_transform = torch.distributions.ComposeTransform((
SoftplusTransform(),
AffineTransform(1e-5, 1.),
))
super().__init__([FillTriL(), DiagTransform(diag_transform=diag_transform)])
self.diag_transform = diag_transform

@property
def bijective(self):
return True

def log_abs_det_jacobian(self, x, y):
x = FillTriL()._call(x)
diagonal = x.diagonal(dim1=-2, dim2=-1)
return self.diag_transform.log_abs_det_jacobian(diagonal, diagonal)

@staticmethod
def params_size(event_size):
"""
Returns the number of parameters required to create lower triangular matrix, which is given by n*(n+1)//2
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Docstring suggestion: Returns the number of parameters required to create an n-by-n lower triangular matrix...


Args:
event_size (int): size of event

Returns:
int: Number of parameters needed

"""
return event_size * (event_size + 1) // 2
73 changes: 73 additions & 0 deletions tests/transforms/fill_scale_tril.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import pytest
from rs_distributions.transforms.fill_scale_tril import (
FillScaleTriL,
FillTriL,
DiagTransform,
)
import torch
from torch.distributions.constraints import lower_cholesky
from torch.distributions.transforms import SoftplusTransform, ExpTransform


@pytest.mark.parametrize("batch_shape, d", [((2, 3), 6), ((1, 4, 5), 10)])
def test_forward_transform(batch_shape, d):
transform = FillScaleTriL()
input_shape = batch_shape + (d,)
input_vector = torch.randn(input_shape)
transformed_vector = transform(input_vector)

n = int((-1 + torch.sqrt(torch.tensor(1 + 8 * d))) / 2)
expected_output_shape = batch_shape + (n, n)
cholesky_constraint_check = lower_cholesky.check(transformed_vector)

assert isinstance(transformed_vector, torch.Tensor), "Output is not a torch.Tensor"
assert (
transformed_vector.shape == expected_output_shape
), f"Expected shape {expected_output_shape}, got {transformed_vector.shape}"
assert cholesky_constraint_check.all()


@pytest.mark.parametrize("batch_shape, d", [((2, 3), 6), ((1, 4, 5), 10)])
def test_forward_equals_inverse(batch_shape, d):
transform = FillScaleTriL()
input_shape = batch_shape + (d,)
input_vector = torch.randn(input_shape)
L = transform(input_vector)
invL = transform.inv(L)

assert torch.allclose(
input_vector, invL, atol=1e-4
), "Original input and the result of applying inverse transformation are not close enough"


@pytest.mark.parametrize(
"batch_shape, d, diag_transform",
[
((2, 3), 6, SoftplusTransform()),
((1, 4, 5), 10, SoftplusTransform()),
((2, 3), 6, ExpTransform()),
((1, 4, 5), 10, ExpTransform()),
],
)
def test_log_abs_det_jacobian_softplus_and_exp(batch_shape, d, diag_transform):
transform = FillScaleTriL(diag_transform=diag_transform)
filltril = FillTriL()
diagtransform = DiagTransform(diag_transform=diag_transform)
input_shape = batch_shape + (d,)
input_vector = torch.randn(input_shape, requires_grad=True)
transformed_vector = transform(input_vector)

# Calculate gradients log_abs_det_jacobian from FillScaleTriL
log_abs_det_jacobian = transform.log_abs_det_jacobian(
input_vector, transformed_vector
)

# Extract diagonal elements from input and transformed vectors
tril = filltril(input_vector)
diagonal_transformed = diagtransform(tril)

# Calculate diagonal gradients
diag_jacobian = diagtransform.log_abs_det_jacobian(tril, diagonal_transformed)

# Assert diagonal gradients are approximately equal
assert torch.allclose(diag_jacobian, log_abs_det_jacobian, atol=1e-4)
Loading