Skip to content

Commit

Permalink
WIP: dimension ordering messed up
Browse files Browse the repository at this point in the history
  • Loading branch information
Waino committed Aug 5, 2024
1 parent e09763f commit 89f3163
Show file tree
Hide file tree
Showing 20 changed files with 340 additions and 869 deletions.
41 changes: 23 additions & 18 deletions mammoth/distributed/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ class Side(Enum):
decoder = auto()


@dataclass
# mypy doesn't like abstract dataclasses
@dataclass # type: ignore
class DistributedComponent(ABC):
"""
Represents a model component that may be distributed across several
Expand Down Expand Up @@ -78,11 +79,16 @@ def needs_communication(self) -> bool:


# TODO: This is a misnomer: Not an entire XCoder, but just one AttentionLayers block
@dataclass
@dataclass # type: ignore
class DistributedXCoder(DistributedComponent, ABC):
layer_stack_index: int
xcoder_id: str

@property
@abstractmethod
def side(self) -> Side:
pass

def get_name(self) -> str:
return f'{self.side.name}_{self.layer_stack_index}_{self.xcoder_id}'

Expand All @@ -106,7 +112,11 @@ def encoder_id(self) -> str:
return self.xcoder_id

def get_module(self, model: NMTModel) -> nn.Module:
return model.encoder.get_submodule(self.layer_stack_index, self.xcoder_id)
a_task_id = sorted(self.task_ids)[0]
aal = model.encoder.get_attention_layers(a_task_id, self.layer_stack_index)
assert aal.xcoder_id == self.xcoder_id, \
f'{self.get_name()} {self.layer_stack_index}: expected {self.xcoder_id} found {aal.xcoder_id}'
return aal


@dataclass
Expand All @@ -120,7 +130,11 @@ def decoder_id(self) -> str:
return self.xcoder_id

def get_module(self, model: NMTModel) -> nn.Module:
return model.decoder.get_submodule(self.layer_stack_index, self.xcoder_id)
a_task_id = sorted(self.task_ids)[0]
aal = model.decoder.get_attention_layers(a_task_id, self.layer_stack_index)
assert aal.xcoder_id == self.xcoder_id, \
f'{self.get_name()} {self.layer_stack_index}: expected {self.xcoder_id} found {aal.xcoder_id}'
return aal


@dataclass
Expand All @@ -133,21 +147,12 @@ def get_name(self) -> str:
return f'{side_str}_embeddings_{self.lang}'

def get_module(self, model: NMTModel) -> nn.Module:
a_task_id = sorted(self.task_ids)[0]
# FIXME: embeddings should be pre-created and stored in a dict keyed by lang
if self.side == Side.encoder:
return model.encoder.embeddings[f'embeddings_{self.lang}']
return model.encoder.get_embedding(a_task_id)
else:
return model.decoder.embeddings[f'embeddings_{self.lang}']


@dataclass
class DistributedGenerator(DistributedComponent):
lang: str

def get_name(self) -> str:
return f'generator_{self.lang}'

def get_module(self, model: NMTModel) -> nn.Module:
return model.generator[f'generator_{self.lang}']
return model.decoder.get_embedding(a_task_id)


@dataclass
Expand Down Expand Up @@ -175,7 +180,7 @@ def get_name(self) -> str:
return 'attention_bridge'

def get_module(self, model: NMTModel) -> Optional[nn.Module]:
return self.model.attention_bridge
return model.attention_bridge


@dataclass
Expand Down
9 changes: 0 additions & 9 deletions mammoth/distributed/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
DistributedDecoder,
DistributedEmbedding,
DistributedEncoder,
DistributedGenerator,
Side,
)
from mammoth.distributed.contexts import DeviceContext, WorldContext
Expand Down Expand Up @@ -367,14 +366,6 @@ def create_all_distributed_components(
lang=task.tgt_lang,
)
)
builder.add(
DistributedGenerator(
global_ranks={global_rank},
task_ids={task.corpus_id},
group=None,
lang=task.tgt_lang,
)
)
for layer_stack_index, encoder_id in enumerate(task.encoder_id):
builder.add(
DistributedEncoder(
Expand Down
39 changes: 21 additions & 18 deletions mammoth/inputters/dataset.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import collections
from dataclasses import dataclass
import gzip
import itertools
from dataclasses import dataclass
from functools import partial
import gzip
from io import IOBase

import torch
from torch import Tensor
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import IterableDataset

Expand All @@ -15,18 +16,21 @@
from mammoth.inputters.vocab import Vocab


TensorWithMask = collections.namedtuple('TensorWithMask', 'tensor mask')


@dataclass
class Batch():
src: tuple # of torch Tensors
tgt: torch.Tensor
labels: torch.Tensor
src: TensorWithMask
tgt: TensorWithMask
labels: Tensor
batch_size: int
line_idx: int

def to(self, device):
self.src = (self.src[0].to(device), self.src[1].to(device))
self.src = TensorWithMask(self.src.tensor.to(device), self.src.mask.to(device))
if self.tgt is not None:
self.tgt = self.tgt.to(device)
self.tgt = TensorWithMask(self.tgt.tensor.to(device), self.tgt.mask.to(device))
if self.labels is not None:
self.labels = self.labels.to(device)
return self
Expand Down Expand Up @@ -196,23 +200,22 @@ def _cast(example_dict):

def collate_fn(self, examples, line_idx):
has_tgt = 'tgt' in examples[0].keys()
src_padidx = self.vocabs['src'][DefaultTokens.PAD]
tgt_padidx = self.vocabs['tgt'][DefaultTokens.PAD]
if self.max_length is None:
src_lengths = torch.tensor([ex['src'].numel() for ex in examples], device='cpu')
else:
src_lengths = torch.tensor([min(ex['src'].numel(), self.max_length) for ex in examples], device='cpu')
src = (self._pad_sequence([ex['src'] for ex in examples], padding_value=src_padidx), src_lengths)
src_padding_idx = self.vocabs['src'][DefaultTokens.PAD]
tgt_padding_idx = self.vocabs['tgt'][DefaultTokens.PAD]
src = self._pad_sequence([ex['src'] for ex in examples], padding_value=src_padding_idx)
src_mask = src[:, :, 0].ne(src_padding_idx)
if has_tgt:
tgt = self._pad_sequence([ex['tgt'] for ex in examples], padding_value=tgt_padidx)
tgt = self._pad_sequence([ex['tgt'] for ex in examples], padding_value=tgt_padding_idx)
tgt_mask = tgt[:, :, 0].ne(tgt_padding_idx)
if 'labels' not in examples[0].keys():
labels = tgt
else:
labels = self._pad_sequence([ex['labels'] for ex in examples], padding_value=tgt_padidx)
labels = self._pad_sequence([ex['labels'] for ex in examples], padding_value=tgt_padding_idx)
tgt_with_mask = TensorWithMask(tgt, tgt_mask)
else:
tgt = None
tgt_with_mask = None
labels = None
batch = Batch(src, tgt, labels, len(examples), line_idx)
batch = Batch(TensorWithMask(src, src_mask), tgt_with_mask, labels, len(examples), line_idx)
return batch


Expand Down
55 changes: 44 additions & 11 deletions mammoth/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@
from functools import partial
from pathlib import Path
from torch.nn.init import xavier_uniform_
from typing import Optional, List
from x_transformers import TransformerWrapper
from typing import Optional, List, Dict, Tuple
# from x_transformers import TransformerWrapper
from x_transformers.x_transformers import TokenEmbedding

from mammoth.distributed.components import (
DistributedAdapter,
Expand All @@ -18,20 +19,21 @@
DistributedEncoder,
Side,
)
from mammoth.models import NMTModel
from mammoth.modules.adapters import (
AdaptedAttentionLayers,
Adapter,
FeedForwardAdapterLayer,
LoraAdapterLayer,
)
from mammoth.inputters.vocab import Vocab
from mammoth.models import NMTModel
from mammoth.modules.attention_bridge import AttentionBridge
from mammoth.modules.layer_stack import AdaptedAttentionLayersStack, StackXcoder
from mammoth.utils.logging import logger
from mammoth.utils.misc import use_gpu
from mammoth.utils.module_splitter import _combine_ordered_dicts
from mammoth.utils.parse import ArgumentParser

from mammoth.modules.attention_bridge import AttentionBridge
from mammoth.utils.transformer_wrapper import TransformerWrapper


def uses_adapters(opts):
Expand Down Expand Up @@ -257,11 +259,22 @@ def get_attention_layers_kwargs(
def build_xcoder(
side: Side,
model_opts,
vocabs_dict,
vocabs_dict: Dict[Tuple[str, str], Vocab],
device,
task_queue_manager,
single_task: Optional[str] = None,
):
token_embs: Optional[Dict[str, Vocab]] = None,
) -> StackXcoder:
"""
Build a StackXcoder for use as either Encoder or Decoder.
side: a Side enum from distributed components
model_opts: options
vocabs_dict: A dict mapping ('src'|'tgt', lang) to a Vocab.
device: torch.device
task_queue_manager: TaskQueueManager
single_task: if a task_id string is given, the built model contains only the components necessary for that task.
token_embs: to tie encoder and decoder embeddings, pass existing embeddings here.
"""
my_components: List[DistributedComponent] = task_queue_manager.get_my_distributed_components()
my_components = [
component for component in my_components
Expand Down Expand Up @@ -295,7 +308,11 @@ def build_xcoder(
xcoder_id=xcoder_id,
model_opts=model_opts,
)
attention_layer_blocks[layer_stack_index][xcoder_id] = AdaptedAttentionLayers(**attention_layers_kwargs)
attention_layer_blocks[layer_stack_index][xcoder_id] = AdaptedAttentionLayers(
layer_stack_index=layer_stack_index,
xcoder_id=xcoder_id,
**attention_layers_kwargs
)

# Create AdapterLayer objects and Adapter objects
if uses_adapters(model_opts):
Expand Down Expand Up @@ -344,6 +361,23 @@ def build_xcoder(
for attention_layers in attention_layer_blocks[layer_stack_index]:
attention_layers.add_adapter(adapter)

# Create TokenEmbedding objects
l2norm_embed = False
if side == Side.encoder:
all_langs = sorted(set(task_queue_manager.get_my_src_langs()))
else:
all_langs = sorted(set(task_queue_manager.get_my_tgt_langs()))
side_alt_str = 'src' if side == Side.encoder else 'tgt'
if token_embs is None:
token_embs = dict()
for lang in all_langs:
if lang not in token_embs:
vocab = vocabs_dict[(side_alt_str, lang)]
token_embs[lang] = TokenEmbedding(
dim=model_opts.model_dim,
num_tokens=len(vocab),
l2norm_embed=l2norm_embed
)
# Create AdaptedAttentionLayersStack objects and TransformerWrapper objects
tasks = task_queue_manager.get_my_tasks()
if single_task:
Expand All @@ -362,16 +396,14 @@ def build_xcoder(
attention_layers_stack=attention_layers_stack
)

side_alt_str = 'src' if side == Side.encoder else 'tgt'
lang = task.src_lang if side == Side.encoder else task.tgt_lang
vocab = vocabs_dict[(side_alt_str, lang)]
max_seq_len = 0 if model_opts.max_length is None else model_opts.max_length
post_emb_norm = True
tie_embedding = True
use_abs_pos_emb = True
emb_frac_gradient = 1.
# FIXME: this won't work: creates embeddings for each task, not for each language
# Have to reimplement TransformerWrapper to allow passing in an embedding
# Using custom extended TransformerWrapper to allow passing in an embedding
transformer_wrapper = TransformerWrapper(
num_tokens=len(vocab),
max_seq_len=max_seq_len,
Expand All @@ -381,6 +413,7 @@ def build_xcoder(
tie_embedding=tie_embedding,
use_abs_pos_emb=use_abs_pos_emb,
emb_frac_gradient=emb_frac_gradient,
token_emb=token_embs[lang],
)
transformer_wrappers[task.corpus_id] = transformer_wrapper

Expand Down
57 changes: 40 additions & 17 deletions mammoth/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ class BaseModel(nn.Module):
def __init__(self, encoder, decoder, attention_bridge):
super(BaseModel, self).__init__()

def forward(self, src, tgt, lengths, bptt=False, with_align=False):
def forward(self, src, tgt, lengths, return_attention=False):
"""Forward propagate a `src` and `tgt` pair for training.
Possible initialized with a beginning decoder state.
Expand All @@ -25,7 +25,7 @@ def forward(self, src, tgt, lengths, bptt=False, with_align=False):
lengths(LongTensor): The src lengths, pre-padding ``(batch,)``.
bptt (Boolean): A flag indicating if truncated bptt is set.
If reset then init_state
with_align (Boolean): A flag indicating whether output alignment,
return_attention (Boolean): A flag indicating whether output attention,
Only valid for transformer decoder.
Returns:
Expand Down Expand Up @@ -58,24 +58,47 @@ def __init__(self, encoder, decoder, attention_bridge):
self.decoder = decoder
self.attention_bridge = attention_bridge

def forward(self, src, tgt, lengths, bptt=False, with_align=False, metadata=None):
dec_in = tgt[:-1] # exclude last target from inputs

def forward(self, src, decoder_input, src_mask, return_attention=False, metadata=None):
# Activate the correct pluggable embeddings and modules
self.encoder.activate(metadata)
self.decoder.activate(metadata)

enc_state, memory_bank, lengths, mask = self.encoder(src, lengths)

memory_bank, alphas = self.attention_bridge(memory_bank, mask)
active_encoder = self.encoder.activate(
task_id=metadata.corpus_id,
adapter_ids=metadata.encoder_adapter_ids,
)
active_decoder = self.decoder.activate(
task_id=metadata.corpus_id,
adapter_ids=metadata.decoder_adapter_ids,
)

# QUI logging for batch shapes
def quishape(name, val):
print(f'{name} {val.shape} {val.shape[0] * val.shape[1]}')
quishape('src', src)
quishape('src_mask', src_mask)

encoder_output = active_encoder(
x=src,
mask=src_mask,
return_embeddings=True,
)

encoder_output, alphas = self.attention_bridge(encoder_output, src_mask)
if self.attention_bridge.is_fixed_length:
# turn off masking in the transformer decoder
lengths = None

if not bptt:
self.decoder.init_state(src, memory_bank, enc_state)
dec_out, attns = self.decoder(dec_in, memory_bank, memory_lengths=lengths, with_align=with_align)
return dec_out, attns
src_mask = None

retval = active_decoder(
decoder_input,
context=encoder_output,
context_mask=src_mask,
return_attn=return_attention,
return_logits_and_embeddings=True,
)
if return_attention:
(logits, decoder_output), attentions = retval
else:
logits, decoder_output = retval
attentions = None
return logits, decoder_output, attentions

def update_dropout(self, dropout):
self.encoder.update_dropout(dropout)
Expand Down
Loading

0 comments on commit 89f3163

Please sign in to comment.