Skip to content

Commit

Permalink
tests: Add logit grad check for diff sampling
Browse files Browse the repository at this point in the history
  • Loading branch information
braun-steven committed Oct 4, 2023
1 parent 3e9a1f8 commit d4f1a2f
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 3 deletions.
12 changes: 12 additions & 0 deletions simple_einet/tests/layers/test_einsum_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from simple_einet.abstract_layers import logits_to_log_weights
from simple_einet.layers.einsum import EinsumLayer
from simple_einet.sampling_utils import index_one_hot
from simple_einet.tests.layers.test_utils import get_sampling_context


Expand Down Expand Up @@ -58,3 +59,14 @@ def test__condition_weights_on_evidence(self, differentiable: bool):
sums = log_weights.logsumexp(dim=2)
target = torch.zeros_like(sums)
self.assertTrue(torch.allclose(sums, target, atol=1e-5))

def test__differentiable_sampling_has_grads(self):
N = 2
ctx = get_sampling_context(layer=self.layer, num_samples=N, is_differentiable=True)
ctx = self.layer.sample(ctx)

sample = torch.randn(N, self.layer.num_features, self.layer.num_sums_in, self.layer.num_repetitions)
sample = index_one_hot(sample, index=ctx.indices_repetition.unsqueeze(1).unsqueeze(2), dim=-1)
sample = index_one_hot(sample, index=ctx.indices_out, dim=-1)
sample.mean().backward()
self.assertTrue(self.layer.logits.grad is not None)
12 changes: 12 additions & 0 deletions simple_einet/tests/layers/test_linsum_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from simple_einet.abstract_layers import logits_to_log_weights
from simple_einet.layers.linsum import LinsumLayer
from simple_einet.sampling_utils import index_one_hot
from simple_einet.tests.layers.test_utils import get_sampling_context


Expand Down Expand Up @@ -58,3 +59,14 @@ def test__select_weights(self, differentiable: bool):
ctx = get_sampling_context(layer=self.layer, num_samples=N, is_differentiable=differentiable)
weights = self.layer._select_weights(ctx, self.layer.logits)
self.assertEqual(tuple(weights.shape), (N, self.layer.num_features_out, self.layer.num_sums_in))

def test__differentiable_sampling_has_grads(self):
N = 2
ctx = get_sampling_context(layer=self.layer, num_samples=N, is_differentiable=True)
ctx = self.layer.sample(ctx)

sample = torch.randn(N, self.layer.num_features, self.layer.num_sums_in, self.layer.num_repetitions)
sample = index_one_hot(sample, index=ctx.indices_repetition.unsqueeze(1).unsqueeze(2), dim=-1)
sample = index_one_hot(sample, index=ctx.indices_out, dim=-1)
sample.mean().backward()
self.assertTrue(self.layer.logits.grad is not None)
12 changes: 12 additions & 0 deletions simple_einet/tests/layers/test_mixing_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from simple_einet.abstract_layers import logits_to_log_weights
from simple_einet.layers.mixing import MixingLayer
from simple_einet.sampling_utils import index_one_hot
from simple_einet.tests.layers.test_utils import get_sampling_context


Expand Down Expand Up @@ -56,3 +57,14 @@ def test__select_weights(self, differentiable: bool):
ctx = get_sampling_context(layer=self.layer, num_samples=N, is_differentiable=differentiable)
weights = self.layer._select_weights(ctx, self.layer.logits)
self.assertEqual(tuple(weights.shape), (N, self.layer.num_features_out, self.layer.num_sums_in))

def test__differentiable_sampling_has_grads(self):
N = 2
ctx = get_sampling_context(layer=self.layer, num_samples=N, is_differentiable=True)
ctx = self.layer.sample(ctx)

sample = torch.randn(N, self.layer.num_features, self.layer.num_sums_in, self.layer.num_repetitions)
sample = index_one_hot(sample, index=ctx.indices_repetition.unsqueeze(1).unsqueeze(2), dim=-1)
sample = index_one_hot(sample, index=ctx.indices_out, dim=-1)
sample.mean().backward()
self.assertTrue(self.layer.logits.grad is not None)
12 changes: 12 additions & 0 deletions simple_einet/tests/layers/test_sum_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from simple_einet.abstract_layers import logits_to_log_weights
from simple_einet.layers.sum import SumLayer
from simple_einet.sampling_utils import index_one_hot
from simple_einet.tests.layers.test_utils import get_sampling_context


Expand Down Expand Up @@ -58,3 +59,14 @@ def test__condition_weights_on_evidence(self, differentiable: bool):
sums = log_weights.logsumexp(dim=2)
target = torch.zeros_like(sums)
self.assertTrue(torch.allclose(sums, target, atol=1e-5))

def test__differentiable_sampling_has_grads(self):
N = 2
ctx = get_sampling_context(layer=self.layer, num_samples=N, is_differentiable=True)
ctx = self.layer.sample(ctx)

sample = torch.randn(N, self.layer.num_features, self.layer.num_sums_in, self.layer.num_repetitions)
sample = index_one_hot(sample, index=ctx.indices_repetition.unsqueeze(1).unsqueeze(2), dim=-1)
sample = index_one_hot(sample, index=ctx.indices_out, dim=-1)
sample.mean().backward()
self.assertTrue(self.layer.logits.grad is not None)
8 changes: 5 additions & 3 deletions simple_einet/tests/layers/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@
def get_sampling_context(layer, num_samples: int, is_differentiable: bool = False):
if is_differentiable:
indices_out = torch.randint(low=0, high=layer.num_sums_out, size=(num_samples, layer.num_features_out))
one_hot_indices_out = F.one_hot(indices_out, num_classes=layer.num_sums_out)
indices_repetition = torch.randint(low=0, high=layer.num_repetitions, size=(num_samples, 1))
one_hot_indices_repetition = F.one_hot(indices_repetition, num_classes=layer.num_repetitions)
one_hot_indices_out = F.one_hot(indices_out, num_classes=layer.num_sums_out).float()
indices_repetition = torch.randint(low=0, high=layer.num_repetitions, size=(num_samples,))
one_hot_indices_repetition = F.one_hot(indices_repetition, num_classes=layer.num_repetitions).float()
one_hot_indices_out.requires_grad_(True)
one_hot_indices_repetition.requires_grad_(True)
return SamplingContext(
num_samples=num_samples,
indices_out=one_hot_indices_out,
Expand Down

0 comments on commit d4f1a2f

Please sign in to comment.