-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added new transforms submodules and fill_scale_tril.py transform to c… (
#5) * Added new transforms submodules and fill_scale_tril.py transform to convert vectors into lower triangular matrices * Tests now handle nested batches. FillScaleTriL uses .diagonal_scatter instead of .copy_ * Removed unused imports * Wrote individual FillTriL, DiagTransform. Wrote composition FillScaleTriL * Added test for `FillScaleTriL().log_abs_det_jacobian` * Fixed documentation errors * added `params_size` and default AffineTransform shift to `FillScaleTriL` * Update fill_scale_tril.py * Update fill_scale_tril.py * Update fill_scale_tril.py * Update fill_scale_tril.py * Update fill_scale_tril.py * Update fill_scale_tril.py
- Loading branch information
Showing
3 changed files
with
208 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
from .fill_scale_tril import FillScaleTriL, FillTriL, DiagTransform | ||
|
||
__all__ = [ | ||
"FillScaleTriL", | ||
"FillTriL", | ||
"DiagTransform", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,128 @@ | ||
import torch | ||
from torch.distributions import Transform, ComposeTransform, constraints | ||
from torch.distributions.transforms import SoftplusTransform, AffineTransform | ||
from torch.distributions.utils import vec_to_tril_matrix, tril_matrix_to_vec | ||
|
||
|
||
class FillTriL(Transform): | ||
""" | ||
Transform for converting a real-valued vector into a lower triangular matrix | ||
""" | ||
|
||
def __init__(self): | ||
super().__init__() | ||
|
||
@property | ||
def domain(self): | ||
return constraints.real_vector | ||
|
||
@property | ||
def codomain(self): | ||
return constraints.lower_triangular | ||
|
||
@property | ||
def bijective(self): | ||
return True | ||
|
||
def _call(self, x): | ||
""" | ||
Converts real-valued vector to lower triangular matrix. | ||
Args: | ||
x (torch.Tensor): input real-valued vector | ||
Returns: | ||
torch.Tensor: Lower triangular matrix | ||
""" | ||
|
||
return vec_to_tril_matrix(x) | ||
|
||
def _inverse(self, y): | ||
return tril_matrix_to_vec(y) | ||
|
||
def log_abs_det_jacobian(self, x, y): | ||
batch_shape = x.shape[:-1] | ||
return torch.zeros(batch_shape, dtype=x.dtype, device=x.device) | ||
|
||
|
||
class DiagTransform(Transform): | ||
""" | ||
Applies transformation to the diagonal of a square matrix | ||
""" | ||
|
||
def __init__(self, diag_transform): | ||
super().__init__() | ||
self.diag_transform = diag_transform | ||
|
||
@property | ||
def domain(self): | ||
return self.diag_transform.domain | ||
|
||
@property | ||
def codomain(self): | ||
return self.diag_transform.codomain | ||
|
||
@property | ||
def bijective(self): | ||
return self.diag_transform.bijective | ||
|
||
def _call(self, x): | ||
""" | ||
Args: | ||
x (torch.Tensor): Input matrix | ||
Returns | ||
torch.Tensor: Transformed matrix | ||
""" | ||
diagonal = x.diagonal(dim1=-2, dim2=-1) | ||
transformed_diagonal = self.diag_transform(diagonal) | ||
result = x.diagonal_scatter(transformed_diagonal, dim1=-2, dim2=-1) | ||
|
||
return result | ||
|
||
def _inverse(self, y): | ||
diagonal = y.diagonal(dim1=-2, dim2=-1) | ||
result = y.diagonal_scatter(self.diag_transform.inv(diagonal), dim1=-2, dim2=-1) | ||
return result | ||
|
||
def log_abs_det_jacobian(self, x, y): | ||
diagonal = x.diagonal(dim1=-2, dim2=-1) | ||
return self.diag_transform.log_abs_det_jacobian(diagonal, y) | ||
|
||
|
||
class FillScaleTriL(ComposeTransform): | ||
""" | ||
A `ComposeTransform` that reshapes a real-valued vector into a lower triangular matrix. | ||
The diagonal of the matrix is transformed with `diag_transform`. | ||
""" | ||
|
||
def __init__(self, diag_transform=None): | ||
if diag_transform is None: | ||
diag_transform = torch.distributions.ComposeTransform( | ||
( | ||
SoftplusTransform(), | ||
AffineTransform(1e-5, 1.0), | ||
) | ||
) | ||
super().__init__([FillTriL(), DiagTransform(diag_transform=diag_transform)]) | ||
self.diag_transform = diag_transform | ||
|
||
@property | ||
def bijective(self): | ||
return True | ||
|
||
def log_abs_det_jacobian(self, x, y): | ||
x = FillTriL()._call(x) | ||
diagonal = x.diagonal(dim1=-2, dim2=-1) | ||
return self.diag_transform.log_abs_det_jacobian(diagonal, diagonal) | ||
|
||
@staticmethod | ||
def params_size(event_size): | ||
""" | ||
Returns the number of parameters required to create an n-by-n lower triangular matrix, which is given by n*(n+1)//2 | ||
Args: | ||
event_size (int): size of event | ||
Returns: | ||
int: Number of parameters needed | ||
""" | ||
return event_size * (event_size + 1) // 2 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
import pytest | ||
from rs_distributions.transforms.fill_scale_tril import ( | ||
FillScaleTriL, | ||
FillTriL, | ||
DiagTransform, | ||
) | ||
import torch | ||
from torch.distributions.constraints import lower_cholesky | ||
from torch.distributions.transforms import SoftplusTransform, ExpTransform | ||
|
||
|
||
@pytest.mark.parametrize("batch_shape, d", [((2, 3), 6), ((1, 4, 5), 10)]) | ||
def test_forward_transform(batch_shape, d): | ||
transform = FillScaleTriL() | ||
input_shape = batch_shape + (d,) | ||
input_vector = torch.randn(input_shape) | ||
transformed_vector = transform(input_vector) | ||
|
||
n = int((-1 + torch.sqrt(torch.tensor(1 + 8 * d))) / 2) | ||
expected_output_shape = batch_shape + (n, n) | ||
cholesky_constraint_check = lower_cholesky.check(transformed_vector) | ||
|
||
assert isinstance(transformed_vector, torch.Tensor), "Output is not a torch.Tensor" | ||
assert ( | ||
transformed_vector.shape == expected_output_shape | ||
), f"Expected shape {expected_output_shape}, got {transformed_vector.shape}" | ||
assert cholesky_constraint_check.all() | ||
|
||
|
||
@pytest.mark.parametrize("batch_shape, d", [((2, 3), 6), ((1, 4, 5), 10)]) | ||
def test_forward_equals_inverse(batch_shape, d): | ||
transform = FillScaleTriL() | ||
input_shape = batch_shape + (d,) | ||
input_vector = torch.randn(input_shape) | ||
L = transform(input_vector) | ||
invL = transform.inv(L) | ||
|
||
assert torch.allclose( | ||
input_vector, invL, atol=1e-4 | ||
), "Original input and the result of applying inverse transformation are not close enough" | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"batch_shape, d, diag_transform", | ||
[ | ||
((2, 3), 6, SoftplusTransform()), | ||
((1, 4, 5), 10, SoftplusTransform()), | ||
((2, 3), 6, ExpTransform()), | ||
((1, 4, 5), 10, ExpTransform()), | ||
], | ||
) | ||
def test_log_abs_det_jacobian_softplus_and_exp(batch_shape, d, diag_transform): | ||
transform = FillScaleTriL(diag_transform=diag_transform) | ||
filltril = FillTriL() | ||
diagtransform = DiagTransform(diag_transform=diag_transform) | ||
input_shape = batch_shape + (d,) | ||
input_vector = torch.randn(input_shape, requires_grad=True) | ||
transformed_vector = transform(input_vector) | ||
|
||
# Calculate gradients log_abs_det_jacobian from FillScaleTriL | ||
log_abs_det_jacobian = transform.log_abs_det_jacobian( | ||
input_vector, transformed_vector | ||
) | ||
|
||
# Extract diagonal elements from input and transformed vectors | ||
tril = filltril(input_vector) | ||
diagonal_transformed = diagtransform(tril) | ||
|
||
# Calculate diagonal gradients | ||
diag_jacobian = diagtransform.log_abs_det_jacobian(tril, diagonal_transformed) | ||
|
||
# Assert diagonal gradients are approximately equal | ||
assert torch.allclose(diag_jacobian, log_abs_det_jacobian, atol=1e-4) |