Skip to content

Commit

Permalink
Greedy search tests are also passing
Browse files Browse the repository at this point in the history
  • Loading branch information
Waino committed Sep 2, 2024
1 parent ca4dc5c commit c154154
Show file tree
Hide file tree
Showing 3 changed files with 122 additions and 69 deletions.
105 changes: 90 additions & 15 deletions mammoth/tests/test_greedy_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ class TestGreedySearch(unittest.TestCase):

BLOCKED_SCORE = -10e20

@unittest.skip('TMP')
def test_doesnt_predict_eos_if_shorter_than_min_len(self):
# batch 0 will always predict EOS. The other batches will predict
# non-eos scores.
Expand All @@ -40,9 +39,25 @@ def test_doesnt_predict_eos_if_shorter_than_min_len(self):
eos_idx = 2
lengths = torch.randint(0, 30, (batch_sz,))
samp = GreedySearch(
0, 1, 2, 3, batch_sz, GlobalScorerStub(), min_length, False, set(), False, 30, 1.0, 1, 0, 1, False
0,
1,
2,
3,
batch_sz,
GlobalScorerStub(),
min_length,
False,
set(),
False,
30,
1.0,
1,
0,
1,
False,
device=lengths.device,
)
samp.initialize(torch.zeros((1, 1)), lengths)
samp.initialize()
all_attns = []
for i in range(min_length + 4):
word_probs = torch.full((batch_sz, n_words), -float('inf'))
Expand All @@ -66,7 +81,6 @@ def test_doesnt_predict_eos_if_shorter_than_min_len(self):
else: # i > min_length
break

@unittest.skip('TMP')
def test_returns_correct_scores_deterministic(self):
for batch_sz in [1, 13]:
for temp in [1.0, 3.0]:
Expand All @@ -77,9 +91,25 @@ def test_returns_correct_scores_deterministic(self):
eos_idx = 2
lengths = torch.randint(0, 30, (batch_sz,))
samp = GreedySearch(
0, 1, 2, 3, batch_sz, GlobalScorerStub(), 0, False, set(), False, 30, temp, 1, 0, 1, False
0,
1,
2,
3,
batch_sz,
GlobalScorerStub(),
0,
False,
set(),
False,
30,
temp,
1,
0,
1,
False,
device=lengths.device,
)
samp.initialize(torch.zeros((1, 1)), lengths)
samp.initialize()
# initial step
i = 0
word_probs = torch.full((batch_sz, n_words), -float('inf'))
Expand Down Expand Up @@ -129,7 +159,6 @@ def test_returns_correct_scores_deterministic(self):
samp.update_finished()
self.assertTrue(samp.done)

@unittest.skip('TMP')
def test_returns_correct_scores_non_deterministic(self):
for batch_sz in [1, 13]:
for temp in [1.0, 3.0]:
Expand All @@ -140,9 +169,25 @@ def test_returns_correct_scores_non_deterministic(self):
eos_idx = 2
lengths = torch.randint(0, 30, (batch_sz,))
samp = GreedySearch(
0, 1, 2, 3, batch_sz, GlobalScorerStub(), 0, False, set(), False, 30, temp, 2, 0, 1, False
0,
1,
2,
3,
batch_sz,
GlobalScorerStub(),
0,
False,
set(),
False,
30,
temp,
2,
0,
1,
False,
device=lengths.device,
)
samp.initialize(torch.zeros((1, 1)), lengths)
samp.initialize()
# initial step
i = 0
for _ in range(100):
Expand Down Expand Up @@ -217,7 +262,6 @@ def test_returns_correct_scores_non_deterministic(self):

self.assertTrue(samp.done)

@unittest.skip('TMP')
def test_returns_correct_scores_non_deterministic_beams(self):
beam_size = 10
for batch_sz in [1, 13]:
Expand All @@ -229,9 +273,25 @@ def test_returns_correct_scores_non_deterministic_beams(self):
eos_idx = 2
lengths = torch.randint(0, 30, (batch_sz,))
samp = GreedySearch(
0, 1, 2, 3, batch_sz, GlobalScorerStub(), 0, False, set(), False, 30, temp, 50, 0, beam_size, False
0,
1,
2,
3,
batch_sz,
GlobalScorerStub(),
0,
False,
set(),
False,
30,
temp,
50,
0,
beam_size,
False,
device=lengths.device,
)
samp.initialize(torch.zeros((1, 1)), lengths)
samp.initialize()
# initial step
# finish one beam
i = 0
Expand Down Expand Up @@ -308,7 +368,6 @@ def test_returns_correct_scores_non_deterministic_beams(self):

self.assertTrue(samp.done)

@unittest.skip('TMP')
def test_returns_correct_scores_non_deterministic_topp(self):
for batch_sz in [1, 13]:
for temp in [1.0, 0.3]:
Expand All @@ -319,9 +378,25 @@ def test_returns_correct_scores_non_deterministic_topp(self):
eos_idx = 2
lengths = torch.randint(0, 30, (batch_sz,))
samp = GreedySearch(
0, 1, 2, 3, batch_sz, GlobalScorerStub(), 0, False, set(), False, -1, temp, 50, 0.5, 1, False
0,
1,
2,
3,
batch_sz,
GlobalScorerStub(),
0,
False,
set(),
False,
-1,
temp,
50,
0.5,
1,
False,
device=lengths.device,
)
samp.initialize(torch.zeros((1, 1)), lengths)
samp.initialize()
# initial step
i = 0
for _ in range(100):
Expand Down
2 changes: 1 addition & 1 deletion mammoth/translate/beam_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def initialize_(self, target_prefix):
# repeat the prefix for each beam
target_prefix = tile(target_prefix, self.parallel_paths, dim=1)

super(BeamSearchBase, self).initialize(target_prefix)
super(BeamSearchBase, self).initialize(target_prefix=target_prefix)

self.best_scores = torch.full([self.batch_size], -1e10, dtype=torch.float, device=self.device)
self._beam_offset = torch.arange(
Expand Down
84 changes: 31 additions & 53 deletions mammoth/translate/greedy_search.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import torch
import torch.nn.functional as F
from einops import rearrange

from mammoth.translate.decode_strategy import DecodeStrategy
from mammoth.utils.misc import tile


def sample_topp(logits, keep_topp):
Expand Down Expand Up @@ -133,41 +135,46 @@ def __init__(
keep_topp,
beam_size,
ban_unk_token,
device,
):
super(GreedySearch, self).__init__(
pad,
bos,
eos,
unk,
batch_size,
beam_size,
global_scorer,
min_length,
block_ngram_repeat,
exclusion_tokens,
return_attention,
max_length,
ban_unk_token,
pad=pad,
bos=bos,
eos=eos,
unk=unk,
batch_size=batch_size,
parallel_paths=beam_size,
global_scorer=global_scorer,
min_length=min_length,
block_ngram_repeat=block_ngram_repeat,
exclusion_tokens=exclusion_tokens,
return_attention=return_attention,
max_length=max_length,
ban_unk_token=ban_unk_token,
device=device,
)
self.sampling_temp = sampling_temp
self.keep_topk = keep_topk
self.keep_topp = keep_topp
self.topk_scores = None
self.beam_size = beam_size

def initialize(self, memory_bank, src_lengths, src_map=None, device=None, target_prefix=None):
def initialize(self, target_prefix=None):
"""Initialize for decoding."""
(fn_map_state, memory_bank, src_map, target_prefix) = self.initialize_tile(
memory_bank, src_lengths, src_map, target_prefix
if target_prefix is not None:
if target_prefix.ndim == 1:
target_prefix = rearrange(target_prefix, 'b -> 1 b')
# repeat the prefix for each beam
target_prefix = tile(target_prefix, self.parallel_paths, dim=1)

super(GreedySearch, self).initialize(target_prefix=target_prefix)
self.select_indices = torch.arange(self.batch_size * self.beam_size, dtype=torch.long, device=self.device)
self.original_batch_idx = tile(
torch.arange(self.batch_size, dtype=torch.long, device=self.device),
self.parallel_paths,
dim=0
)
if device is None:
device = self.get_device_from_memory_bank(memory_bank)

super(GreedySearch, self).initialize(memory_bank, src_lengths, src_map, device, target_prefix)
self.select_indices = torch.arange(self.batch_size * self.beam_size, dtype=torch.long, device=device)
self.original_batch_idx = fn_map_state(torch.arange(self.batch_size, dtype=torch.long, device=device), dim=0)
self.beams_scores = torch.zeros((self.batch_size * self.beam_size, 1), dtype=torch.float, device=device)
return fn_map_state, memory_bank, self.memory_lengths, src_map
self.beams_scores = torch.zeros((self.batch_size * self.beam_size, 1), dtype=torch.float, device=self.device)

@property
def current_predictions(self):
Expand Down Expand Up @@ -257,32 +264,3 @@ def update_finished(self):
self.select_indices = is_alive.nonzero(as_tuple=False).view(-1)
self.original_batch_idx = self.original_batch_idx[is_alive]
self.maybe_update_target_prefix(self.select_indices)


class GreedySearchLM(GreedySearch):
def update_finished(self):
super(GreedySearchLM, self).update_finished()
self.update_memory_lengths()

def update_memory_lengths(self):
is_alive = ~self.is_finished.view(-1)
self.memory_lengths = self.memory_lengths[is_alive]

def advance(self, log_probs, attn):
super(GreedySearchLM, self).advance(log_probs, attn)

# in LM task memory_lengths is associated with currently generated src
# and therefore needs to follow the generation
self.memory_lengths += 1

def initialize(self, src, src_lengths, src_map=None, device=None, target_prefix=None):
"""Initialize for decoding."""

if device is None:
device = src.device

(fn_map_state, _, self.memory_lengths, src_map) = super(GreedySearchLM, self).initialize(
None, src_lengths, src_map, device, target_prefix
)

return fn_map_state, src, self.memory_lengths, src_map

0 comments on commit c154154

Please sign in to comment.