Skip to content

Commit

Permalink
merged
Browse files Browse the repository at this point in the history
  • Loading branch information
Joseph Attieh authored and Joseph Attieh committed Oct 7, 2024
2 parents 70bc59d + 7229141 commit bc8735a
Show file tree
Hide file tree
Showing 8 changed files with 56 additions and 113 deletions.
1 change: 0 additions & 1 deletion mammoth/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@ def get_attention_layers_kwargs(
kwargs.update({
'dim': model_opts.model_dim,
'depth': depth,
'heads': model_opts.heads,
'causal': causal,
'cross_attend': cross_attend,
'pre_norm_has_final_norm': pre_norm_has_final_norm,
Expand Down
29 changes: 5 additions & 24 deletions mammoth/modules/transformer_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ def __init__(
"""
super(TransformerDecoderLayerBase, self).__init__()
assert not full_context_alignment, 'alignment is obsolete'
assert alignment_heads == 0, 'alignment is obsolete'

if self_attn_type == "scaled-dot":
self.self_attn = MultiHeadedAttention(
Expand All @@ -86,8 +88,6 @@ def __init__(
self.layer_norm_3 = nn.LayerNorm(d_model, eps=1e-6)
self.layer_norm_4 = nn.LayerNorm(d_model, eps=1e-6)
self.drop = nn.Dropout(dropout)
self.full_context_alignment = full_context_alignment
self.alignment_heads = alignment_heads

def forward(self, *args, **kwargs):
"""Extend `_forward` for (possibly) multiple decoder pass:
Expand All @@ -97,7 +97,6 @@ def forward(self, *args, **kwargs):
Args:
* All arguments of _forward.
with_align (bool): whether return alignment attention.
Returns:
(FloatTensor, FloatTensor, FloatTensor or None):
Expand All @@ -106,22 +105,9 @@ def forward(self, *args, **kwargs):
* top_attn ``(batch_size, T, src_len)``
* attn_align ``(batch_size, T, src_len)`` or None
"""
with_align = kwargs.pop("with_align", False)
output, attns = self._forward(*args, **kwargs)
top_attn = attns[:, 0, :, :].contiguous()
attn_align = None
if with_align:
if self.full_context_alignment:
# return _, (B, Q_len, K_len)
_, attns = self._forward(*args, **kwargs, future=True)

if self.alignment_heads > 0:
attns = attns[:, : self.alignment_heads, :, :].contiguous()
# layer average attention across heads, get ``(B, Q, K)``
# Case 1: no full_context, no align heads -> layer avg baseline
# Case 2: no full_context, 1 align heads -> guided align
# Case 3: full_context, 1 align heads -> full cte guided align
attn_align = attns.mean(dim=1)
return output, top_attn, attn_align

def update_dropout(self, dropout, attention_dropout):
Expand Down Expand Up @@ -317,9 +303,9 @@ def from_opts(cls, opts, embeddings, is_on_top=False):
embeddings,
opts.max_relative_positions,
opts.aan_useffn,
opts.full_context_alignment,
opts.alignment_layer,
alignment_heads=opts.alignment_heads,
False,
None,
alignment_heads=0.,
pos_ffn_activation_fn=opts.pos_ffn_activation_fn,
layer_norm_module=(
nn.LayerNorm(opts.model_dim, eps=1e-6) if is_on_top
Expand Down Expand Up @@ -489,7 +475,6 @@ def forward(
src_max_len = memory_bank.size(1)
src_pad_mask = ~sequence_mask(memory_lengths, src_max_len).unsqueeze(1)

with_align = kwargs.pop("with_align", False)
attn_aligns = []

for i, layer in enumerate(self._get_layers()):
Expand All @@ -505,7 +490,6 @@ def forward(
tgt_pad_mask,
layer_cache=layer_cache,
step=step,
with_align=with_align,
)
if attn_align is not None:
attn_aligns.append(attn_align)
Expand All @@ -517,9 +501,6 @@ def forward(
attns = {"std": attn}
if self._copy:
attns["copy"] = attn
if with_align:
attns["align"] = attn_aligns[self.alignment_layer] # `(B, Q, K)`
# attns["align"] = torch.stack(attn_aligns, 0).mean(0) # All avg

# TODO change the way attns is returned dict => list or tuple (onnx)
return dec_outs, attns
Expand Down
43 changes: 11 additions & 32 deletions mammoth/opts.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,11 @@ def model_opts(parser):
"For more detailed information, see: "
"https://arxiv.org/pdf/1803.02155.pdf",
)
group.add('--heads', '-heads', type=int, default=8, help='Number of heads for transformer self-attention')
group.add(
'--heads', '-heads', type=int, default=8,
help='Number of heads for transformer self-attention. '
' Semi-obsolete: not used for x-transformers, only used for some attention bridge configuations.'
)
group.add(
"-x_transformers_opts",
"--x_transformers_opts",
Expand All @@ -270,35 +274,6 @@ def model_opts(parser):
" https://github.com/lucidrains/x-transformers/blob/main/README.md ."
)

# Alignement options
# TODO is this actually in use?
group = parser.add_argument_group('Model - Alignment')
group.add(
'--lambda_align',
'-lambda_align',
type=float,
default=0.0,
help="Lambda value for alignement loss of Garg et al (2019)"
"For more detailed information, see: "
"https://arxiv.org/abs/1909.02074",
)
group.add(
'--alignment_layer', '-alignment_layer', type=int, default=-3, help='Layer number which has to be supervised.'
)
group.add(
'--alignment_heads',
'-alignment_heads',
type=int,
default=0,
help='N. of cross attention heads per layer to supervised with',
)
group.add(
'--full_context_alignment',
'-full_context_alignment',
action="store_true",
help='Whether alignment is conditioned on full target context.',
)

# Generator and loss options.
group = parser.add_argument_group('Generator')
group.add(
Expand Down Expand Up @@ -397,6 +372,10 @@ def _add_train_general_opts(parser):
default=None,
help='Criteria to use for early stopping.',
)
group.add(
'--max_nan_batches', '-max_nan_batches', type=int, default=5,
help='Number of batches that may be skipped due to loss blowout.'
)

# GPU
group = parser.add_argument_group('Computation Environment')
Expand Down Expand Up @@ -556,7 +535,7 @@ def _add_train_general_opts(parser):
'--max_grad_norm',
'-max_grad_norm',
type=float,
default=5,
default=1,
help="If the norm of the gradient vector exceeds this, "
"renormalize it to have the norm equal to "
"max_grad_norm",
Expand Down Expand Up @@ -679,7 +658,7 @@ def _add_train_general_opts(parser):
'-learning_rate',
type=float,
default=1.0,
help="Starting learning rate. Recommended settings: sgd = 1, adagrad = 0.1, adadelta = 1, adam = 0.001",
help="Starting learning rate. Recommended settings: sgd = TBD, adagrad = TBD, adadelta = TBD, adam = TBD",
)
group.add(
'--learning_rate_decay',
Expand Down
14 changes: 0 additions & 14 deletions mammoth/tests/pull_request_chk.sh
Original file line number Diff line number Diff line change
Expand Up @@ -88,20 +88,6 @@ ${PYTHON} onmt/bin/train.py \
[ "$?" -eq 0 ] || error_exit
echo "Succeeded" | tee -a ${LOG_FILE}

echo -n " [+] Testing NMT training w/ align..."
${PYTHON} onmt/bin/train.py \
-config ${DATA_DIR}/align_data.yaml \
-src_vocab $TMP_OUT_DIR/onmt.vocab.src \
-tgt_vocab $TMP_OUT_DIR/onmt.vocab.tgt \
-src_vocab_size 1000 \
-tgt_vocab_size 1000 \
-encoder_type transformer -decoder_type transformer \
-layers 4 -word_vec_size 16 -rnn_size 16 -heads 2 -transformer_ff 64 \
-lambda_align 0.05 -alignment_layer 2 -alignment_heads 0 \
-report_every 5 -train_steps 10 >> ${LOG_FILE} 2>&1
[ "$?" -eq 0 ] || error_exit
echo "Succeeded" | tee -a ${LOG_FILE}

echo -n " [+] Testing LM training..."
${PYTHON} onmt/bin/train.py \
-config ${DATA_DIR}/lm_data.yaml \
Expand Down
7 changes: 6 additions & 1 deletion mammoth/train_single.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,12 @@ def _get_model_opts(opts, frame_checkpoint=None):
def _build_valid_iter(opts, vocabs_dict, transforms_cls, task_queue_manager):
"""Build iterator used for validation."""
if not any(opts.tasks[corpus_id].get('path_valid_src', False) for corpus_id in opts.tasks.keys()):
logger.info("Validation set missing for: {}".format(
[
corpus_id for corpus_id in opts.tasks.keys()
if not opts.tasks[corpus_id].get('path_valid_src', False)
]
))
return None
logger.info("creating validation iterator")
valid_iter = DynamicDatasetIter.from_opts(
Expand Down Expand Up @@ -180,7 +186,6 @@ def _train_iter():

train_iter = _train_iter()
# train_iter = iter_on_device(train_iter, device_context)
logger.info("Device {} - Valid iter".format(device_context.id))
valid_iter = _build_valid_iter(opts, vocabs_dict, transforms_cls, task_queue_manager)

if len(opts.gpu_ranks):
Expand Down
36 changes: 26 additions & 10 deletions mammoth/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ def build_trainer(
task_queue_manager=task_queue_manager,
report_stats_from_parameters=opts.report_stats_from_parameters,
report_training_accuracy=opts.report_training_accuracy,
max_nan_batches=opts.max_nan_batches,
)
return trainer

Expand Down Expand Up @@ -149,6 +150,7 @@ def __init__(
task_queue_manager=None,
report_stats_from_parameters=False,
report_training_accuracy=False,
max_nan_batches=0,
):
# Basic attributes.
self.model = model
Expand All @@ -171,6 +173,8 @@ def __init__(
self.earlystopper = earlystopper
self.dropout = dropout
self.dropout_steps = dropout_steps
self.max_nan_batches = max_nan_batches
self.nan_batches = 0

self.task_queue_manager = task_queue_manager

Expand Down Expand Up @@ -382,27 +386,34 @@ def validate(self, valid_iter, moving_average=None, task=None):
for batch, metadata, _ in valid_iter:
if stats is None:
stats = mammoth.utils.Statistics()

src, src_lengths = batch.src if isinstance(batch.src, tuple) else (batch.src, None)
decoder_input = batch.tgt[:-1]
target = batch.tgt[1:]


stats.n_src_words += batch.src.mask.sum().item()
src = batch.src.tensor
src_mask = batch.src.mask
decoder_input = batch.tgt.tensor[:-1]
target = batch.tgt.tensor[1:]

with torch.cuda.amp.autocast(enabled=self.optim.amp):
# F-prop through the model.
logits, decoder_output = valid_model(
src,
decoder_input,
src_lengths,
rearrange(src, 't b 1 -> b t'),
rearrange(decoder_input, 't b 1 -> b t'),
rearrange(src_mask, 't b -> b t'),
metadata=metadata,
)
logits = rearrange(logits, 'b t i -> t b i')
decoder_output = rearrange(decoder_output, 'b t d -> t b d')

# Compute loss.
loss = self.loss_functions[metadata.tgt_lang](logits, target)
loss = self.loss_functions[metadata.tgt_lang](
rearrange(logits, 't b i -> (t b) i'),
rearrange(target, 't b 1 -> (t b)'),
)

# Update statistics.
padding_idx = self.loss_functions[metadata.tgt_lang].ignore_index
batch_stats = Statistics.from_loss_logits_target(
loss,
loss.item(),
logits,
target,
padding_idx,
Expand Down Expand Up @@ -476,6 +487,8 @@ def _gradient_accumulation(

try:
if loss is not None:
if torch.isnan(loss):
raise Exception('Loss blowout')
self.optim.backward(loss)

if self.report_training_accuracy:
Expand All @@ -500,6 +513,9 @@ def _gradient_accumulation(
except Exception:
traceback.print_exc()
logger.info("At step %d, we removed a batch - accum %d", self.optim.training_step, k)
self.nan_batches += 1
if self.nan_batches >= self.max_nan_batches:
raise Exception('Exceeded allowed --max_nan_batches.')
if len(seen_comm_batches) != 1:
logger.warning('Communication batches out of synch with batch accumulation')

Expand Down
35 changes: 6 additions & 29 deletions mammoth/utils/parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,13 +71,6 @@ def _validate_tasks(cls, opts):
cls._validate_file(path_src, info=f'{cname}/path_src')
cls._validate_file(path_tgt, info=f'{cname}/path_tgt')
"""
path_align = corpus.get('path_align', None)
if path_align is None:
if hasattr(opts, 'lambda_align') and opts.lambda_align > 0.0:
raise ValueError(f'Corpus {cname} alignment file path are required when lambda_align > 0.0')
corpus['path_align'] = None
else:
cls._validate_file(path_align, info=f'{cname}/path_align')
# Check prefix: will be used when use prefix transform
src_prefix = corpus.get('src_prefix', None)
tgt_prefix = corpus.get('tgt_prefix', None)
Expand Down Expand Up @@ -169,11 +162,6 @@ def _get_all_transform(cls, opts):
_transforms = set(corpus['transforms'])
if len(_transforms) != 0:
all_transforms.update(_transforms)
if hasattr(opts, 'lambda_align') and opts.lambda_align > 0.0:
if not all_transforms.isdisjoint({'sentencepiece', 'bpe', 'onmt_tokenize'}):
raise ValueError('lambda_align is not compatible with on-the-fly tokenization.')
if not all_transforms.isdisjoint({'tokendrop', 'prefix', 'denoising'}):
raise ValueError('lambda_align is not compatible yet with potential token deletion/addition.')
opts._all_transform = all_transforms

@classmethod
Expand Down Expand Up @@ -265,11 +253,6 @@ def update_model_opts(cls, model_opts):
if hasattr(model_opts, 'fix_word_vecs_dec'):
model_opts.freeze_word_vecs_dec = model_opts.fix_word_vecs_dec

if model_opts.alignment_layer is None:
model_opts.alignment_layer = -2
model_opts.lambda_align = 0.0
model_opts.full_context_alignment = False

@classmethod
def validate_x_transformers_opts(cls, opts):
if not opts.x_transformers_opts:
Expand All @@ -279,7 +262,6 @@ def validate_x_transformers_opts(cls, opts):
for overwritten_key in (
'dim',
'depth',
'heads',
'causal',
'cross_attend',
'pre_norm_has_final_norm',
Expand Down Expand Up @@ -318,17 +300,6 @@ def validate_model_opts(cls, model_opts):
# if model_opts.share_embeddings:
# if model_opts.model_type != "text":
# raise AssertionError("--share_embeddings requires --model_type text.")
if model_opts.lambda_align > 0.0:
assert (
model_opts.alignment_layer < model_opts.dec_layers
and model_opts.alignment_layer >= -model_opts.dec_layers
), "N° alignment_layer should be smaller than number of layers."
logger.info(
"Joint learn alignment at layer [{}] "
"with {} heads in full_context '{}'.".format(
model_opts.alignment_layer, model_opts.alignment_heads, model_opts.full_context_alignment
)
)

cls.validate_x_transformers_opts(model_opts)

Expand Down Expand Up @@ -365,6 +336,12 @@ def validate_train_opts(cls, opts):
opts.accum_steps
), 'Number of accum_count values must match number of accum_steps'

if opts.decay_method not in {'linear_warmup', 'none'}:
logger.warn(
'Note that decay methods other than "linear_warmup" and "none" have weird scaling.'
' Did you tune the learning rate for this decay method?'
)

# TODO: do we want to remove that completely?
# if opts.update_vocab:
# assert opts.train_from, "-update_vocab needs -train_from option"
Expand Down
4 changes: 2 additions & 2 deletions mammoth/utils/statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class Statistics(object):
* elapsed time
"""

def __init__(self, loss=0, n_words=0, n_correct=0):
def __init__(self, loss=0, n_words=0, n_correct=None):
self.loss = loss
self.n_words = n_words
self.n_correct = n_correct
Expand Down Expand Up @@ -136,7 +136,7 @@ def update_from_parameters(self, named_parameters):

def accuracy(self):
"""compute accuracy"""
if self.n_correct and self.n_words:
if self.n_correct is not None and self.n_words:
return 100 * (self.n_correct / self.n_words)
else:
return None
Expand Down

0 comments on commit bc8735a

Please sign in to comment.