Skip to content

Commit

Permalink
State dict fixes
Browse files Browse the repository at this point in the history
The adapter injection code was causing parameter duplication.

Another issue: to normalize or not to normalize?
We compute a normalization based on either tokens or sents, but never
apply it. The effect can be compensated for using the learning rate, as
long as batches are approximately the same size. Too high learning rates
lead to gradient clipping, which is extra detrimental because each
component is individually clipped.

Clipping deterministically requires one of the following:
- access to gradients for all parameters of the entire model (infeasible)
- component local clipping (current approach)
- communicating a clipping factor across devices (maybe we should do
  this?)
  • Loading branch information
Waino committed Oct 14, 2024
1 parent 97bc2d9 commit 3a16fa0
Show file tree
Hide file tree
Showing 6 changed files with 129 additions and 4 deletions.
20 changes: 19 additions & 1 deletion mammoth/distributed/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,15 @@ def state_dict(self, model: NMTModel, prefix='', keep_vars=False) -> Dict[str, A
sub_module.state_dict(destination=destination, prefix=prefix + name + '.', keep_vars=keep_vars)
return destination

def load_state_dict(self, model: NMTModel, state_dict: Dict[str, Any]):
module = self.get_module(model)
mismatch = module.load_state_dict(state_dict, strict=False)
missing_keys = [
name for name in mismatch.missing_keys
if not name.startswith('attn_layers.') or name.startswith('token_emb.')
]
return mismatch._replace(missing_keys=missing_keys)


@dataclass # type: ignore
class DistributedAttentionLayersBlock(DistributedComponent, ABC):
Expand All @@ -147,13 +156,22 @@ def named_parameters(self, model: NMTModel):
def state_dict(self, model: NMTModel, prefix='', keep_vars=False) -> Dict[str, Any]:
module = self.get_module(model)
destination: Dict[str, Any] = OrderedDict()
for name, sub_module in module._modules.items():
for name, sub_module in module.get_sub_modules().items():
if name == 'adapters':
# Adapters are stored separately
continue
sub_module.state_dict(destination=destination, prefix=prefix + name + '.', keep_vars=keep_vars)
return destination

def load_state_dict(self, model: NMTModel, state_dict: Dict[str, Any]):
module = self.get_module(model)
mismatch = module.load_state_dict(state_dict, strict=False)
missing_keys = [
name for name in mismatch.missing_keys
if not name.startswith('layers.')
]
return mismatch._replace(missing_keys=missing_keys)


@dataclass
class DistributedEncoderAttentionLayersBlock(DistributedAttentionLayersBlock):
Expand Down
7 changes: 7 additions & 0 deletions mammoth/modules/adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,3 +225,10 @@ def _inject_adapters(self):
def forward(self, *args, **kwargs):
self._inject_adapters()
return super().forward(*args, **kwargs)

def get_sub_modules(self):
omit_submodules = {'layers'}
return {
name: sub_module for name, sub_module in self._modules.items()
if name not in omit_submodules
}
87 changes: 86 additions & 1 deletion mammoth/tests/test_task_queue_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
from mammoth.distributed import TaskQueueManager, WorldContext
from mammoth.distributed.components import (
Side,
DistributedEncoderAttentionLayersBlock,
DistributedDecoderAttentionLayersBlock,
DistributedEmbedding,
DistributedEncoderAttentionLayersBlock,
DistributedTransformerWrapper,
# DistributedAdapter,
# DistributedAttentionBridge,
# DistributedComponentAction,
Expand Down Expand Up @@ -169,6 +170,34 @@ def __call__(self, sorted_global_ranks):
layer_stack_index=0,
xcoder_id="yy",
),
DistributedTransformerWrapper(
global_ranks={0},
task_ids={'train_0_a-b'},
group=None,
task_id='train_0_a-b',
side=Side.decoder,
),
DistributedTransformerWrapper(
global_ranks={1},
task_ids={'train_1_c-d'},
group=None,
task_id='train_1_c-d',
side=Side.decoder,
),
DistributedTransformerWrapper(
global_ranks={1},
task_ids={'train_2_a-d'},
group=None,
task_id='train_2_a-d',
side=Side.decoder,
),
DistributedTransformerWrapper(
global_ranks={2},
task_ids={'train_3_e-b'},
group=None,
task_id='train_3_e-b',
side=Side.decoder,
),
DistributedEncoderAttentionLayersBlock(
global_ranks={0, 1},
task_ids={"train_2_a-d", "train_0_a-b"},
Expand All @@ -190,6 +219,34 @@ def __call__(self, sorted_global_ranks):
layer_stack_index=0,
xcoder_id="xxx",
),
DistributedTransformerWrapper(
global_ranks={0},
task_ids={'train_0_a-b'},
group=None,
task_id='train_0_a-b',
side=Side.encoder,
),
DistributedTransformerWrapper(
global_ranks={1},
task_ids={'train_1_c-d'},
group=None,
task_id='train_1_c-d',
side=Side.encoder,
),
DistributedTransformerWrapper(
global_ranks={1},
task_ids={'train_2_a-d'},
group=None,
task_id='train_2_a-d',
side=Side.encoder,
),
DistributedTransformerWrapper(
global_ranks={2},
task_ids={'train_3_e-b'},
group=None,
task_id='train_3_e-b',
side=Side.encoder,
),
DistributedEmbedding(
global_ranks={0, 1},
task_ids={"train_0_a-b", "train_2_a-d"},
Expand Down Expand Up @@ -251,6 +308,20 @@ def __call__(self, sorted_global_ranks):
layer_stack_index=0,
xcoder_id="yy",
),
DistributedTransformerWrapper(
global_ranks={1},
task_ids={'train_1_c-d'},
group=None,
task_id='train_1_c-d',
side=Side.decoder,
),
DistributedTransformerWrapper(
global_ranks={1},
task_ids={'train_2_a-d'},
group=None,
task_id='train_2_a-d',
side=Side.decoder,
),
DistributedEncoderAttentionLayersBlock(
global_ranks={0, 1},
task_ids={"train_2_a-d", "train_0_a-b"},
Expand All @@ -265,6 +336,20 @@ def __call__(self, sorted_global_ranks):
layer_stack_index=0,
xcoder_id="xx",
),
DistributedTransformerWrapper(
global_ranks={1},
task_ids={'train_1_c-d'},
group=None,
task_id='train_1_c-d',
side=Side.encoder,
),
DistributedTransformerWrapper(
global_ranks={1},
task_ids={'train_2_a-d'},
group=None,
task_id='train_2_a-d',
side=Side.encoder,
),
DistributedEmbedding(
global_ranks={0, 1},
task_ids={"train_0_a-b", "train_2_a-d"},
Expand Down
6 changes: 6 additions & 0 deletions mammoth/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,10 @@ def validate(self, valid_iter, moving_average=None, task=None):
src_mask = batch.src.mask
decoder_input = batch.tgt.tensor[:-1]
target = batch.tgt.tensor[1:]
# if self.norm_method == "tokens":
# normalization = batch.tgt.mask.sum().item()
# else:
# normalization = batch.batch_size

with torch.cuda.amp.autocast(enabled=self.optim.amp):
# F-prop through the model.
Expand All @@ -410,6 +414,7 @@ def validate(self, valid_iter, moving_average=None, task=None):
rearrange(logits, 't b i -> (t b) i'),
rearrange(target, 't b 1 -> (t b)'),
)
# loss /= normalization

# Update statistics.
padding_idx = self.loss_functions[metadata.tgt_lang].ignore_index
Expand Down Expand Up @@ -490,6 +495,7 @@ def _gradient_accumulation(
if loss is not None:
if torch.isnan(loss):
raise Exception('Loss blowout')
# loss /= normalization
self.optim.backward(loss)

if self.report_training_accuracy:
Expand Down
11 changes: 10 additions & 1 deletion mammoth/utils/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ def __init__(
self.grad_scaler = grad_scaler
self._training_step = 1
self._decay_step = 1
self._n_clips = 0

@property
def param_groups(self):
Expand Down Expand Up @@ -194,6 +195,7 @@ def load_state_dict(self, state_dict):
def zero_grad(self):
"""Zero the gradients of optimized parameters."""
self._optimizer.zero_grad()
self._n_clips = 0

def step(self):
"""Update the model parameters based on current gradients. """
Expand All @@ -202,7 +204,11 @@ def step(self):
for group in self._optimizer.param_groups:
group['lr'] = learning_rate
if self._max_grad_norm > 0:
clip_grad_norm_(group['params'], self._max_grad_norm)
orig_norm = clip_grad_norm_(group['params'], self._max_grad_norm)
if orig_norm.item() > self._max_grad_norm:
# FIXME: debug. Count and log instead
print(f'Clipping {orig_norm} -> {self._max_grad_norm}')
self._n_clips += 1

if self.grad_scaler is not None:
self.grad_scaler.step(self._optimizer)
Expand Down Expand Up @@ -396,6 +402,9 @@ def amp(self):
"""True if use torch amp mix precision training."""
return self.grad_scaler is not None

def n_clips(self):
return sum(optimizer._n_clips for optimizer in self.suboptimizers.values())


# Code below is an implementation of https://arxiv.org/pdf/1804.04235.pdf
# inspired but modified from https://github.com/DeadAt0m/adafactor-pytorch
Expand Down
2 changes: 1 addition & 1 deletion tools/generate_synth_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,7 @@ def generate_from_config(
@click.command(context_settings={'show_default': True})
@click.option('--config_path', type=Path, required=True)
@click.option('--vocab_size', type=int, default=300)
@click.option('--num_examples_train', type=int, default=10000)
@click.option('--num_examples_train', type=int, default=1000000)
@click.option('--num_examples_test', type=int, default=100)
@click.option('--start_seed', type=int, default=1)
@click.option(
Expand Down

0 comments on commit 3a16fa0

Please sign in to comment.