From d220931573a71a4b76e4a83afaee001f1ae30360 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stig-Arne=20Gr=C3=B6nroos?= Date: Mon, 2 Sep 2024 17:46:49 +0300 Subject: [PATCH] Tests are passing --- mammoth/tests/test_beam_search.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/mammoth/tests/test_beam_search.py b/mammoth/tests/test_beam_search.py index d7c94bed..9af69561 100644 --- a/mammoth/tests/test_beam_search.py +++ b/mammoth/tests/test_beam_search.py @@ -1,5 +1,6 @@ import unittest from mammoth.translate.beam_search import BeamSearch, GNMTGlobalScorer +from mammoth.utils.misc import tile from copy import deepcopy @@ -121,13 +122,13 @@ def test_advance_with_some_repeats_gets_blocked(self): # on initial round, only predicted scores for beam 0 # matter. Make two predictions. Top one will be repeated # in beam zero, second one will live on in beam 1. - word_probs[:, repeat_idx] = repeat_score - word_probs[:, repeat_idx + i + 1] = no_repeat_score + word_probs[0::beam_sz, repeat_idx] = repeat_score + word_probs[0::beam_sz, repeat_idx + i + 1] = no_repeat_score else: # predict the same thing in beam 0 - word_probs[0, repeat_idx] = 0 + word_probs[0::beam_sz, repeat_idx] = 0 # continue pushing around what beam 1 predicts - word_probs[1:, repeat_idx + i + 1] = 0 + word_probs[1::beam_sz, repeat_idx + i + 1] = 0 attns = torch.randn(1, batch_sz * beam_sz, 53) beam.advance(word_probs, attns) if i < ngram_repeat: @@ -140,7 +141,7 @@ def test_advance_with_some_repeats_gets_blocked(self): expected = torch.full([batch_sz, beam_sz], float("-inf")) expected[:, 0] = no_repeat_score expected[:, 1] = self.BLOCKED_SCORE - # self.assertTrue(beam.topk_log_probs.equal(expected)) + self.assertTrue(beam.topk_log_probs.equal(expected)) else: # now beam 0 dies (along with the others), beam 1 -> beam 0 self.assertFalse(beam.topk_log_probs[:, 0].eq(self.BLOCKED_SCORE).any()) @@ -223,7 +224,6 @@ def test_doesnt_predict_eos_if_shorter_than_min_len(self): valid_score_dist = torch.log_softmax(torch.tensor([6.0, 5.0, 4.0, 3.0, 2.0, 1.0]), dim=0) min_length = 5 eos_idx = 2 - lengths = torch.randint(0, 30, (batch_sz,)) device_init = torch.zeros(1, 1) beam = BeamSearch( beam_sz, @@ -244,7 +244,7 @@ def test_doesnt_predict_eos_if_shorter_than_min_len(self): False, device=device_init.device, ) - beam.initialize(lengths) + beam.initialize() all_attns = [] for i in range(min_length + 4): # non-interesting beams are going to get dummy values @@ -269,6 +269,7 @@ def test_doesnt_predict_eos_if_shorter_than_min_len(self): beam.advance(word_probs, attns) if i < min_length: expected_score_dist = (i + 1) * valid_score_dist[1:].unsqueeze(0) + # Note that when batch_sz is > 1, expected is broadcast across the batch self.assertTrue(beam.topk_log_probs.allclose(expected_score_dist)) elif i == min_length: # now the top beam has ended and no others have @@ -350,6 +351,7 @@ def test_beam_is_done_when_n_best_beams_eos_using_min_length(self): beam.update_finished() self.assertTrue(beam.done) + @unittest.skip('attention no longer returned') def test_beam_returns_attn_with_correct_length(self): beam_sz = 5 batch_sz = 3 @@ -358,7 +360,7 @@ def test_beam_returns_attn_with_correct_length(self): valid_score_dist = torch.log_softmax(torch.tensor([6.0, 5.0, 4.0, 3.0, 2.0, 1.0]), dim=0) min_length = 5 eos_idx = 2 - inp_lens = torch.randint(1, 30, (batch_sz,)) + inp_lens = tile(torch.randint(1, 30, (1, batch_sz,)), beam_sz, dim=1) device_init = torch.zeros(1, 1) beam = BeamSearch( beam_sz, @@ -379,8 +381,7 @@ def test_beam_returns_attn_with_correct_length(self): False, device=device_init.device, ) - _, _, inp_lens, _ = beam.initialize(inp_lens) - # inp_lens is tiled in initialize, reassign to make attn match + beam.initialize(None) for i in range(min_length + 2): # non-interesting beams are going to get dummy values word_probs = torch.full((batch_sz * beam_sz, n_words), -float('inf'))