Skip to content

Commit

Permalink
Added new transforms submodules and fill_scale_tril.py transform to c… (
Browse files Browse the repository at this point in the history
#5)

* Added new transforms submodules and fill_scale_tril.py transform to convert vectors into lower triangular matrices

* Tests now handle nested batches. FillScaleTriL uses .diagonal_scatter instead of .copy_

* Removed unused imports

* Wrote individual FillTriL, DiagTransform. Wrote composition FillScaleTriL

* Added test for `FillScaleTriL().log_abs_det_jacobian`

* Fixed documentation errors

* added `params_size` and default AffineTransform shift to `FillScaleTriL`

* Update fill_scale_tril.py

* Update fill_scale_tril.py

* Update fill_scale_tril.py

* Update fill_scale_tril.py

* Update fill_scale_tril.py

* Update fill_scale_tril.py
  • Loading branch information
LuisA92 authored May 28, 2024
1 parent 2a61d44 commit 55f9957
Show file tree
Hide file tree
Showing 3 changed files with 208 additions and 0 deletions.
7 changes: 7 additions & 0 deletions src/rs_distributions/transforms/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from .fill_scale_tril import FillScaleTriL, FillTriL, DiagTransform

__all__ = [
"FillScaleTriL",
"FillTriL",
"DiagTransform",
]
128 changes: 128 additions & 0 deletions src/rs_distributions/transforms/fill_scale_tril.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
import torch
from torch.distributions import Transform, ComposeTransform, constraints
from torch.distributions.transforms import SoftplusTransform, AffineTransform
from torch.distributions.utils import vec_to_tril_matrix, tril_matrix_to_vec


class FillTriL(Transform):
"""
Transform for converting a real-valued vector into a lower triangular matrix
"""

def __init__(self):
super().__init__()

@property
def domain(self):
return constraints.real_vector

@property
def codomain(self):
return constraints.lower_triangular

@property
def bijective(self):
return True

def _call(self, x):
"""
Converts real-valued vector to lower triangular matrix.
Args:
x (torch.Tensor): input real-valued vector
Returns:
torch.Tensor: Lower triangular matrix
"""

return vec_to_tril_matrix(x)

def _inverse(self, y):
return tril_matrix_to_vec(y)

def log_abs_det_jacobian(self, x, y):
batch_shape = x.shape[:-1]
return torch.zeros(batch_shape, dtype=x.dtype, device=x.device)


class DiagTransform(Transform):
"""
Applies transformation to the diagonal of a square matrix
"""

def __init__(self, diag_transform):
super().__init__()
self.diag_transform = diag_transform

@property
def domain(self):
return self.diag_transform.domain

@property
def codomain(self):
return self.diag_transform.codomain

@property
def bijective(self):
return self.diag_transform.bijective

def _call(self, x):
"""
Args:
x (torch.Tensor): Input matrix
Returns
torch.Tensor: Transformed matrix
"""
diagonal = x.diagonal(dim1=-2, dim2=-1)
transformed_diagonal = self.diag_transform(diagonal)
result = x.diagonal_scatter(transformed_diagonal, dim1=-2, dim2=-1)

return result

def _inverse(self, y):
diagonal = y.diagonal(dim1=-2, dim2=-1)
result = y.diagonal_scatter(self.diag_transform.inv(diagonal), dim1=-2, dim2=-1)
return result

def log_abs_det_jacobian(self, x, y):
diagonal = x.diagonal(dim1=-2, dim2=-1)
return self.diag_transform.log_abs_det_jacobian(diagonal, y)


class FillScaleTriL(ComposeTransform):
"""
A `ComposeTransform` that reshapes a real-valued vector into a lower triangular matrix.
The diagonal of the matrix is transformed with `diag_transform`.
"""

def __init__(self, diag_transform=None):
if diag_transform is None:
diag_transform = torch.distributions.ComposeTransform(
(
SoftplusTransform(),
AffineTransform(1e-5, 1.0),
)
)
super().__init__([FillTriL(), DiagTransform(diag_transform=diag_transform)])
self.diag_transform = diag_transform

@property
def bijective(self):
return True

def log_abs_det_jacobian(self, x, y):
x = FillTriL()._call(x)
diagonal = x.diagonal(dim1=-2, dim2=-1)
return self.diag_transform.log_abs_det_jacobian(diagonal, diagonal)

@staticmethod
def params_size(event_size):
"""
Returns the number of parameters required to create an n-by-n lower triangular matrix, which is given by n*(n+1)//2
Args:
event_size (int): size of event
Returns:
int: Number of parameters needed
"""
return event_size * (event_size + 1) // 2
73 changes: 73 additions & 0 deletions tests/transforms/fill_scale_tril.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import pytest
from rs_distributions.transforms.fill_scale_tril import (
FillScaleTriL,
FillTriL,
DiagTransform,
)
import torch
from torch.distributions.constraints import lower_cholesky
from torch.distributions.transforms import SoftplusTransform, ExpTransform


@pytest.mark.parametrize("batch_shape, d", [((2, 3), 6), ((1, 4, 5), 10)])
def test_forward_transform(batch_shape, d):
transform = FillScaleTriL()
input_shape = batch_shape + (d,)
input_vector = torch.randn(input_shape)
transformed_vector = transform(input_vector)

n = int((-1 + torch.sqrt(torch.tensor(1 + 8 * d))) / 2)
expected_output_shape = batch_shape + (n, n)
cholesky_constraint_check = lower_cholesky.check(transformed_vector)

assert isinstance(transformed_vector, torch.Tensor), "Output is not a torch.Tensor"
assert (
transformed_vector.shape == expected_output_shape
), f"Expected shape {expected_output_shape}, got {transformed_vector.shape}"
assert cholesky_constraint_check.all()


@pytest.mark.parametrize("batch_shape, d", [((2, 3), 6), ((1, 4, 5), 10)])
def test_forward_equals_inverse(batch_shape, d):
transform = FillScaleTriL()
input_shape = batch_shape + (d,)
input_vector = torch.randn(input_shape)
L = transform(input_vector)
invL = transform.inv(L)

assert torch.allclose(
input_vector, invL, atol=1e-4
), "Original input and the result of applying inverse transformation are not close enough"


@pytest.mark.parametrize(
"batch_shape, d, diag_transform",
[
((2, 3), 6, SoftplusTransform()),
((1, 4, 5), 10, SoftplusTransform()),
((2, 3), 6, ExpTransform()),
((1, 4, 5), 10, ExpTransform()),
],
)
def test_log_abs_det_jacobian_softplus_and_exp(batch_shape, d, diag_transform):
transform = FillScaleTriL(diag_transform=diag_transform)
filltril = FillTriL()
diagtransform = DiagTransform(diag_transform=diag_transform)
input_shape = batch_shape + (d,)
input_vector = torch.randn(input_shape, requires_grad=True)
transformed_vector = transform(input_vector)

# Calculate gradients log_abs_det_jacobian from FillScaleTriL
log_abs_det_jacobian = transform.log_abs_det_jacobian(
input_vector, transformed_vector
)

# Extract diagonal elements from input and transformed vectors
tril = filltril(input_vector)
diagonal_transformed = diagtransform(tril)

# Calculate diagonal gradients
diag_jacobian = diagtransform.log_abs_det_jacobian(tril, diagonal_transformed)

# Assert diagonal gradients are approximately equal
assert torch.allclose(diag_jacobian, log_abs_det_jacobian, atol=1e-4)

0 comments on commit 55f9957

Please sign in to comment.