-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Tests now handle nested batches. FillScaleTriL uses .diagonal_scatter…
… instead of .copy_
- Loading branch information
Showing
2 changed files
with
67 additions
and
43 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,27 +1,43 @@ | ||
import pytest | ||
from rs_distributions.transforms.fill_scale_tril import FillScaleTriL | ||
import torch | ||
from torch.distributions.utils import vec_to_tril_matrix, tril_matrix_to_vec | ||
from torch.distributions.constraints import lower_cholesky | ||
from torch.distributions.transforms import ( | ||
ComposeTransform, | ||
ExpTransform, | ||
SoftplusTransform, | ||
) | ||
|
||
|
||
@pytest.mark.parametrize("input_shape", [(6,), (10,)]) | ||
def test_forward_transform(input_shape): | ||
@pytest.mark.parametrize("batch_shape, d", [((2, 3), 6), ((1, 4, 5), 10)]) | ||
def test_forward_transform(batch_shape, d): | ||
transform = FillScaleTriL() | ||
input_shape = batch_shape + (d,) | ||
input_vector = torch.randn(input_shape) | ||
transformed_vector = transform._call(input_vector) | ||
|
||
assert isinstance(transformed_vector, torch.Tensor) | ||
assert transformed_vector.shape == ( | ||
(-1 + torch.sqrt(torch.tensor(1 + input_shape[0] * 8))) / 2, | ||
(-1 + torch.sqrt(torch.tensor(1 + input_shape[0] * 8))) / 2, | ||
) | ||
assert lower_cholesky.check(transformed_vector) | ||
n = int((-1 + torch.sqrt(torch.tensor(1 + 8 * d))) / 2) | ||
expected_output_shape = batch_shape + (n, n) | ||
cholesky_constraint_check = lower_cholesky.check(transformed_vector) | ||
|
||
assert isinstance(transformed_vector, torch.Tensor), "Output is not a torch.Tensor" | ||
assert ( | ||
transformed_vector.shape == expected_output_shape | ||
), f"Expected shape {expected_output_shape}, got {transformed_vector.shape}" | ||
assert cholesky_constraint_check.all() | ||
|
||
@pytest.mark.parametrize("input_vector", [torch.randn(3), torch.randn(6)]) | ||
def test_forward_equals_inverse(input_vector): | ||
|
||
@pytest.mark.parametrize("batch_shape, d", [((2, 3), 6), ((1, 4, 5), 10)]) | ||
def test_forward_equals_inverse(batch_shape, d): | ||
transform = FillScaleTriL() | ||
input_shape = batch_shape + (d,) | ||
input_vector = torch.randn(input_shape) | ||
L = transform._call(input_vector) | ||
invL = transform._inverse(L) | ||
|
||
assert torch.allclose(input_vector, invL, atol=1e-6) | ||
n = int((-1 + torch.sqrt(torch.tensor(1 + 8 * d))) / 2) | ||
|
||
assert torch.allclose( | ||
input_vector, invL, atol=1e-4 | ||
), "Original input and the result of applying inverse transformation are not close enough" |