Skip to content

Commit

Permalink
merge
Browse files Browse the repository at this point in the history
  • Loading branch information
Joseph Attieh authored and Joseph Attieh committed Oct 14, 2024
2 parents bc8735a + 97bc2d9 commit 16b4e6d
Show file tree
Hide file tree
Showing 9 changed files with 140 additions and 44 deletions.
45 changes: 40 additions & 5 deletions mammoth/distributed/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,41 @@ def needs_communication(self) -> bool:
return self.group is not None


# TODO: This is a misnomer: Not an entire XCoder, but just one AttentionLayers block
@dataclass # type: ignore
class DistributedTransformerWrapper(DistributedComponent, ABC):
task_id: str
side: Side

def get_name(self) -> str:
return f'{self.side.name}_{self.task_id}'

def get_module(self, model: NMTModel) -> nn.Module:
parent = model.encoder if self.side == Side.encoder else model.decoder
tw = parent[self.task_id]
return tw

def named_parameters(self, model: NMTModel):
module = self.get_module(model)
for name, p in module.named_parameters():
# TransformerWrapper contains the AttentionLayers and the embs.
# however, we want to treat these as distinct DistributedComponents
if name.startswith('attn_layers.'):
continue
if name.startswith('token_emb.'):
continue
yield name, p

def state_dict(self, model: NMTModel, prefix='', keep_vars=False) -> Dict[str, Any]:
module = self.get_module(model)
destination: Dict[str, Any] = OrderedDict()
for name, sub_module in module._modules.items():
if name.endswith('attn_layers'):
# stored separately
continue
sub_module.state_dict(destination=destination, prefix=prefix + name + '.', keep_vars=keep_vars)
return destination


@dataclass # type: ignore
class DistributedAttentionLayersBlock(DistributedComponent, ABC):
layer_stack_index: int
Expand All @@ -106,8 +140,9 @@ def named_parameters(self, model: NMTModel):
for name, p in module.named_parameters():
# encoders and decoders contain embeddings and adapters as submodules
# however, we want to treat these as distinct DistributedComponents
if 'embeddings' not in name and 'adapter' not in name:
yield name, p
if 'adapter' in name:
continue
yield name, p

def state_dict(self, model: NMTModel, prefix='', keep_vars=False) -> Dict[str, Any]:
module = self.get_module(model)
Expand All @@ -121,7 +156,7 @@ def state_dict(self, model: NMTModel, prefix='', keep_vars=False) -> Dict[str, A


@dataclass
class DistributedEncoder(DistributedAttentionLayersBlock):
class DistributedEncoderAttentionLayersBlock(DistributedAttentionLayersBlock):
@property
def side(self) -> Side:
return Side.encoder
Expand All @@ -136,7 +171,7 @@ def get_module(self, model: NMTModel) -> nn.Module:


@dataclass
class DistributedDecoder(DistributedAttentionLayersBlock):
class DistributedDecoderAttentionLayersBlock(DistributedAttentionLayersBlock):
@property
def side(self) -> Side:
return Side.decoder
Expand Down
27 changes: 23 additions & 4 deletions mammoth/distributed/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@
DistributedComponent,
DistributedComponentBuilder,
DistributedComponentGradientSync,
DistributedDecoder,
DistributedDecoderAttentionLayersBlock,
DistributedEmbedding,
DistributedEncoder,
DistributedEncoderAttentionLayersBlock,
DistributedTransformerWrapper,
Side,
)
from mammoth.distributed.contexts import DeviceContext, WorldContext
Expand Down Expand Up @@ -369,9 +370,27 @@ def create_all_distributed_components(
lang=task.tgt_lang,
)
)
builder.add(
DistributedTransformerWrapper(
global_ranks={global_rank},
task_ids={task.corpus_id},
group=None,
side=Side.encoder,
task_id=task.corpus_id,
)
)
builder.add(
DistributedTransformerWrapper(
global_ranks={global_rank},
task_ids={task.corpus_id},
group=None,
side=Side.decoder,
task_id=task.corpus_id,
)
)
for layer_stack_index, encoder_id in enumerate(task.encoder_id):
builder.add(
DistributedEncoder(
DistributedEncoderAttentionLayersBlock(
global_ranks={global_rank},
task_ids={task.corpus_id},
group=None,
Expand All @@ -381,7 +400,7 @@ def create_all_distributed_components(
)
for layer_stack_index, decoder_id in enumerate(task.decoder_id):
builder.add(
DistributedDecoder(
DistributedDecoderAttentionLayersBlock(
global_ranks={global_rank},
task_ids={task.corpus_id},
group=None,
Expand Down
64 changes: 50 additions & 14 deletions mammoth/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
from mammoth.distributed.components import (
DistributedAdapter,
DistributedComponent,
DistributedDecoder,
DistributedEncoder,
DistributedDecoderAttentionLayersBlock,
DistributedEncoderAttentionLayersBlock,
Side,
)
from mammoth.modules.adapters import (
Expand Down Expand Up @@ -44,6 +44,14 @@ def forward(self, x):
token_emb = self.emb(x.long())
return token_emb

TRANSFORMER_WRAPPER_OPTS = {
'post_emb_norm',
'tie_embedding',
'use_abs_pos_emb',
'scaled_sinu_pos_emb',
'emb_frac_gradient',
}


def _combine_ordered_dicts(input_dicts: Dict[str, OrderedDict]) -> OrderedDict:
result = []
Expand Down Expand Up @@ -72,6 +80,7 @@ def get_attention_layers_kwargs(
is_last = layer_stack_index == len(depths) - 1
pre_norm_has_final_norm = is_last
kwargs = model_opts.x_transformers_opts if model_opts.x_transformers_opts else dict()
kwargs = {key: val for key, val in kwargs.items() if key not in TRANSFORMER_WRAPPER_OPTS}
kwargs.update({
'dim': model_opts.model_dim,
'depth': depth,
Expand All @@ -82,6 +91,21 @@ def get_attention_layers_kwargs(
return kwargs


def get_transformer_wrapper_kwargs(
side: Side,
model_opts,
):
"""Return arguments for x_transformers.TransformerWrapper"""
assert side in {Side.encoder, Side.decoder}, f'Invalid side "{side}"'
kwargs = model_opts.x_transformers_opts if model_opts.x_transformers_opts else dict()
kwargs = {key: val for key, val in kwargs.items() if key in TRANSFORMER_WRAPPER_OPTS}
max_seq_len = 0 if model_opts.max_length is None else model_opts.max_length
kwargs.update({
'max_seq_len': max_seq_len,
})
return kwargs


def build_xcoder(
side: Side,
model_opts,
Expand All @@ -108,10 +132,10 @@ def build_xcoder(
]
distributed_xcoder_class: type
if side == Side.encoder:
distributed_xcoder_class = DistributedEncoder
distributed_xcoder_class = DistributedEncoderAttentionLayersBlock
side_str = 'encoder'
else:
distributed_xcoder_class = DistributedDecoder
distributed_xcoder_class = DistributedDecoderAttentionLayersBlock
side_str = 'decoder'
if single_task:
my_components = [
Expand Down Expand Up @@ -210,6 +234,10 @@ def build_xcoder(
if single_task:
tasks = [task for task in tasks if task.corpus_id == single_task]
transformer_wrappers = dict()
transformer_wrapper_kwargs = get_transformer_wrapper_kwargs(
side=side,
model_opts=model_opts,
)
for task in tasks:
if side == Side.encoder:
xcoder_ids = task.encoder_id
Expand All @@ -225,21 +253,11 @@ def build_xcoder(

lang = task.src_lang if side == Side.encoder else task.tgt_lang
vocab = vocabs_dict[(side_alt_str, lang)]
max_seq_len = 0 if model_opts.max_length is None else model_opts.max_length
post_emb_norm = True
tie_embedding = True
use_abs_pos_emb = True
emb_frac_gradient = 1.
# Using custom extended TransformerWrapper to allow passing in an embedding
transformer_wrapper = TransformerWrapper(
num_tokens=len(vocab),
max_seq_len=max_seq_len,
attn_layers=adapted_attention_layers_stack,
emb_dim=model_opts.model_dim,
post_emb_norm=post_emb_norm,
tie_embedding=tie_embedding,
use_abs_pos_emb=use_abs_pos_emb,
emb_frac_gradient=emb_frac_gradient,
token_emb=token_embs[lang],
initialize_embeddings=not (model_opts.use_embeddingless)
)
Expand Down Expand Up @@ -324,3 +342,21 @@ def build_model(
# logger.info(model)
logger.info('Building model - done!')
return model


def validate_optimizer_coverage(model, optimizer):
trainable_model_params = {
name: p for name, p in model.named_parameters()
if p.requires_grad
}
optimized_params = set()
for group in optimizer.param_groups:
optimized_params.update(group['params'])
missing_params = [
name for name, p in trainable_model_params.items()
if p not in optimized_params
]
if len(missing_params) > 0:
raise Exception(f'Missing optimizer for params: {sorted(missing_params)}')
else:
logger.info('All non-frozen parameters have an optimizer')
20 changes: 10 additions & 10 deletions mammoth/tests/test_task_queue_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from mammoth.distributed import TaskQueueManager, WorldContext
from mammoth.distributed.components import (
Side,
DistributedEncoder,
DistributedDecoder,
DistributedEncoderAttentionLayersBlock,
DistributedDecoderAttentionLayersBlock,
DistributedEmbedding,
# DistributedAdapter,
# DistributedAttentionBridge,
Expand Down Expand Up @@ -155,35 +155,35 @@ def __call__(self, sorted_global_ranks):
use_attention_bridge=False, new_group_func=MockGroup()
)
assert all_components == [
DistributedDecoder(
DistributedDecoderAttentionLayersBlock(
global_ranks={0, 2},
task_ids={'train_3_e-b', 'train_0_a-b'},
group="Group 0 with GPU ranks [0, 2]",
layer_stack_index=0,
xcoder_id="y",
),
DistributedDecoder(
DistributedDecoderAttentionLayersBlock(
global_ranks={1},
task_ids={"train_2_a-d", "train_1_c-d"},
group=None,
layer_stack_index=0,
xcoder_id="yy",
),
DistributedEncoder(
DistributedEncoderAttentionLayersBlock(
global_ranks={0, 1},
task_ids={"train_2_a-d", "train_0_a-b"},
group="Group 1 with GPU ranks [0, 1]",
layer_stack_index=0,
xcoder_id="x",
),
DistributedEncoder(
DistributedEncoderAttentionLayersBlock(
global_ranks={1},
task_ids={"train_1_c-d"},
group=None,
layer_stack_index=0,
xcoder_id="xx",
),
DistributedEncoder(
DistributedEncoderAttentionLayersBlock(
global_ranks={2},
task_ids={'train_3_e-b'},
group=None,
Expand Down Expand Up @@ -244,21 +244,21 @@ def __call__(self, sorted_global_ranks):
f"my component {component} not in all_components {all_components}"
)
assert my_components == [
DistributedDecoder(
DistributedDecoderAttentionLayersBlock(
global_ranks={1},
task_ids={"train_2_a-d", "train_1_c-d"},
group=None,
layer_stack_index=0,
xcoder_id="yy",
),
DistributedEncoder(
DistributedEncoderAttentionLayersBlock(
global_ranks={0, 1},
task_ids={"train_2_a-d", "train_0_a-b"},
group="Group 1 with GPU ranks [0, 1]",
layer_stack_index=0,
xcoder_id="x",
),
DistributedEncoder(
DistributedEncoderAttentionLayersBlock(
global_ranks={1},
task_ids={"train_1_c-d"},
group=None,
Expand Down
3 changes: 2 additions & 1 deletion mammoth/train_single.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch
import time

from mammoth.model_builder import build_model
from mammoth.model_builder import build_model, validate_optimizer_coverage
from mammoth.utils.optimizers import MultipleOptimizer
from mammoth.utils.misc import set_random_seed
from mammoth.trainer import build_trainer
Expand Down Expand Up @@ -139,6 +139,7 @@ def main(
device_context.id,
optim.count_parameters()
))
validate_optimizer_coverage(model, optim)

# Load parameters from checkpoint
if opts.train_from:
Expand Down
13 changes: 7 additions & 6 deletions mammoth/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,8 +239,9 @@ def train(
else:
logger.info('Start training loop and validate every %d steps...', valid_steps)

total_stats = mammoth.utils.Statistics()
report_stats = mammoth.utils.Statistics()
n_correct = 0 if self.report_training_accuracy else None
total_stats = mammoth.utils.Statistics(n_correct=n_correct)
report_stats = mammoth.utils.Statistics(n_correct=n_correct)
self._start_report_manager(start_time=total_stats.start_time)
self.optim.zero_grad()

Expand Down Expand Up @@ -385,7 +386,7 @@ def validate(self, valid_iter, moving_average=None, task=None):

for batch, metadata, _ in valid_iter:
if stats is None:
stats = mammoth.utils.Statistics()
stats = mammoth.utils.Statistics(n_correct=0)

stats.n_src_words += batch.src.mask.sum().item()
src = batch.src.tensor
Expand Down Expand Up @@ -470,9 +471,9 @@ def _gradient_accumulation(

with torch.cuda.amp.autocast(enabled=self.optim.amp):
logits, decoder_output = self.model(
rearrange(src, 't b 1 -> b t'),
rearrange(decoder_input, 't b 1 -> b t'),
rearrange(src_mask, 't b -> b t'),
src=rearrange(src, 't b 1 -> b t'),
decoder_input=rearrange(decoder_input, 't b 1 -> b t'),
src_mask=rearrange(src_mask, 't b -> b t'),
metadata=metadata,
)
logits = rearrange(logits, 'b t i -> t b i')
Expand Down
2 changes: 1 addition & 1 deletion mammoth/utils/model_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def load_parameters_from_checkpoint(
logger.warning(
f'Could not find optim checkpoint file {optimizer_path}. Affected parameters are reinitialized.'
)
all_ok = False
all_ok = True
if all_ok:
if reset_optim:
logger.info(f'All modules restored from checkpoint {checkpoint_prefix}')
Expand Down
Loading

0 comments on commit 16b4e6d

Please sign in to comment.