Skip to content

Commit

Permalink
Merge branch 'feat/model_blowout' into feat/embeddingless
Browse files Browse the repository at this point in the history
  • Loading branch information
Joseph Attieh authored and Joseph Attieh committed Oct 17, 2024
2 parents 1ed939f + 3a16fa0 commit 839b06c
Show file tree
Hide file tree
Showing 6 changed files with 129 additions and 4 deletions.
20 changes: 19 additions & 1 deletion mammoth/distributed/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,15 @@ def state_dict(self, model: NMTModel, prefix='', keep_vars=False) -> Dict[str, A
sub_module.state_dict(destination=destination, prefix=prefix + name + '.', keep_vars=keep_vars)
return destination

def load_state_dict(self, model: NMTModel, state_dict: Dict[str, Any]):
module = self.get_module(model)
mismatch = module.load_state_dict(state_dict, strict=False)
missing_keys = [
name for name in mismatch.missing_keys
if not name.startswith('attn_layers.') or name.startswith('token_emb.')
]
return mismatch._replace(missing_keys=missing_keys)


@dataclass # type: ignore
class DistributedAttentionLayersBlock(DistributedComponent, ABC):
Expand All @@ -147,13 +156,22 @@ def named_parameters(self, model: NMTModel):
def state_dict(self, model: NMTModel, prefix='', keep_vars=False) -> Dict[str, Any]:
module = self.get_module(model)
destination: Dict[str, Any] = OrderedDict()
for name, sub_module in module._modules.items():
for name, sub_module in module.get_sub_modules().items():
if name == 'adapters':
# Adapters are stored separately
continue
sub_module.state_dict(destination=destination, prefix=prefix + name + '.', keep_vars=keep_vars)
return destination

def load_state_dict(self, model: NMTModel, state_dict: Dict[str, Any]):
module = self.get_module(model)
mismatch = module.load_state_dict(state_dict, strict=False)
missing_keys = [
name for name in mismatch.missing_keys
if not name.startswith('layers.')
]
return mismatch._replace(missing_keys=missing_keys)


@dataclass
class DistributedEncoderAttentionLayersBlock(DistributedAttentionLayersBlock):
Expand Down
7 changes: 7 additions & 0 deletions mammoth/modules/adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,3 +225,10 @@ def _inject_adapters(self):
def forward(self, *args, **kwargs):
self._inject_adapters()
return super().forward(*args, **kwargs)

def get_sub_modules(self):
omit_submodules = {'layers'}
return {
name: sub_module for name, sub_module in self._modules.items()
if name not in omit_submodules
}
87 changes: 86 additions & 1 deletion mammoth/tests/test_task_queue_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
from mammoth.distributed import TaskQueueManager, WorldContext
from mammoth.distributed.components import (
Side,
DistributedEncoderAttentionLayersBlock,
DistributedDecoderAttentionLayersBlock,
DistributedEmbedding,
DistributedEncoderAttentionLayersBlock,
DistributedTransformerWrapper,
# DistributedAdapter,
# DistributedAttentionBridge,
# DistributedComponentAction,
Expand Down Expand Up @@ -169,6 +170,34 @@ def __call__(self, sorted_global_ranks):
layer_stack_index=0,
xcoder_id="yy",
),
DistributedTransformerWrapper(
global_ranks={0},
task_ids={'train_0_a-b'},
group=None,
task_id='train_0_a-b',
side=Side.decoder,
),
DistributedTransformerWrapper(
global_ranks={1},
task_ids={'train_1_c-d'},
group=None,
task_id='train_1_c-d',
side=Side.decoder,
),
DistributedTransformerWrapper(
global_ranks={1},
task_ids={'train_2_a-d'},
group=None,
task_id='train_2_a-d',
side=Side.decoder,
),
DistributedTransformerWrapper(
global_ranks={2},
task_ids={'train_3_e-b'},
group=None,
task_id='train_3_e-b',
side=Side.decoder,
),
DistributedEncoderAttentionLayersBlock(
global_ranks={0, 1},
task_ids={"train_2_a-d", "train_0_a-b"},
Expand All @@ -190,6 +219,34 @@ def __call__(self, sorted_global_ranks):
layer_stack_index=0,
xcoder_id="xxx",
),
DistributedTransformerWrapper(
global_ranks={0},
task_ids={'train_0_a-b'},
group=None,
task_id='train_0_a-b',
side=Side.encoder,
),
DistributedTransformerWrapper(
global_ranks={1},
task_ids={'train_1_c-d'},
group=None,
task_id='train_1_c-d',
side=Side.encoder,
),
DistributedTransformerWrapper(
global_ranks={1},
task_ids={'train_2_a-d'},
group=None,
task_id='train_2_a-d',
side=Side.encoder,
),
DistributedTransformerWrapper(
global_ranks={2},
task_ids={'train_3_e-b'},
group=None,
task_id='train_3_e-b',
side=Side.encoder,
),
DistributedEmbedding(
global_ranks={0, 1},
task_ids={"train_0_a-b", "train_2_a-d"},
Expand Down Expand Up @@ -251,6 +308,20 @@ def __call__(self, sorted_global_ranks):
layer_stack_index=0,
xcoder_id="yy",
),
DistributedTransformerWrapper(
global_ranks={1},
task_ids={'train_1_c-d'},
group=None,
task_id='train_1_c-d',
side=Side.decoder,
),
DistributedTransformerWrapper(
global_ranks={1},
task_ids={'train_2_a-d'},
group=None,
task_id='train_2_a-d',
side=Side.decoder,
),
DistributedEncoderAttentionLayersBlock(
global_ranks={0, 1},
task_ids={"train_2_a-d", "train_0_a-b"},
Expand All @@ -265,6 +336,20 @@ def __call__(self, sorted_global_ranks):
layer_stack_index=0,
xcoder_id="xx",
),
DistributedTransformerWrapper(
global_ranks={1},
task_ids={'train_1_c-d'},
group=None,
task_id='train_1_c-d',
side=Side.encoder,
),
DistributedTransformerWrapper(
global_ranks={1},
task_ids={'train_2_a-d'},
group=None,
task_id='train_2_a-d',
side=Side.encoder,
),
DistributedEmbedding(
global_ranks={0, 1},
task_ids={"train_0_a-b", "train_2_a-d"},
Expand Down
6 changes: 6 additions & 0 deletions mammoth/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,10 @@ def validate(self, valid_iter, moving_average=None, task=None):
src_mask = batch.src.mask
decoder_input = batch.tgt.tensor[:-1]
target = batch.tgt.tensor[1:]
# if self.norm_method == "tokens":
# normalization = batch.tgt.mask.sum().item()
# else:
# normalization = batch.batch_size

with torch.cuda.amp.autocast(enabled=self.optim.amp):
# F-prop through the model.
Expand All @@ -410,6 +414,7 @@ def validate(self, valid_iter, moving_average=None, task=None):
rearrange(logits, 't b i -> (t b) i'),
rearrange(target, 't b 1 -> (t b)'),
)
# loss /= normalization

# Update statistics.
padding_idx = self.loss_functions[metadata.tgt_lang].ignore_index
Expand Down Expand Up @@ -490,6 +495,7 @@ def _gradient_accumulation(
if loss is not None:
if torch.isnan(loss):
raise Exception('Loss blowout')
# loss /= normalization
self.optim.backward(loss)

if self.report_training_accuracy:
Expand Down
11 changes: 10 additions & 1 deletion mammoth/utils/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ def __init__(
self.grad_scaler = grad_scaler
self._training_step = 1
self._decay_step = 1
self._n_clips = 0

@property
def param_groups(self):
Expand Down Expand Up @@ -194,6 +195,7 @@ def load_state_dict(self, state_dict):
def zero_grad(self):
"""Zero the gradients of optimized parameters."""
self._optimizer.zero_grad()
self._n_clips = 0

def step(self):
"""Update the model parameters based on current gradients. """
Expand All @@ -202,7 +204,11 @@ def step(self):
for group in self._optimizer.param_groups:
group['lr'] = learning_rate
if self._max_grad_norm > 0:
clip_grad_norm_(group['params'], self._max_grad_norm)
orig_norm = clip_grad_norm_(group['params'], self._max_grad_norm)
if orig_norm.item() > self._max_grad_norm:
# FIXME: debug. Count and log instead
print(f'Clipping {orig_norm} -> {self._max_grad_norm}')
self._n_clips += 1

if self.grad_scaler is not None:
self.grad_scaler.step(self._optimizer)
Expand Down Expand Up @@ -396,6 +402,9 @@ def amp(self):
"""True if use torch amp mix precision training."""
return self.grad_scaler is not None

def n_clips(self):
return sum(optimizer._n_clips for optimizer in self.suboptimizers.values())


# Code below is an implementation of https://arxiv.org/pdf/1804.04235.pdf
# inspired but modified from https://github.com/DeadAt0m/adafactor-pytorch
Expand Down
2 changes: 1 addition & 1 deletion tools/generate_synth_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,7 @@ def generate_from_config(
@click.command(context_settings={'show_default': True})
@click.option('--config_path', type=Path, required=True)
@click.option('--vocab_size', type=int, default=300)
@click.option('--num_examples_train', type=int, default=10000)
@click.option('--num_examples_train', type=int, default=1000000)
@click.option('--num_examples_test', type=int, default=100)
@click.option('--start_seed', type=int, default=1)
@click.option(
Expand Down

0 comments on commit 839b06c

Please sign in to comment.