From d4f1a2fcef5a25f443803a71e2b5794b2c21ed64 Mon Sep 17 00:00:00 2001 From: Steven Braun Date: Wed, 4 Oct 2023 07:56:45 +0200 Subject: [PATCH] tests: Add logit grad check for diff sampling --- simple_einet/tests/layers/test_einsum_layer.py | 12 ++++++++++++ simple_einet/tests/layers/test_linsum_layer.py | 12 ++++++++++++ simple_einet/tests/layers/test_mixing_layer.py | 12 ++++++++++++ simple_einet/tests/layers/test_sum_layer.py | 12 ++++++++++++ simple_einet/tests/layers/test_utils.py | 8 +++++--- 5 files changed, 53 insertions(+), 3 deletions(-) diff --git a/simple_einet/tests/layers/test_einsum_layer.py b/simple_einet/tests/layers/test_einsum_layer.py index 96ecb52..01b3766 100644 --- a/simple_einet/tests/layers/test_einsum_layer.py +++ b/simple_einet/tests/layers/test_einsum_layer.py @@ -5,6 +5,7 @@ from simple_einet.abstract_layers import logits_to_log_weights from simple_einet.layers.einsum import EinsumLayer +from simple_einet.sampling_utils import index_one_hot from simple_einet.tests.layers.test_utils import get_sampling_context @@ -58,3 +59,14 @@ def test__condition_weights_on_evidence(self, differentiable: bool): sums = log_weights.logsumexp(dim=2) target = torch.zeros_like(sums) self.assertTrue(torch.allclose(sums, target, atol=1e-5)) + + def test__differentiable_sampling_has_grads(self): + N = 2 + ctx = get_sampling_context(layer=self.layer, num_samples=N, is_differentiable=True) + ctx = self.layer.sample(ctx) + + sample = torch.randn(N, self.layer.num_features, self.layer.num_sums_in, self.layer.num_repetitions) + sample = index_one_hot(sample, index=ctx.indices_repetition.unsqueeze(1).unsqueeze(2), dim=-1) + sample = index_one_hot(sample, index=ctx.indices_out, dim=-1) + sample.mean().backward() + self.assertTrue(self.layer.logits.grad is not None) diff --git a/simple_einet/tests/layers/test_linsum_layer.py b/simple_einet/tests/layers/test_linsum_layer.py index 9bf0575..9ea0de8 100644 --- a/simple_einet/tests/layers/test_linsum_layer.py +++ b/simple_einet/tests/layers/test_linsum_layer.py @@ -5,6 +5,7 @@ from simple_einet.abstract_layers import logits_to_log_weights from simple_einet.layers.linsum import LinsumLayer +from simple_einet.sampling_utils import index_one_hot from simple_einet.tests.layers.test_utils import get_sampling_context @@ -58,3 +59,14 @@ def test__select_weights(self, differentiable: bool): ctx = get_sampling_context(layer=self.layer, num_samples=N, is_differentiable=differentiable) weights = self.layer._select_weights(ctx, self.layer.logits) self.assertEqual(tuple(weights.shape), (N, self.layer.num_features_out, self.layer.num_sums_in)) + + def test__differentiable_sampling_has_grads(self): + N = 2 + ctx = get_sampling_context(layer=self.layer, num_samples=N, is_differentiable=True) + ctx = self.layer.sample(ctx) + + sample = torch.randn(N, self.layer.num_features, self.layer.num_sums_in, self.layer.num_repetitions) + sample = index_one_hot(sample, index=ctx.indices_repetition.unsqueeze(1).unsqueeze(2), dim=-1) + sample = index_one_hot(sample, index=ctx.indices_out, dim=-1) + sample.mean().backward() + self.assertTrue(self.layer.logits.grad is not None) diff --git a/simple_einet/tests/layers/test_mixing_layer.py b/simple_einet/tests/layers/test_mixing_layer.py index 84800af..7bb59a0 100644 --- a/simple_einet/tests/layers/test_mixing_layer.py +++ b/simple_einet/tests/layers/test_mixing_layer.py @@ -5,6 +5,7 @@ from simple_einet.abstract_layers import logits_to_log_weights from simple_einet.layers.mixing import MixingLayer +from simple_einet.sampling_utils import index_one_hot from simple_einet.tests.layers.test_utils import get_sampling_context @@ -56,3 +57,14 @@ def test__select_weights(self, differentiable: bool): ctx = get_sampling_context(layer=self.layer, num_samples=N, is_differentiable=differentiable) weights = self.layer._select_weights(ctx, self.layer.logits) self.assertEqual(tuple(weights.shape), (N, self.layer.num_features_out, self.layer.num_sums_in)) + + def test__differentiable_sampling_has_grads(self): + N = 2 + ctx = get_sampling_context(layer=self.layer, num_samples=N, is_differentiable=True) + ctx = self.layer.sample(ctx) + + sample = torch.randn(N, self.layer.num_features, self.layer.num_sums_in, self.layer.num_repetitions) + sample = index_one_hot(sample, index=ctx.indices_repetition.unsqueeze(1).unsqueeze(2), dim=-1) + sample = index_one_hot(sample, index=ctx.indices_out, dim=-1) + sample.mean().backward() + self.assertTrue(self.layer.logits.grad is not None) diff --git a/simple_einet/tests/layers/test_sum_layer.py b/simple_einet/tests/layers/test_sum_layer.py index 1e5c313..b9bd10b 100644 --- a/simple_einet/tests/layers/test_sum_layer.py +++ b/simple_einet/tests/layers/test_sum_layer.py @@ -5,6 +5,7 @@ from simple_einet.abstract_layers import logits_to_log_weights from simple_einet.layers.sum import SumLayer +from simple_einet.sampling_utils import index_one_hot from simple_einet.tests.layers.test_utils import get_sampling_context @@ -58,3 +59,14 @@ def test__condition_weights_on_evidence(self, differentiable: bool): sums = log_weights.logsumexp(dim=2) target = torch.zeros_like(sums) self.assertTrue(torch.allclose(sums, target, atol=1e-5)) + + def test__differentiable_sampling_has_grads(self): + N = 2 + ctx = get_sampling_context(layer=self.layer, num_samples=N, is_differentiable=True) + ctx = self.layer.sample(ctx) + + sample = torch.randn(N, self.layer.num_features, self.layer.num_sums_in, self.layer.num_repetitions) + sample = index_one_hot(sample, index=ctx.indices_repetition.unsqueeze(1).unsqueeze(2), dim=-1) + sample = index_one_hot(sample, index=ctx.indices_out, dim=-1) + sample.mean().backward() + self.assertTrue(self.layer.logits.grad is not None) diff --git a/simple_einet/tests/layers/test_utils.py b/simple_einet/tests/layers/test_utils.py index 3101cbb..3feecd0 100644 --- a/simple_einet/tests/layers/test_utils.py +++ b/simple_einet/tests/layers/test_utils.py @@ -7,9 +7,11 @@ def get_sampling_context(layer, num_samples: int, is_differentiable: bool = False): if is_differentiable: indices_out = torch.randint(low=0, high=layer.num_sums_out, size=(num_samples, layer.num_features_out)) - one_hot_indices_out = F.one_hot(indices_out, num_classes=layer.num_sums_out) - indices_repetition = torch.randint(low=0, high=layer.num_repetitions, size=(num_samples, 1)) - one_hot_indices_repetition = F.one_hot(indices_repetition, num_classes=layer.num_repetitions) + one_hot_indices_out = F.one_hot(indices_out, num_classes=layer.num_sums_out).float() + indices_repetition = torch.randint(low=0, high=layer.num_repetitions, size=(num_samples,)) + one_hot_indices_repetition = F.one_hot(indices_repetition, num_classes=layer.num_repetitions).float() + one_hot_indices_out.requires_grad_(True) + one_hot_indices_repetition.requires_grad_(True) return SamplingContext( num_samples=num_samples, indices_out=one_hot_indices_out,