Skip to content

Commit

Permalink
WIP: translation
Browse files Browse the repository at this point in the history
Translation runs, and the output is somehow related to the input.
However, it definitely doesn't work correctly.
For example, you get (x * 2) - 1 rows of output for x rows of input.
  • Loading branch information
Waino committed Aug 19, 2024
1 parent e6f564a commit fd492d9
Show file tree
Hide file tree
Showing 8 changed files with 191 additions and 569 deletions.
4 changes: 2 additions & 2 deletions mammoth/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from functools import partial
from torch.nn.init import xavier_uniform_
from typing import Optional, List, Dict, Tuple
# from x_transformers import TransformerWrapper
from x_transformers import TransformerWrapper
from x_transformers.x_transformers import TokenEmbedding

from mammoth.distributed.components import (
Expand All @@ -32,7 +32,7 @@
from mammoth.utils.misc import use_gpu
from mammoth.utils.model_saver import load_frame_checkpoint
from mammoth.utils.parse import ArgumentParser
from mammoth.utils.transformer_wrapper import TransformerWrapper
# from mammoth.utils.transformer_wrapper import TransformerWrapper


def _combine_ordered_dicts(input_dicts: Dict[str, OrderedDict]) -> OrderedDict:
Expand Down
4 changes: 2 additions & 2 deletions mammoth/modules/adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def as_layer_struct(self):
self.residual,
])

def apply(tmp_layer_types, tmp_layer_structs, tmp_layer_dropouts):
def apply(self, tmp_layer_types, tmp_layer_structs, tmp_layer_dropouts):
# FeedForwards are injected after the base ff
tmp_layer_types.append('f')
tmp_layer_structs.append(self.as_layer_struct())
Expand Down Expand Up @@ -75,7 +75,7 @@ def __init__(
def is_wrapper(self):
return True

def apply(tmp_layer_types, tmp_layer_structs, tmp_layer_dropouts):
def apply(self, tmp_layer_types, tmp_layer_structs, tmp_layer_dropouts):
# LoraAdapterLayer wraps the existing feedforward. No norms are added.
tmp_layer_structs[0][1] = self.wrap(tmp_layer_structs[0][1])
return tmp_layer_types, tmp_layer_structs, tmp_layer_dropouts
Expand Down
13 changes: 9 additions & 4 deletions mammoth/modules/layer_stack.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from torch import nn
from typing import List, Sequence, Optional, Tuple
from x_transformers.x_transformers import LayerIntermediates

from mammoth.modules.adapters import AdaptedAttentionLayers

Expand All @@ -16,14 +17,18 @@ def __init__(self, attention_layers_stack: Sequence[AdaptedAttentionLayers]):
assert len(set(attention_layers.dim for attention_layers in attention_layers_stack)) == 1, \
'All AdaptedAttentionLayers must have the same dimension'

def forward(self, x, return_hiddens=False, **kwargs):
def forward(self, x, return_hiddens=False, cache: Optional[List[LayerIntermediates]] = None, **kwargs):
all_intermediates = []
for attention_layers in self.attention_layers_stack:
for i, attention_layers in enumerate(self.attention_layers_stack):
if cache:
cache_i = cache[i]
else:
cache_i = None
if return_hiddens:
x, intermediates = attention_layers.forward(x, return_hiddens=True, **kwargs)
x, intermediates = attention_layers.forward(x, return_hiddens=True, cache=cache_i, **kwargs)
all_intermediates.append(intermediates)
else:
x = attention_layers.forward(x, return_hiddens=False, **kwargs)
x = attention_layers.forward(x, return_hiddens=False, cache=cache_i, **kwargs)
if return_hiddens:
return x, all_intermediates
else:
Expand Down
8 changes: 2 additions & 6 deletions mammoth/translate/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
""" Modules for translation """
from mammoth.translate.translator import Translator, GeneratorLM
from mammoth.translate.translator import Translator
from mammoth.translate.translation import Translation, TranslationBuilder
from mammoth.translate.beam_search import BeamSearch, GNMTGlobalScorer
from mammoth.translate.beam_search import BeamSearchLM
from mammoth.translate.decode_strategy import DecodeStrategy
from mammoth.translate.greedy_search import GreedySearch, GreedySearchLM
from mammoth.translate.greedy_search import GreedySearch
from mammoth.translate.penalties import PenaltyBuilder
from mammoth.translate.translation_server import TranslationServer, ServerModelError

Expand All @@ -19,7 +18,4 @@
'ServerModelError',
"DecodeStrategy",
"GreedySearch",
"GreedySearchLM",
"BeamSearchLM",
"GeneratorLM",
]
140 changes: 27 additions & 113 deletions mammoth/translate/beam_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,21 +73,23 @@ def __init__(
stepwise_penalty,
ratio,
ban_unk_token,
device,
):
super(BeamSearchBase, self).__init__(
pad,
bos,
eos,
unk,
batch_size,
beam_size,
global_scorer,
min_length,
block_ngram_repeat,
exclusion_tokens,
return_attention,
max_length,
ban_unk_token,
pad=pad,
bos=bos,
eos=eos,
unk=unk,
batch_size=batch_size,
parallel_paths=beam_size,
global_scorer=global_scorer,
min_length=min_length,
block_ngram_repeat=block_ngram_repeat,
exclusion_tokens=exclusion_tokens,
return_attention=return_attention,
max_length=max_length,
ban_unk_token=ban_unk_token,
device=device,
)
# beam parameters
self.beam_size = beam_size
Expand All @@ -109,31 +111,27 @@ def __init__(
self._prev_penalty = None
self._coverage = None

self._stepwise_cov_pen = stepwise_penalty and self.global_scorer.has_cov_pen
self._vanilla_cov_pen = not stepwise_penalty and self.global_scorer.has_cov_pen
self._cov_pen = self.global_scorer.has_cov_pen

self.memory_lengths = None

def initialize(self, *args, **kwargs):
raise NotImplementedError

def initialize_(self, memory_bank, memory_lengths, src_map, device, target_prefix):
super(BeamSearchBase, self).initialize(memory_bank, memory_lengths, src_map, device, target_prefix)
def initialize_(self, target_prefix):
super(BeamSearchBase, self).initialize(target_prefix)

self.best_scores = torch.full([self.batch_size], -1e10, dtype=torch.float, device=device)
self.best_scores = torch.full([self.batch_size], -1e10, dtype=torch.float, device=self.device)
self._beam_offset = torch.arange(
0, self.batch_size * self.beam_size, step=self.beam_size, dtype=torch.long, device=device
0, self.batch_size * self.beam_size, step=self.beam_size, dtype=torch.long, device=self.device
)
self.topk_log_probs = (
torch.tensor([0.0] + [float("-inf")] * (self.beam_size - 1), device=device)
torch.tensor([0.0] + [float("-inf")] * (self.beam_size - 1), device=self.device)
.repeat(self.batch_size)
.reshape(self.batch_size, self.beam_size)
)
# buffers for the topk scores and 'backpointer'
self.topk_scores = torch.empty((self.batch_size, self.beam_size), dtype=torch.float, device=device)
self.topk_ids = torch.empty((self.batch_size, self.beam_size), dtype=torch.long, device=device)
self._batch_index = torch.empty([self.batch_size, self.beam_size], dtype=torch.long, device=device)
self.topk_scores = torch.empty((self.batch_size, self.beam_size), dtype=torch.float, device=self.device)
self.topk_ids = torch.empty((self.batch_size, self.beam_size), dtype=torch.long, device=self.device)
self._batch_index = torch.empty([self.batch_size, self.beam_size], dtype=torch.long, device=self.device)

@property
def current_predictions(self):
Expand Down Expand Up @@ -246,27 +244,14 @@ def remove_finished_batches(self, _B_new, _B_old, non_finished, predictions, att
self.alive_attn = attention.index_select(1, non_finished).view(
step - 1, _B_new * self.beam_size, inp_seq_len
)
if self._cov_pen:
self._coverage = (
self._coverage.view(1, _B_old, self.beam_size, inp_seq_len)
.index_select(1, non_finished)
.view(1, _B_new * self.beam_size, inp_seq_len)
)
if self._stepwise_cov_pen:
self._prev_penalty = self._prev_penalty.index_select(0, non_finished)

def advance(self, log_probs, attn):
def advance(self, logits, new_cache):
log_probs = torch.log_softmax(logits, dim=-1)
vocab_size = log_probs.size(-1)

# using integer division to get an integer _B without casting
_B = log_probs.shape[0] // self.beam_size

if self._stepwise_cov_pen and self._prev_penalty is not None:
self.topk_log_probs += self._prev_penalty
self.topk_log_probs -= self.global_scorer.cov_penalty(self._coverage + attn, self.global_scorer.beta).view(
_B, self.beam_size
)

# force the output to be longer than self.min_length
step = len(self)
self.ensure_min_length(log_probs)
Expand Down Expand Up @@ -305,30 +290,6 @@ def advance(self, log_probs, attn):

self.maybe_update_forbidden_tokens()

if self.return_attention or self._cov_pen:
current_attn = attn.index_select(1, self.select_indices)
if step == 1:
self.alive_attn = current_attn
# update global state (step == 1)
if self._cov_pen: # coverage penalty
self._prev_penalty = torch.zeros_like(self.topk_log_probs)
self._coverage = current_attn
else:
self.alive_attn = self.alive_attn.index_select(1, self.select_indices)
self.alive_attn = torch.cat([self.alive_attn, current_attn], 0)
# update global state (step > 1)
if self._cov_pen:
self._coverage = self._coverage.index_select(1, self.select_indices)
self._coverage += current_attn
self._prev_penalty = self.global_scorer.cov_penalty(
self._coverage, beta=self.global_scorer.beta
).view(_B, self.beam_size)

if self._vanilla_cov_pen:
# shape: (batch_size x beam_size, 1)
cov_penalty = self.global_scorer.cov_penalty(self._coverage, beta=self.global_scorer.beta)
self.topk_scores -= cov_penalty.view(_B, self.beam_size).float()

self.is_finished = self.topk_ids.eq(self.eos)
self.ensure_max_length()

Expand All @@ -338,57 +299,10 @@ class BeamSearch(BeamSearchBase):
Beam search for seq2seq/encoder-decoder models
"""

def initialize(self, memory_bank, src_lengths, src_map=None, device=None, target_prefix=None):
def initialize(self, target_prefix=None):
"""Initialize for decoding.
Repeat src objects `beam_size` times.
"""

(fn_map_state, memory_bank, src_map, target_prefix) = self.initialize_tile(
memory_bank, src_lengths, src_map, target_prefix
)
if device is None:
device = self.get_device_from_memory_bank(memory_bank)

super(BeamSearch, self).initialize_(memory_bank, self.memory_lengths, src_map, device, target_prefix)

return fn_map_state, memory_bank, self.memory_lengths, src_map


class BeamSearchLM(BeamSearchBase):
"""
Beam search for language/decoder only models
"""

def initialize(self, src, src_lengths, src_map=None, device=None, target_prefix=None):
"""Initialize for decoding.
Repeat src objects `beam_size` times.
"""
(fn_map_state, _, src_map, target_prefix) = self.initialize_tile(None, src_lengths, src_map, target_prefix)
if device is None:
device = src.device

super(BeamSearchLM, self).initialize_(
None, self.memory_lengths, src_map=src_map, device=device, target_prefix=target_prefix
)

return fn_map_state, src, self.memory_lengths, src_map

def advance(self, log_probs, attn):
super(BeamSearchLM, self).advance(log_probs, attn)

# in LM task memory_lengths is associated with currently generated src
# and therefore needs to follow the generation
self.memory_lengths += 1

def remove_finished_batches(self, _B_new, _B_old, non_finished, predictions, attention, step):
super(BeamSearchLM, self).remove_finished_batches(_B_new, _B_old, non_finished, predictions, attention, step)

# in LM task memory_lengths is associated with currently generated src
# and therefore needs to follow the generation
non_finished = non_finished.to(self.topk_ids.device)
self.memory_lengths = (
self.memory_lengths.view(_B_old, self.beam_size).index_select(0, non_finished).view(_B_new * self.beam_size)
)
super(BeamSearch, self).initialize_(target_prefix)


class GNMTGlobalScorer(object):
Expand Down
44 changes: 13 additions & 31 deletions mammoth/translate/decode_strategy.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import torch
from copy import deepcopy
from typing import Optional, List

from x_transformers.x_transformers import LayerIntermediates

from mammoth.utils.misc import tile

Expand Down Expand Up @@ -78,6 +81,7 @@ def __init__(
return_attention,
max_length,
ban_unk_token,
device,
):

# magic indices
Expand Down Expand Up @@ -108,45 +112,21 @@ def __init__(

self.exclusion_tokens = exclusion_tokens
self.return_attention = return_attention
self.device = device

self.done = False
self.cache: Optional[List[LayerIntermediates]] = None

def get_device_from_memory_bank(self, memory_bank):
if isinstance(memory_bank, tuple):
mb_device = memory_bank[0].device
else:
mb_device = memory_bank.device
return mb_device

def initialize_tile(self, memory_bank, src_lengths, src_map=None, target_prefix=None):
def fn_map_state(state, dim):
return tile(state, self.beam_size, dim=dim)

if isinstance(memory_bank, tuple):
memory_bank = tuple(tile(x, self.beam_size, dim=1) for x in memory_bank)
elif memory_bank is not None:
memory_bank = tile(memory_bank, self.beam_size, dim=1)
if src_map is not None:
src_map = tile(src_map, self.beam_size, dim=1)

self.memory_lengths = tile(src_lengths, self.beam_size)
if target_prefix is not None:
target_prefix = tile(target_prefix, self.beam_size, dim=1)

return fn_map_state, memory_bank, src_map, target_prefix

def initialize(self, memory_bank, src_lengths, src_map=None, device=None, target_prefix=None):
def initialize(self, target_prefix=None):
"""DecodeStrategy subclasses should override :func:`initialize()`.
`initialize` should be called before all actions.
used to prepare necessary ingredients for decode.
"""
if device is None:
device = torch.device('cpu')
self.alive_seq = torch.full(
[self.batch_size * self.parallel_paths, 1], self.bos, dtype=torch.long, device=device
[self.batch_size * self.parallel_paths, 1], self.bos, dtype=torch.long, device=self.device
)
self.is_finished = torch.zeros([self.batch_size, self.parallel_paths], dtype=torch.uint8, device=device)
self.is_finished = torch.zeros([self.batch_size, self.parallel_paths], dtype=torch.uint8, device=self.device)
if target_prefix is not None:
seq_len, batch_size, n_feats = target_prefix.size()
assert (
Expand All @@ -161,7 +141,9 @@ def initialize(self, memory_bank, src_lengths, src_map=None, device=None, target
self.min_length += min(prefix_non_pad) - 1

self.target_prefix = target_prefix # NOTE: forced prefix words
return None, memory_bank, src_lengths, src_map

def set_cache(self, cache):
self.cache = cache

def __len__(self):
return self.alive_seq.shape[1]
Expand Down Expand Up @@ -293,7 +275,7 @@ def maybe_update_target_prefix(self, select_index):
return
self.target_prefix = self.target_prefix.index_select(0, select_index)

def advance(self, log_probs, attn):
def advance(self, logits, new_cache):
"""DecodeStrategy subclasses should override :func:`advance()`.
Advance is used to update ``self.alive_seq``, ``self.is_finished``,
Expand Down
Loading

0 comments on commit fd492d9

Please sign in to comment.