From c154154ebfb4356c1746312d3b6ab7453da82212 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stig-Arne=20Gr=C3=B6nroos?= Date: Mon, 2 Sep 2024 18:10:22 +0300 Subject: [PATCH] Greedy search tests are also passing --- mammoth/tests/test_greedy_search.py | 105 ++++++++++++++++++++++++---- mammoth/translate/beam_search.py | 2 +- mammoth/translate/greedy_search.py | 84 ++++++++-------------- 3 files changed, 122 insertions(+), 69 deletions(-) diff --git a/mammoth/tests/test_greedy_search.py b/mammoth/tests/test_greedy_search.py index 03805ef1..3b120bcc 100644 --- a/mammoth/tests/test_greedy_search.py +++ b/mammoth/tests/test_greedy_search.py @@ -28,7 +28,6 @@ class TestGreedySearch(unittest.TestCase): BLOCKED_SCORE = -10e20 - @unittest.skip('TMP') def test_doesnt_predict_eos_if_shorter_than_min_len(self): # batch 0 will always predict EOS. The other batches will predict # non-eos scores. @@ -40,9 +39,25 @@ def test_doesnt_predict_eos_if_shorter_than_min_len(self): eos_idx = 2 lengths = torch.randint(0, 30, (batch_sz,)) samp = GreedySearch( - 0, 1, 2, 3, batch_sz, GlobalScorerStub(), min_length, False, set(), False, 30, 1.0, 1, 0, 1, False + 0, + 1, + 2, + 3, + batch_sz, + GlobalScorerStub(), + min_length, + False, + set(), + False, + 30, + 1.0, + 1, + 0, + 1, + False, + device=lengths.device, ) - samp.initialize(torch.zeros((1, 1)), lengths) + samp.initialize() all_attns = [] for i in range(min_length + 4): word_probs = torch.full((batch_sz, n_words), -float('inf')) @@ -66,7 +81,6 @@ def test_doesnt_predict_eos_if_shorter_than_min_len(self): else: # i > min_length break - @unittest.skip('TMP') def test_returns_correct_scores_deterministic(self): for batch_sz in [1, 13]: for temp in [1.0, 3.0]: @@ -77,9 +91,25 @@ def test_returns_correct_scores_deterministic(self): eos_idx = 2 lengths = torch.randint(0, 30, (batch_sz,)) samp = GreedySearch( - 0, 1, 2, 3, batch_sz, GlobalScorerStub(), 0, False, set(), False, 30, temp, 1, 0, 1, False + 0, + 1, + 2, + 3, + batch_sz, + GlobalScorerStub(), + 0, + False, + set(), + False, + 30, + temp, + 1, + 0, + 1, + False, + device=lengths.device, ) - samp.initialize(torch.zeros((1, 1)), lengths) + samp.initialize() # initial step i = 0 word_probs = torch.full((batch_sz, n_words), -float('inf')) @@ -129,7 +159,6 @@ def test_returns_correct_scores_deterministic(self): samp.update_finished() self.assertTrue(samp.done) - @unittest.skip('TMP') def test_returns_correct_scores_non_deterministic(self): for batch_sz in [1, 13]: for temp in [1.0, 3.0]: @@ -140,9 +169,25 @@ def test_returns_correct_scores_non_deterministic(self): eos_idx = 2 lengths = torch.randint(0, 30, (batch_sz,)) samp = GreedySearch( - 0, 1, 2, 3, batch_sz, GlobalScorerStub(), 0, False, set(), False, 30, temp, 2, 0, 1, False + 0, + 1, + 2, + 3, + batch_sz, + GlobalScorerStub(), + 0, + False, + set(), + False, + 30, + temp, + 2, + 0, + 1, + False, + device=lengths.device, ) - samp.initialize(torch.zeros((1, 1)), lengths) + samp.initialize() # initial step i = 0 for _ in range(100): @@ -217,7 +262,6 @@ def test_returns_correct_scores_non_deterministic(self): self.assertTrue(samp.done) - @unittest.skip('TMP') def test_returns_correct_scores_non_deterministic_beams(self): beam_size = 10 for batch_sz in [1, 13]: @@ -229,9 +273,25 @@ def test_returns_correct_scores_non_deterministic_beams(self): eos_idx = 2 lengths = torch.randint(0, 30, (batch_sz,)) samp = GreedySearch( - 0, 1, 2, 3, batch_sz, GlobalScorerStub(), 0, False, set(), False, 30, temp, 50, 0, beam_size, False + 0, + 1, + 2, + 3, + batch_sz, + GlobalScorerStub(), + 0, + False, + set(), + False, + 30, + temp, + 50, + 0, + beam_size, + False, + device=lengths.device, ) - samp.initialize(torch.zeros((1, 1)), lengths) + samp.initialize() # initial step # finish one beam i = 0 @@ -308,7 +368,6 @@ def test_returns_correct_scores_non_deterministic_beams(self): self.assertTrue(samp.done) - @unittest.skip('TMP') def test_returns_correct_scores_non_deterministic_topp(self): for batch_sz in [1, 13]: for temp in [1.0, 0.3]: @@ -319,9 +378,25 @@ def test_returns_correct_scores_non_deterministic_topp(self): eos_idx = 2 lengths = torch.randint(0, 30, (batch_sz,)) samp = GreedySearch( - 0, 1, 2, 3, batch_sz, GlobalScorerStub(), 0, False, set(), False, -1, temp, 50, 0.5, 1, False + 0, + 1, + 2, + 3, + batch_sz, + GlobalScorerStub(), + 0, + False, + set(), + False, + -1, + temp, + 50, + 0.5, + 1, + False, + device=lengths.device, ) - samp.initialize(torch.zeros((1, 1)), lengths) + samp.initialize() # initial step i = 0 for _ in range(100): diff --git a/mammoth/translate/beam_search.py b/mammoth/translate/beam_search.py index d76565e6..dd211813 100644 --- a/mammoth/translate/beam_search.py +++ b/mammoth/translate/beam_search.py @@ -125,7 +125,7 @@ def initialize_(self, target_prefix): # repeat the prefix for each beam target_prefix = tile(target_prefix, self.parallel_paths, dim=1) - super(BeamSearchBase, self).initialize(target_prefix) + super(BeamSearchBase, self).initialize(target_prefix=target_prefix) self.best_scores = torch.full([self.batch_size], -1e10, dtype=torch.float, device=self.device) self._beam_offset = torch.arange( diff --git a/mammoth/translate/greedy_search.py b/mammoth/translate/greedy_search.py index 91251b32..02305890 100644 --- a/mammoth/translate/greedy_search.py +++ b/mammoth/translate/greedy_search.py @@ -1,7 +1,9 @@ import torch import torch.nn.functional as F +from einops import rearrange from mammoth.translate.decode_strategy import DecodeStrategy +from mammoth.utils.misc import tile def sample_topp(logits, keep_topp): @@ -133,21 +135,23 @@ def __init__( keep_topp, beam_size, ban_unk_token, + device, ): super(GreedySearch, self).__init__( - pad, - bos, - eos, - unk, - batch_size, - beam_size, - global_scorer, - min_length, - block_ngram_repeat, - exclusion_tokens, - return_attention, - max_length, - ban_unk_token, + pad=pad, + bos=bos, + eos=eos, + unk=unk, + batch_size=batch_size, + parallel_paths=beam_size, + global_scorer=global_scorer, + min_length=min_length, + block_ngram_repeat=block_ngram_repeat, + exclusion_tokens=exclusion_tokens, + return_attention=return_attention, + max_length=max_length, + ban_unk_token=ban_unk_token, + device=device, ) self.sampling_temp = sampling_temp self.keep_topk = keep_topk @@ -155,19 +159,22 @@ def __init__( self.topk_scores = None self.beam_size = beam_size - def initialize(self, memory_bank, src_lengths, src_map=None, device=None, target_prefix=None): + def initialize(self, target_prefix=None): """Initialize for decoding.""" - (fn_map_state, memory_bank, src_map, target_prefix) = self.initialize_tile( - memory_bank, src_lengths, src_map, target_prefix + if target_prefix is not None: + if target_prefix.ndim == 1: + target_prefix = rearrange(target_prefix, 'b -> 1 b') + # repeat the prefix for each beam + target_prefix = tile(target_prefix, self.parallel_paths, dim=1) + + super(GreedySearch, self).initialize(target_prefix=target_prefix) + self.select_indices = torch.arange(self.batch_size * self.beam_size, dtype=torch.long, device=self.device) + self.original_batch_idx = tile( + torch.arange(self.batch_size, dtype=torch.long, device=self.device), + self.parallel_paths, + dim=0 ) - if device is None: - device = self.get_device_from_memory_bank(memory_bank) - - super(GreedySearch, self).initialize(memory_bank, src_lengths, src_map, device, target_prefix) - self.select_indices = torch.arange(self.batch_size * self.beam_size, dtype=torch.long, device=device) - self.original_batch_idx = fn_map_state(torch.arange(self.batch_size, dtype=torch.long, device=device), dim=0) - self.beams_scores = torch.zeros((self.batch_size * self.beam_size, 1), dtype=torch.float, device=device) - return fn_map_state, memory_bank, self.memory_lengths, src_map + self.beams_scores = torch.zeros((self.batch_size * self.beam_size, 1), dtype=torch.float, device=self.device) @property def current_predictions(self): @@ -257,32 +264,3 @@ def update_finished(self): self.select_indices = is_alive.nonzero(as_tuple=False).view(-1) self.original_batch_idx = self.original_batch_idx[is_alive] self.maybe_update_target_prefix(self.select_indices) - - -class GreedySearchLM(GreedySearch): - def update_finished(self): - super(GreedySearchLM, self).update_finished() - self.update_memory_lengths() - - def update_memory_lengths(self): - is_alive = ~self.is_finished.view(-1) - self.memory_lengths = self.memory_lengths[is_alive] - - def advance(self, log_probs, attn): - super(GreedySearchLM, self).advance(log_probs, attn) - - # in LM task memory_lengths is associated with currently generated src - # and therefore needs to follow the generation - self.memory_lengths += 1 - - def initialize(self, src, src_lengths, src_map=None, device=None, target_prefix=None): - """Initialize for decoding.""" - - if device is None: - device = src.device - - (fn_map_state, _, self.memory_lengths, src_map) = super(GreedySearchLM, self).initialize( - None, src_lengths, src_map, device, target_prefix - ) - - return fn_map_state, src, self.memory_lengths, src_map