Skip to content

Commit

Permalink
Tests are passing
Browse files Browse the repository at this point in the history
  • Loading branch information
Waino committed Sep 2, 2024
1 parent 283d72c commit d220931
Showing 1 changed file with 11 additions and 10 deletions.
21 changes: 11 additions & 10 deletions mammoth/tests/test_beam_search.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import unittest
from mammoth.translate.beam_search import BeamSearch, GNMTGlobalScorer
from mammoth.utils.misc import tile

from copy import deepcopy

Expand Down Expand Up @@ -121,13 +122,13 @@ def test_advance_with_some_repeats_gets_blocked(self):
# on initial round, only predicted scores for beam 0
# matter. Make two predictions. Top one will be repeated
# in beam zero, second one will live on in beam 1.
word_probs[:, repeat_idx] = repeat_score
word_probs[:, repeat_idx + i + 1] = no_repeat_score
word_probs[0::beam_sz, repeat_idx] = repeat_score
word_probs[0::beam_sz, repeat_idx + i + 1] = no_repeat_score
else:
# predict the same thing in beam 0
word_probs[0, repeat_idx] = 0
word_probs[0::beam_sz, repeat_idx] = 0
# continue pushing around what beam 1 predicts
word_probs[1:, repeat_idx + i + 1] = 0
word_probs[1::beam_sz, repeat_idx + i + 1] = 0
attns = torch.randn(1, batch_sz * beam_sz, 53)
beam.advance(word_probs, attns)
if i < ngram_repeat:
Expand All @@ -140,7 +141,7 @@ def test_advance_with_some_repeats_gets_blocked(self):
expected = torch.full([batch_sz, beam_sz], float("-inf"))
expected[:, 0] = no_repeat_score
expected[:, 1] = self.BLOCKED_SCORE
# self.assertTrue(beam.topk_log_probs.equal(expected))
self.assertTrue(beam.topk_log_probs.equal(expected))
else:
# now beam 0 dies (along with the others), beam 1 -> beam 0
self.assertFalse(beam.topk_log_probs[:, 0].eq(self.BLOCKED_SCORE).any())
Expand Down Expand Up @@ -223,7 +224,6 @@ def test_doesnt_predict_eos_if_shorter_than_min_len(self):
valid_score_dist = torch.log_softmax(torch.tensor([6.0, 5.0, 4.0, 3.0, 2.0, 1.0]), dim=0)
min_length = 5
eos_idx = 2
lengths = torch.randint(0, 30, (batch_sz,))
device_init = torch.zeros(1, 1)
beam = BeamSearch(
beam_sz,
Expand All @@ -244,7 +244,7 @@ def test_doesnt_predict_eos_if_shorter_than_min_len(self):
False,
device=device_init.device,
)
beam.initialize(lengths)
beam.initialize()
all_attns = []
for i in range(min_length + 4):
# non-interesting beams are going to get dummy values
Expand All @@ -269,6 +269,7 @@ def test_doesnt_predict_eos_if_shorter_than_min_len(self):
beam.advance(word_probs, attns)
if i < min_length:
expected_score_dist = (i + 1) * valid_score_dist[1:].unsqueeze(0)
# Note that when batch_sz is > 1, expected is broadcast across the batch
self.assertTrue(beam.topk_log_probs.allclose(expected_score_dist))
elif i == min_length:
# now the top beam has ended and no others have
Expand Down Expand Up @@ -350,6 +351,7 @@ def test_beam_is_done_when_n_best_beams_eos_using_min_length(self):
beam.update_finished()
self.assertTrue(beam.done)

@unittest.skip('attention no longer returned')
def test_beam_returns_attn_with_correct_length(self):
beam_sz = 5
batch_sz = 3
Expand All @@ -358,7 +360,7 @@ def test_beam_returns_attn_with_correct_length(self):
valid_score_dist = torch.log_softmax(torch.tensor([6.0, 5.0, 4.0, 3.0, 2.0, 1.0]), dim=0)
min_length = 5
eos_idx = 2
inp_lens = torch.randint(1, 30, (batch_sz,))
inp_lens = tile(torch.randint(1, 30, (1, batch_sz,)), beam_sz, dim=1)
device_init = torch.zeros(1, 1)
beam = BeamSearch(
beam_sz,
Expand All @@ -379,8 +381,7 @@ def test_beam_returns_attn_with_correct_length(self):
False,
device=device_init.device,
)
_, _, inp_lens, _ = beam.initialize(inp_lens)
# inp_lens is tiled in initialize, reassign to make attn match
beam.initialize(None)
for i in range(min_length + 2):
# non-interesting beams are going to get dummy values
word_probs = torch.full((batch_sz * beam_sz, n_words), -float('inf'))
Expand Down

0 comments on commit d220931

Please sign in to comment.