From 4f6620c4e5fbd3f721fce307f2c3d88baeb82603 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stig-Arne=20Gr=C3=B6nroos?= Date: Mon, 7 Oct 2024 14:19:27 +0300 Subject: [PATCH] Remove more obsolete opts --- mammoth/model_builder.py | 1 - mammoth/modules/transformer_decoder.py | 29 ++++--------------- mammoth/opts.py | 39 +++++--------------------- mammoth/tests/pull_request_chk.sh | 14 --------- mammoth/train_single.py | 7 ++++- mammoth/utils/parse.py | 35 ++++------------------- mammoth/utils/statistics.py | 2 +- 7 files changed, 25 insertions(+), 102 deletions(-) diff --git a/mammoth/model_builder.py b/mammoth/model_builder.py index b8b0afa6..8564cf1c 100644 --- a/mammoth/model_builder.py +++ b/mammoth/model_builder.py @@ -62,7 +62,6 @@ def get_attention_layers_kwargs( kwargs.update({ 'dim': model_opts.model_dim, 'depth': depth, - 'heads': model_opts.heads, 'causal': causal, 'cross_attend': cross_attend, 'pre_norm_has_final_norm': pre_norm_has_final_norm, diff --git a/mammoth/modules/transformer_decoder.py b/mammoth/modules/transformer_decoder.py index b96f3eef..cecb4a5f 100644 --- a/mammoth/modules/transformer_decoder.py +++ b/mammoth/modules/transformer_decoder.py @@ -60,6 +60,8 @@ def __init__( """ super(TransformerDecoderLayerBase, self).__init__() + assert not full_context_alignment, 'alignment is obsolete' + assert alignment_heads == 0, 'alignment is obsolete' if self_attn_type == "scaled-dot": self.self_attn = MultiHeadedAttention( @@ -86,8 +88,6 @@ def __init__( self.layer_norm_3 = nn.LayerNorm(d_model, eps=1e-6) self.layer_norm_4 = nn.LayerNorm(d_model, eps=1e-6) self.drop = nn.Dropout(dropout) - self.full_context_alignment = full_context_alignment - self.alignment_heads = alignment_heads def forward(self, *args, **kwargs): """Extend `_forward` for (possibly) multiple decoder pass: @@ -97,7 +97,6 @@ def forward(self, *args, **kwargs): Args: * All arguments of _forward. - with_align (bool): whether return alignment attention. Returns: (FloatTensor, FloatTensor, FloatTensor or None): @@ -106,22 +105,9 @@ def forward(self, *args, **kwargs): * top_attn ``(batch_size, T, src_len)`` * attn_align ``(batch_size, T, src_len)`` or None """ - with_align = kwargs.pop("with_align", False) output, attns = self._forward(*args, **kwargs) top_attn = attns[:, 0, :, :].contiguous() attn_align = None - if with_align: - if self.full_context_alignment: - # return _, (B, Q_len, K_len) - _, attns = self._forward(*args, **kwargs, future=True) - - if self.alignment_heads > 0: - attns = attns[:, : self.alignment_heads, :, :].contiguous() - # layer average attention across heads, get ``(B, Q, K)`` - # Case 1: no full_context, no align heads -> layer avg baseline - # Case 2: no full_context, 1 align heads -> guided align - # Case 3: full_context, 1 align heads -> full cte guided align - attn_align = attns.mean(dim=1) return output, top_attn, attn_align def update_dropout(self, dropout, attention_dropout): @@ -317,9 +303,9 @@ def from_opts(cls, opts, embeddings, is_on_top=False): embeddings, opts.max_relative_positions, opts.aan_useffn, - opts.full_context_alignment, - opts.alignment_layer, - alignment_heads=opts.alignment_heads, + False, + None, + alignment_heads=0., pos_ffn_activation_fn=opts.pos_ffn_activation_fn, layer_norm_module=( nn.LayerNorm(opts.model_dim, eps=1e-6) if is_on_top @@ -489,7 +475,6 @@ def forward( src_max_len = memory_bank.size(1) src_pad_mask = ~sequence_mask(memory_lengths, src_max_len).unsqueeze(1) - with_align = kwargs.pop("with_align", False) attn_aligns = [] for i, layer in enumerate(self._get_layers()): @@ -505,7 +490,6 @@ def forward( tgt_pad_mask, layer_cache=layer_cache, step=step, - with_align=with_align, ) if attn_align is not None: attn_aligns.append(attn_align) @@ -517,9 +501,6 @@ def forward( attns = {"std": attn} if self._copy: attns["copy"] = attn - if with_align: - attns["align"] = attn_aligns[self.alignment_layer] # `(B, Q, K)` - # attns["align"] = torch.stack(attn_aligns, 0).mean(0) # All avg # TODO change the way attns is returned dict => list or tuple (onnx) return dec_outs, attns diff --git a/mammoth/opts.py b/mammoth/opts.py index 81248c80..5d9f805c 100644 --- a/mammoth/opts.py +++ b/mammoth/opts.py @@ -255,7 +255,11 @@ def model_opts(parser): "For more detailed information, see: " "https://arxiv.org/pdf/1803.02155.pdf", ) - group.add('--heads', '-heads', type=int, default=8, help='Number of heads for transformer self-attention') + group.add( + '--heads', '-heads', type=int, default=8, + help='Number of heads for transformer self-attention. ' + ' Semi-obsolete: not used for x-transformers, only used for some attention bridge configuations.' + ) group.add( "-x_transformers_opts", "--x_transformers_opts", @@ -268,35 +272,6 @@ def model_opts(parser): " https://github.com/lucidrains/x-transformers/blob/main/README.md ." ) - # Alignement options - # TODO is this actually in use? - group = parser.add_argument_group('Model - Alignment') - group.add( - '--lambda_align', - '-lambda_align', - type=float, - default=0.0, - help="Lambda value for alignement loss of Garg et al (2019)" - "For more detailed information, see: " - "https://arxiv.org/abs/1909.02074", - ) - group.add( - '--alignment_layer', '-alignment_layer', type=int, default=-3, help='Layer number which has to be supervised.' - ) - group.add( - '--alignment_heads', - '-alignment_heads', - type=int, - default=0, - help='N. of cross attention heads per layer to supervised with', - ) - group.add( - '--full_context_alignment', - '-full_context_alignment', - action="store_true", - help='Whether alignment is conditioned on full target context.', - ) - # Generator and loss options. group = parser.add_argument_group('Generator') group.add( @@ -558,7 +533,7 @@ def _add_train_general_opts(parser): '--max_grad_norm', '-max_grad_norm', type=float, - default=5, + default=1, help="If the norm of the gradient vector exceeds this, " "renormalize it to have the norm equal to " "max_grad_norm", @@ -681,7 +656,7 @@ def _add_train_general_opts(parser): '-learning_rate', type=float, default=1.0, - help="Starting learning rate. Recommended settings: sgd = 1, adagrad = 0.1, adadelta = 1, adam = 0.001", + help="Starting learning rate. Recommended settings: sgd = TBD, adagrad = TBD, adadelta = TBD, adam = TBD", ) group.add( '--learning_rate_decay', diff --git a/mammoth/tests/pull_request_chk.sh b/mammoth/tests/pull_request_chk.sh index 04ea54a9..b07c6347 100755 --- a/mammoth/tests/pull_request_chk.sh +++ b/mammoth/tests/pull_request_chk.sh @@ -88,20 +88,6 @@ ${PYTHON} onmt/bin/train.py \ [ "$?" -eq 0 ] || error_exit echo "Succeeded" | tee -a ${LOG_FILE} -echo -n " [+] Testing NMT training w/ align..." -${PYTHON} onmt/bin/train.py \ - -config ${DATA_DIR}/align_data.yaml \ - -src_vocab $TMP_OUT_DIR/onmt.vocab.src \ - -tgt_vocab $TMP_OUT_DIR/onmt.vocab.tgt \ - -src_vocab_size 1000 \ - -tgt_vocab_size 1000 \ - -encoder_type transformer -decoder_type transformer \ - -layers 4 -word_vec_size 16 -rnn_size 16 -heads 2 -transformer_ff 64 \ - -lambda_align 0.05 -alignment_layer 2 -alignment_heads 0 \ - -report_every 5 -train_steps 10 >> ${LOG_FILE} 2>&1 -[ "$?" -eq 0 ] || error_exit -echo "Succeeded" | tee -a ${LOG_FILE} - echo -n " [+] Testing LM training..." ${PYTHON} onmt/bin/train.py \ -config ${DATA_DIR}/lm_data.yaml \ diff --git a/mammoth/train_single.py b/mammoth/train_single.py index 412f2011..036cde05 100644 --- a/mammoth/train_single.py +++ b/mammoth/train_single.py @@ -44,6 +44,12 @@ def _get_model_opts(opts, frame_checkpoint=None): def _build_valid_iter(opts, vocabs_dict, transforms_cls, task_queue_manager): """Build iterator used for validation.""" if not any(opts.tasks[corpus_id].get('path_valid_src', False) for corpus_id in opts.tasks.keys()): + logger.info("Validation set missing for: {}".format( + [ + corpus_id for corpus_id in opts.tasks.keys() + if not opts.tasks[corpus_id].get('path_valid_src', False) + ] + )) return None logger.info("creating validation iterator") valid_iter = DynamicDatasetIter.from_opts( @@ -176,7 +182,6 @@ def _train_iter(): train_iter = _train_iter() # train_iter = iter_on_device(train_iter, device_context) - logger.info("Device {} - Valid iter".format(device_context.id)) valid_iter = _build_valid_iter(opts, vocabs_dict, transforms_cls, task_queue_manager) if len(opts.gpu_ranks): diff --git a/mammoth/utils/parse.py b/mammoth/utils/parse.py index c0b81e9d..eacb44f2 100644 --- a/mammoth/utils/parse.py +++ b/mammoth/utils/parse.py @@ -71,13 +71,6 @@ def _validate_tasks(cls, opts): cls._validate_file(path_src, info=f'{cname}/path_src') cls._validate_file(path_tgt, info=f'{cname}/path_tgt') """ - path_align = corpus.get('path_align', None) - if path_align is None: - if hasattr(opts, 'lambda_align') and opts.lambda_align > 0.0: - raise ValueError(f'Corpus {cname} alignment file path are required when lambda_align > 0.0') - corpus['path_align'] = None - else: - cls._validate_file(path_align, info=f'{cname}/path_align') # Check prefix: will be used when use prefix transform src_prefix = corpus.get('src_prefix', None) tgt_prefix = corpus.get('tgt_prefix', None) @@ -169,11 +162,6 @@ def _get_all_transform(cls, opts): _transforms = set(corpus['transforms']) if len(_transforms) != 0: all_transforms.update(_transforms) - if hasattr(opts, 'lambda_align') and opts.lambda_align > 0.0: - if not all_transforms.isdisjoint({'sentencepiece', 'bpe', 'onmt_tokenize'}): - raise ValueError('lambda_align is not compatible with on-the-fly tokenization.') - if not all_transforms.isdisjoint({'tokendrop', 'prefix', 'denoising'}): - raise ValueError('lambda_align is not compatible yet with potential token deletion/addition.') opts._all_transform = all_transforms @classmethod @@ -265,11 +253,6 @@ def update_model_opts(cls, model_opts): if hasattr(model_opts, 'fix_word_vecs_dec'): model_opts.freeze_word_vecs_dec = model_opts.fix_word_vecs_dec - if model_opts.alignment_layer is None: - model_opts.alignment_layer = -2 - model_opts.lambda_align = 0.0 - model_opts.full_context_alignment = False - @classmethod def validate_x_transformers_opts(cls, opts): if not opts.x_transformers_opts: @@ -279,7 +262,6 @@ def validate_x_transformers_opts(cls, opts): for overwritten_key in ( 'dim', 'depth', - 'heads', 'causal', 'cross_attend', 'pre_norm_has_final_norm', @@ -318,17 +300,6 @@ def validate_model_opts(cls, model_opts): # if model_opts.share_embeddings: # if model_opts.model_type != "text": # raise AssertionError("--share_embeddings requires --model_type text.") - if model_opts.lambda_align > 0.0: - assert ( - model_opts.alignment_layer < model_opts.dec_layers - and model_opts.alignment_layer >= -model_opts.dec_layers - ), "N° alignment_layer should be smaller than number of layers." - logger.info( - "Joint learn alignment at layer [{}] " - "with {} heads in full_context '{}'.".format( - model_opts.alignment_layer, model_opts.alignment_heads, model_opts.full_context_alignment - ) - ) cls.validate_x_transformers_opts(model_opts) @@ -365,6 +336,12 @@ def validate_train_opts(cls, opts): opts.accum_steps ), 'Number of accum_count values must match number of accum_steps' + if opts.decay_method not in {'linear_warmup', 'none'}: + logger.warn( + 'Note that decay methods other than "linear_warmup" and "none" have weird scaling.' + ' Did you tune the learning rate for this decay method?' + ) + # TODO: do we want to remove that completely? # if opts.update_vocab: # assert opts.train_from, "-update_vocab needs -train_from option" diff --git a/mammoth/utils/statistics.py b/mammoth/utils/statistics.py index ca3a0021..cd4ad3e4 100644 --- a/mammoth/utils/statistics.py +++ b/mammoth/utils/statistics.py @@ -136,7 +136,7 @@ def update_from_parameters(self, named_parameters): def accuracy(self): """compute accuracy""" - if self.n_correct and self.n_words: + if self.n_correct is not None and self.n_words: return 100 * (self.n_correct / self.n_words) else: return None