diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/layers/__init__.py b/tests/layers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/layers/distributions/__init__.py b/tests/layers/distributions/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/layers/test_einsum_layer.py b/tests/layers/test_einsum_layer.py new file mode 100644 index 0000000..118c6bb --- /dev/null +++ b/tests/layers/test_einsum_layer.py @@ -0,0 +1,72 @@ +from unittest import TestCase + +import torch +from parameterized import parameterized + +from simple_einet.abstract_layers import logits_to_log_weights +from simple_einet.layers.einsum import EinsumLayer +from simple_einet.sampling_utils import index_one_hot +from tests.layers.test_utils import get_sampling_context + + +class TestEinsumLayer(TestCase): + def setUp(self) -> None: + self.layer = EinsumLayer(num_features=4, num_sums_in=3, num_sums_out=2, num_repetitions=5) + + def test_logits_to_log_weights(self): + for dim in range(self.layer.logits.dim()): + log_weights = logits_to_log_weights(self.layer.logits, dim=dim) + sums = log_weights.logsumexp(dim=dim) + target = torch.zeros_like(sums) + self.assertTrue(torch.allclose(sums, target, atol=1e-5)) + + def test_forward_shape(self): + bs = 2 + x = torch.randn(bs, self.layer.num_features, self.layer.num_sums_in, self.layer.num_repetitions) + out = self.layer(x) + self.assertEqual( + out.shape, (bs, self.layer.num_features_out, self.layer.num_sums_out, self.layer.num_repetitions) + ) + + @parameterized.expand([(False,), (True,)]) + def test__sample_from_weights(self, differentiable: bool): + N = 2 + ctx = get_sampling_context(layer=self.layer, num_samples=N, is_differentiable=differentiable) + log_weights = self.layer._select_weights(ctx, self.layer.logits) + indices = self.layer._sample_from_weights(ctx, log_weights) + if differentiable: + self.assertEqual(tuple(indices.shape), (N, self.layer.num_features, self.layer.num_sums_in)) + else: + self.assertEqual(tuple(indices.shape), (N, self.layer.num_features)) + + @parameterized.expand([(False,), (True,)]) + def test__select_weights(self, differentiable: bool): + N = 2 + ctx = get_sampling_context(layer=self.layer, num_samples=N, is_differentiable=differentiable) + weights = self.layer._select_weights(ctx, self.layer.logits) + self.assertEqual(tuple(weights.shape), (N, self.layer.num_features_out, self.layer.num_sums_in**2)) + + @parameterized.expand([(False,), (True,)]) + def test__condition_weights_on_evidence(self, differentiable: bool): + bs = 2 + x = torch.randn(bs, self.layer.num_features, self.layer.num_sums_in, self.layer.num_repetitions) + self.layer._enable_input_cache() + self.layer(x) + + ctx = get_sampling_context(layer=self.layer, num_samples=bs, is_differentiable=differentiable) + log_weights = self.layer._select_weights(ctx, self.layer.logits) + log_weights = self.layer._condition_weights_on_evidence(ctx, log_weights) + sums = log_weights.logsumexp(dim=2) + target = torch.zeros_like(sums) + self.assertTrue(torch.allclose(sums, target, atol=1e-5)) + + def test__differentiable_sampling_has_grads(self): + N = 2 + ctx = get_sampling_context(layer=self.layer, num_samples=N, is_differentiable=True) + ctx = self.layer.sample(ctx) + + sample = torch.randn(N, self.layer.num_features, self.layer.num_sums_in, self.layer.num_repetitions) + sample = index_one_hot(sample, index=ctx.indices_repetition.unsqueeze(1).unsqueeze(2), dim=-1) + sample = index_one_hot(sample, index=ctx.indices_out, dim=-1) + sample.mean().backward() + self.assertTrue(self.layer.logits.grad is not None) diff --git a/tests/layers/test_linsum_layer.py b/tests/layers/test_linsum_layer.py new file mode 100644 index 0000000..cbc7083 --- /dev/null +++ b/tests/layers/test_linsum_layer.py @@ -0,0 +1,72 @@ +from unittest import TestCase + +import torch +from parameterized import parameterized + +from simple_einet.abstract_layers import logits_to_log_weights +from simple_einet.layers.linsum import LinsumLayer +from simple_einet.sampling_utils import index_one_hot +from tests.layers.test_utils import get_sampling_context + + +class TestLinsumLayer(TestCase): + def setUp(self) -> None: + self.layer = LinsumLayer(num_features=4, num_sums_in=3, num_sums_out=2, num_repetitions=5) + + def test_logits_to_log_weights(self): + for dim in range(self.layer.logits.dim()): + log_weights = logits_to_log_weights(self.layer.logits, dim=dim) + sums = log_weights.logsumexp(dim=dim) + target = torch.zeros_like(sums) + self.assertTrue(torch.allclose(sums, target, atol=1e-5)) + + def test_forward_shape(self): + bs = 2 + x = torch.randn(bs, self.layer.num_features, self.layer.num_sums_in, self.layer.num_repetitions) + out = self.layer(x) + self.assertEqual( + out.shape, (bs, self.layer.num_features_out, self.layer.num_sums_out, self.layer.num_repetitions) + ) + + @parameterized.expand([(False,), (True,)]) + def test__condition_weights_on_evidence(self, differentiable: bool): + bs = 2 + x = torch.randn(bs, self.layer.num_features, self.layer.num_sums_in, self.layer.num_repetitions) + self.layer._enable_input_cache() + self.layer(x) + + ctx = get_sampling_context(layer=self.layer, num_samples=bs, is_differentiable=differentiable) + log_weights = self.layer._select_weights(ctx, self.layer.logits) + log_weights = self.layer._condition_weights_on_evidence(ctx, log_weights) + sums = log_weights.logsumexp(dim=2) + target = torch.zeros_like(sums) + self.assertTrue(torch.allclose(sums, target, atol=1e-5)) + + @parameterized.expand([(False,), (True,)]) + def test__sample_from_weights(self, differentiable: bool): + N = 2 + ctx = get_sampling_context(layer=self.layer, num_samples=N, is_differentiable=differentiable) + log_weights = self.layer._select_weights(ctx, self.layer.logits) + indices = self.layer._sample_from_weights(ctx, log_weights) + if differentiable: + self.assertEqual(tuple(indices.shape), (N, self.layer.num_features, self.layer.num_sums_in)) + else: + self.assertEqual(tuple(indices.shape), (N, self.layer.num_features)) + + @parameterized.expand([(False,), (True,)]) + def test__select_weights(self, differentiable: bool): + N = 2 + ctx = get_sampling_context(layer=self.layer, num_samples=N, is_differentiable=differentiable) + weights = self.layer._select_weights(ctx, self.layer.logits) + self.assertEqual(tuple(weights.shape), (N, self.layer.num_features_out, self.layer.num_sums_in)) + + def test__differentiable_sampling_has_grads(self): + N = 2 + ctx = get_sampling_context(layer=self.layer, num_samples=N, is_differentiable=True) + ctx = self.layer.sample(ctx) + + sample = torch.randn(N, self.layer.num_features, self.layer.num_sums_in, self.layer.num_repetitions) + sample = index_one_hot(sample, index=ctx.indices_repetition.unsqueeze(1).unsqueeze(2), dim=-1) + sample = index_one_hot(sample, index=ctx.indices_out, dim=-1) + sample.mean().backward() + self.assertTrue(self.layer.logits.grad is not None) diff --git a/tests/layers/test_mixing_layer.py b/tests/layers/test_mixing_layer.py new file mode 100644 index 0000000..c88e0ce --- /dev/null +++ b/tests/layers/test_mixing_layer.py @@ -0,0 +1,70 @@ +from unittest import TestCase + +import torch +from parameterized import parameterized + +from simple_einet.abstract_layers import logits_to_log_weights +from simple_einet.layers.mixing import MixingLayer +from simple_einet.sampling_utils import index_one_hot +from tests.layers.test_utils import get_sampling_context + + +class TestMixingLayer(TestCase): + def setUp(self) -> None: + self.layer = MixingLayer(num_features=1, num_sums_in=3, num_sums_out=2) + + def test_logits_to_log_weights(self): + for dim in range(self.layer.logits.dim()): + log_weights = logits_to_log_weights(self.layer.logits, dim=dim) + sums = log_weights.logsumexp(dim=dim) + target = torch.zeros_like(sums) + self.assertTrue(torch.allclose(sums, target, atol=1e-5)) + + def test_forward_shape(self): + bs = 2 + x = torch.randn(bs, self.layer.num_features, self.layer.num_sums_out, self.layer.num_sums_in) + out = self.layer(x) + self.assertEqual(out.shape, (bs, self.layer.num_features_out, self.layer.num_sums_out)) + + @parameterized.expand([(False,), (True,)]) + def test__condition_weights_on_evidence(self, differentiable: bool): + bs = 2 + x = torch.randn(bs, self.layer.num_features, self.layer.num_sums_out, self.layer.num_sums_in) + self.layer._enable_input_cache() + self.layer(x) + + ctx = get_sampling_context(layer=self.layer, num_samples=bs, is_differentiable=differentiable) + log_weights = self.layer._select_weights(ctx, self.layer.logits) + log_weights = self.layer._condition_weights_on_evidence(ctx, log_weights) + sums = log_weights.logsumexp(dim=2) + target = torch.zeros_like(sums) + self.assertTrue(torch.allclose(sums, target, atol=1e-5)) + + @parameterized.expand([(False,), (True,)]) + def test__sample_from_weights(self, differentiable: bool): + N = 2 + ctx = get_sampling_context(layer=self.layer, num_samples=N, is_differentiable=differentiable) + log_weights = self.layer._select_weights(ctx, self.layer.logits) + indices = self.layer._sample_from_weights(ctx, log_weights) + if differentiable: + self.assertEqual(tuple(indices.shape), (N, self.layer.num_features, self.layer.num_sums_in)) + else: + self.assertEqual(tuple(indices.shape), (N, self.layer.num_features)) + + @parameterized.expand([(False,), (True,)]) + def test__select_weights(self, differentiable: bool): + N = 2 + ctx = get_sampling_context(layer=self.layer, num_samples=N, is_differentiable=differentiable) + weights = self.layer._select_weights(ctx, self.layer.logits) + self.assertEqual(tuple(weights.shape), (N, self.layer.num_features_out, self.layer.num_sums_in)) + + def test__differentiable_sampling_has_grads(self): + N = 2 + ctx = get_sampling_context(layer=self.layer, num_samples=N, is_differentiable=True) + ctx = self.layer.sample(ctx) + + sample = torch.randn(N, self.layer.num_features, self.layer.num_sums_in, self.layer.num_repetitions) + sample = index_one_hot(sample, index=ctx.indices_repetition.unsqueeze(1).unsqueeze(2), dim=-1) + sample = index_one_hot(sample, index=ctx.indices_out, dim=-1) + sample.mean().backward() + self.assertTrue(self.layer.logits.grad is not None) diff --git a/tests/layers/test_sum_layer.py b/tests/layers/test_sum_layer.py new file mode 100644 index 0000000..9e4631b --- /dev/null +++ b/tests/layers/test_sum_layer.py @@ -0,0 +1,72 @@ +from unittest import TestCase + +import torch +from parameterized import parameterized + +from simple_einet.abstract_layers import logits_to_log_weights +from simple_einet.layers.sum import SumLayer +from simple_einet.sampling_utils import index_one_hot +from tests.layers.test_utils import get_sampling_context + + +class TestSumLayer(TestCase): + def setUp(self) -> None: + self.layer = SumLayer(num_features=4, num_sums_in=3, num_sums_out=2, num_repetitions=5) + + def test_logits_to_log_weights(self): + for dim in range(self.layer.logits.dim()): + log_weights = logits_to_log_weights(self.layer.logits, dim=dim) + sums = log_weights.logsumexp(dim=dim) + target = torch.zeros_like(sums) + self.assertTrue(torch.allclose(sums, target, atol=1e-5)) + + def test_forward_shape(self): + bs = 2 + x = torch.randn(bs, self.layer.num_features, self.layer.num_sums_in, self.layer.num_repetitions) + out = self.layer(x) + self.assertEqual( + out.shape, (bs, self.layer.num_features_out, self.layer.num_sums_out, self.layer.num_repetitions) + ) + + @parameterized.expand([(False,), (True,)]) + def test__sample_from_weights(self, differentiable: bool): + N = 2 + ctx = get_sampling_context(layer=self.layer, num_samples=N, is_differentiable=differentiable) + log_weights = self.layer._select_weights(ctx, self.layer.logits) + indices = self.layer._sample_from_weights(ctx, log_weights) + if differentiable: + self.assertEqual(tuple(indices.shape), (N, self.layer.num_features, self.layer.num_sums_in)) + else: + self.assertEqual(tuple(indices.shape), (N, self.layer.num_features)) + + @parameterized.expand([(False,), (True,)]) + def test__select_weights(self, differentiable: bool): + N = 2 + ctx = get_sampling_context(layer=self.layer, num_samples=N, is_differentiable=differentiable) + weights = self.layer._select_weights(ctx, self.layer.logits) + self.assertEqual(tuple(weights.shape), (N, self.layer.num_features_out, self.layer.num_sums_in)) + + @parameterized.expand([(False,), (True,)]) + def test__condition_weights_on_evidence(self, differentiable: bool): + bs = 2 + x = torch.randn(bs, self.layer.num_features, self.layer.num_sums_in, self.layer.num_repetitions) + self.layer._enable_input_cache() + self.layer(x) + + ctx = get_sampling_context(layer=self.layer, num_samples=bs, is_differentiable=differentiable) + log_weights = self.layer._select_weights(ctx, self.layer.logits) + log_weights = self.layer._condition_weights_on_evidence(ctx, log_weights) + sums = log_weights.logsumexp(dim=2) + target = torch.zeros_like(sums) + self.assertTrue(torch.allclose(sums, target, atol=1e-5)) + + def test__differentiable_sampling_has_grads(self): + N = 2 + ctx = get_sampling_context(layer=self.layer, num_samples=N, is_differentiable=True) + ctx = self.layer.sample(ctx) + + sample = torch.randn(N, self.layer.num_features, self.layer.num_sums_in, self.layer.num_repetitions) + sample = index_one_hot(sample, index=ctx.indices_repetition.unsqueeze(1).unsqueeze(2), dim=-1) + sample = index_one_hot(sample, index=ctx.indices_out, dim=-1) + sample.mean().backward() + self.assertTrue(self.layer.logits.grad is not None) diff --git a/tests/layers/test_utils.py b/tests/layers/test_utils.py new file mode 100644 index 0000000..3feecd0 --- /dev/null +++ b/tests/layers/test_utils.py @@ -0,0 +1,27 @@ +import torch +from torch.nn import functional as F + +from simple_einet.sampling_utils import SamplingContext + + +def get_sampling_context(layer, num_samples: int, is_differentiable: bool = False): + if is_differentiable: + indices_out = torch.randint(low=0, high=layer.num_sums_out, size=(num_samples, layer.num_features_out)) + one_hot_indices_out = F.one_hot(indices_out, num_classes=layer.num_sums_out).float() + indices_repetition = torch.randint(low=0, high=layer.num_repetitions, size=(num_samples,)) + one_hot_indices_repetition = F.one_hot(indices_repetition, num_classes=layer.num_repetitions).float() + one_hot_indices_out.requires_grad_(True) + one_hot_indices_repetition.requires_grad_(True) + return SamplingContext( + num_samples=num_samples, + indices_out=one_hot_indices_out, + indices_repetition=one_hot_indices_repetition, + is_differentiable=True, + ) + else: + return SamplingContext( + num_samples=num_samples, + indices_out=torch.randint(low=0, high=layer.num_sums_out, size=(num_samples, layer.num_features_out)), + indices_repetition=torch.randint(low=0, high=layer.num_repetitions, size=(num_samples,)), + is_differentiable=False, + ) diff --git a/tests/test_einet.py b/tests/test_einet.py new file mode 100644 index 0000000..abfacd8 --- /dev/null +++ b/tests/test_einet.py @@ -0,0 +1,66 @@ +from unittest import TestCase + +from itertools import product +from simple_einet.einet import Einet, EinetConfig +import torch +from parameterized import parameterized + +from simple_einet.abstract_layers import logits_to_log_weights +from simple_einet.layers.distributions.binomial import Binomial +from simple_einet.layers.linsum import LinsumLayer +from simple_einet.sampling_utils import index_one_hot + + +class TestEinet(TestCase): + def make_einet(self, num_classes, num_repetitions): + config = EinetConfig( + num_features=self.num_features, + num_channels=self.num_channels, + depth=self.depth, + num_sums=self.num_sums, + num_leaves=self.num_leaves, + num_repetitions=num_repetitions, + num_classes=num_classes, + leaf_type=self.leaf_type, + leaf_kwargs=self.leaf_kwargs, + layer_type="linsum", + dropout=0.0, + ) + return Einet(config) + + def setUp(self) -> None: + self.num_features = 8 + self.num_channels = 3 + self.num_sums = 5 + self.num_leaves = 2 + self.depth = 3 + self.leaf_type = Binomial + self.leaf_kwargs = {"total_count": 255} + + @parameterized.expand(product([False, True], [1, 3], [1, 4])) + def test_sampling_shapes(self, differentiable: bool, num_classes: int, num_repetitions: int): + model = self.make_einet(num_classes=num_classes, num_repetitions=num_repetitions) + N = 2 + + # Sample without evidence + samples = model.sample(num_samples=N, is_differentiable=differentiable) + self.assertEqual(samples.shape, (N, self.num_channels, self.num_features)) + + # Sample with evidence + evidence = torch.randint(0, 2, size=(N, self.num_channels, self.num_features)) + samples = model.sample(evidence=evidence, is_differentiable=differentiable) + self.assertEqual(samples.shape, (N, self.num_channels, self.num_features)) + + @parameterized.expand(product([False, True], [1, 3], [1, 4])) + def test_mpe_shapes(self, differentiable: bool, num_classes: int, num_repetitions: int): + model = self.make_einet(num_classes=num_classes, num_repetitions=num_repetitions) + N = 2 + + # MPE without evidence + mpe = model.mpe(is_differentiable=differentiable) + self.assertEqual(mpe.shape, (1, self.num_channels, self.num_features)) + + # MPE with evidence + evidence = torch.randint(0, 2, size=(N, self.num_channels, self.num_features)) + mpe = model.mpe(evidence=evidence, is_differentiable=differentiable) + self.assertEqual(mpe.shape, (N, self.num_channels, self.num_features))