Skip to content

Commit

Permalink
Remove LM stuff
Browse files Browse the repository at this point in the history
  • Loading branch information
Waino committed Aug 19, 2024
1 parent 671cf30 commit a8c1697
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 131 deletions.
82 changes: 0 additions & 82 deletions mammoth/tests/test_beam_search.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import unittest
from mammoth.translate.beam_search import BeamSearch, GNMTGlobalScorer
from mammoth.translate.beam_search import BeamSearchLM

from copy import deepcopy

Expand Down Expand Up @@ -583,84 +582,3 @@ def test_beam_advance_against_known_reference(self):
expected_beam_scores = self.first_step(beam, expected_beam_scores, 3)
expected_beam_scores = self.second_step(beam, expected_beam_scores, 4)
self.third_step(beam, expected_beam_scores, 5)


class TestBeamSearchLM(TestBeamSearchAgainstReferenceCase):
def finish_first_beam_step(self, beam):
scores_finish = torch.log_softmax(
torch.tensor(
[
[0, 0, 10000, 0, 5000, 0.51, 0.2, 0], # beam 0 shouldn't cont
[100000, 100001, 0, 0, 0, 0, 0, 0],
[0, 100000, 0, 0, 0, 5000, 0, 0],
[0, 0, 0, 0.2, 0.2, 0.2, 0.2, 0.2],
[0, 0, 0, 0, 0.2, 0.2, 0.2, 0.2],
] # beam 4 -> beam 1 should die
),
dim=1,
)
scores_finish = scores_finish.repeat(self.BATCH_SZ, 1)
scores_finish[: self.BEAM_SZ, beam.eos] = 0
beam.advance(scores_finish, None)

any_finished = beam.is_finished.any()
if any_finished:
beam.update_finished()

def test_beam_lm_increase_memory_length(self):
beam = BeamSearchLM(
self.BEAM_SZ,
self.BATCH_SZ,
0,
1,
2,
3,
self.N_BEST,
GlobalScorerStub(),
0,
30,
False,
0,
set(),
False,
0.0,
False,
)
device_init = torch.zeros(1, 1)
src_lengths = torch.randint(0, 30, (self.BATCH_SZ,))
fn_map_state, _, _, _ = beam.initialize(device_init, src_lengths)
expected_beam_scores = self.init_step(beam, 1)
expected_beam_scores = self.first_step(beam, expected_beam_scores, 1)
expected_beam_scores = self.second_step(beam, expected_beam_scores, 1)
self.third_step(beam, expected_beam_scores, 1)

n_steps = beam.alive_seq.shape[-1] - 1
self.assertTrue(beam.memory_lengths.equal(n_steps + fn_map_state(src_lengths, dim=0)))

def test_beam_lm_update_memory_length_when_finished(self):
beam = BeamSearchLM(
self.BEAM_SZ,
self.BATCH_SZ,
0,
1,
2,
3,
self.N_BEST,
GlobalScorerStub(),
0,
30,
False,
0,
set(),
False,
0.0,
False,
)
device_init = torch.zeros(1, 1)
src_lengths = torch.randint(0, 30, (self.BATCH_SZ,))
fn_map_state, _, _, _ = beam.initialize(device_init, src_lengths)
self.init_step(beam, 1)
self.finish_first_beam_step(beam)

n_steps = beam.alive_seq.shape[-1] - 1
self.assertTrue(beam.memory_lengths.equal(n_steps + fn_map_state(src_lengths[1:], dim=0)))
2 changes: 1 addition & 1 deletion mammoth/tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import mammoth
import mammoth.opts
from mammoth.model_builder import build_encoder, build_decoder
from mammoth.model_builder import build_xcoder
from mammoth.inputters.vocab import Vocab, DEFAULT_SPECIALS
from mammoth.utils.parse import ArgumentParser

Expand Down
36 changes: 2 additions & 34 deletions mammoth/tests/test_translator.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,2 @@
import unittest
from mammoth.translate import GeneratorLM
import torch


class TestGeneratorLM(unittest.TestCase):
def test_split_src_to_prevent_padding_target_prefix_is_none_when_equal_size( # noqa: E501
self,
):
src = torch.randint(0, 10, (5, 6))
src_lengths = 5 * torch.ones(5)
(
src,
src_lengths,
target_prefix,
) = GeneratorLM.split_src_to_prevent_padding(src, src_lengths)
self.assertIsNone(target_prefix)

def test_split_src_to_prevent_padding_target_prefix_is_ok_when_different_size( # noqa: E501
self,
):
default_length = 5
src = torch.randint(0, 10, (default_length, 6))
src_lengths = default_length * torch.ones(6, dtype=torch.int)
new_length = 4
src_lengths[1] = new_length
(
src,
src_lengths,
target_prefix,
) = GeneratorLM.split_src_to_prevent_padding(src, src_lengths)
self.assertTupleEqual(src.shape, (new_length, 6))
self.assertTupleEqual(target_prefix.shape, (1, 6))
self.assertTrue(src_lengths.equal(new_length * torch.ones(6, dtype=torch.int)))
# import unittest
# import torch
29 changes: 15 additions & 14 deletions mammoth/translate/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,6 +515,8 @@ def _translate(
# batch_size=batch_size,
# batch_type=batch_type,
task=self.task,
stride=1,
offset=0,
).to(self._device)

batches = build_dataloader(
Expand Down Expand Up @@ -806,17 +808,17 @@ def translate_batch(self, batch, src_vocabs, attn_debug):
def _run_encoder(self, active_encoder, batch):
src = rearrange(batch.src.tensor, 't b 1 -> b t')
src_mask = rearrange(batch.src.mask, 't b -> b t')
quishape('src in _run_encoder', src)
quishape('src_mask in _run_encoder', src_mask)
# quishape('src in _run_encoder', src)
# quishape('src_mask in _run_encoder', src_mask)
encoder_output = active_encoder(
x=src,
mask=src_mask,
return_embeddings=True,
)
quishape('encoder_output in _run_encoder', encoder_output)
# quishape('encoder_output in _run_encoder', encoder_output)

encoder_output, alphas = self.model.attention_bridge(encoder_output, src_mask)
quishape('encoder_output after AB', encoder_output)
# quishape('encoder_output after AB', encoder_output)
if self.model.attention_bridge.is_fixed_length:
# turn off masking in the transformer decoder
src_mask = None
Expand Down Expand Up @@ -849,12 +851,12 @@ def _translate_batch_with_strategy(self, batch, src_vocabs, decode_strategy):
adapter_ids=metadata.decoder_adapter_ids,
)

quishape('batch.src.tensor', batch.src.tensor)
quishape('batch.src.mask', batch.src.mask)
# quishape('batch.src.tensor', batch.src.tensor)
# quishape('batch.src.mask', batch.src.mask)

# (2) Run the encoder on the src
encoder_output, src_mask = self._run_encoder(active_encoder, batch)
quishape('src_mask', src_mask)
# quishape('src_mask', src_mask)

# (3) Decode and score the gold targets
gold_score = self._gold_score(
Expand All @@ -876,14 +878,14 @@ def _translate_batch_with_strategy(self, batch, src_vocabs, decode_strategy):

# (5) Begin decoding step by step:
for step in range(decode_strategy.max_length):
quishape('alive_seq', decode_strategy.alive_seq)
# quishape('alive_seq', decode_strategy.alive_seq)
decoder_input = decode_strategy.alive_seq
encoder_output_tiled = tile(encoder_output, decode_strategy.parallel_paths, dim=0)
src_mask_tiled = tile(src_mask, decode_strategy.parallel_paths, dim=0)
quishape('decoder_input', decoder_input)
quishape('encoder_output', encoder_output)
quishape('encoder_output_tiled', encoder_output_tiled)
quishape('src_mask_tiled', src_mask_tiled)
# quishape('decoder_input', decoder_input)
# quishape('encoder_output', encoder_output)
# quishape('encoder_output_tiled', encoder_output_tiled)
# quishape('src_mask_tiled', src_mask_tiled)

logits, new_cache = active_decoder(
decoder_input,
Expand All @@ -895,8 +897,7 @@ def _translate_batch_with_strategy(self, batch, src_vocabs, decode_strategy):
cache=decode_strategy.cache,
seq_start_pos=seq_start_pos,
)
quishape('logits', logits)
print('new_cache len', len(new_cache))
# quishape('logits', logits)
# new_cache is a list of LayerIntermediates objects, one for each layer_stack

if active_decoder.can_cache_kv:
Expand Down

0 comments on commit a8c1697

Please sign in to comment.