diff --git a/mammoth/models/model.py b/mammoth/models/model.py index 5e950b36..912945fc 100644 --- a/mammoth/models/model.py +++ b/mammoth/models/model.py @@ -75,10 +75,10 @@ def forward(self, src, decoder_input, src_mask, metadata=None): return_embeddings=True, ) - # encoder_output, alphas = self.attention_bridge(encoder_output, src_mask) - # if self.attention_bridge.is_fixed_length: - # # turn off masking in the transformer decoder - # src_mask = None + encoder_output, alphas = self.attention_bridge(encoder_output, src_mask) + if self.attention_bridge.is_fixed_length: + # turn off masking in the transformer decoder + src_mask = None retval = active_decoder( decoder_input, diff --git a/tools/config_config.py b/tools/config_config.py index 82a14a76..6077db7f 100644 --- a/tools/config_config.py +++ b/tools/config_config.py @@ -822,7 +822,7 @@ def _add_language_pair(opts, src_lang, tgt_lang, src_path, tgt_path, valid_src_p if 'tasks' not in opts.in_config[0]: opts.in_config[0]['tasks'] = dict() tasks_section = opts.in_config[0]['tasks'] - key = f'train_{src_lang}-{tgt_lang}' + key = f'{src_lang}-{tgt_lang}' if key not in tasks_section: tasks_section[key] = dict() tasks_section[key]['src_tgt'] = f'{src_lang}-{tgt_lang}' diff --git a/tools/iterate_tasks.py b/tools/iterate_tasks.py index f3498661..ea5a137d 100644 --- a/tools/iterate_tasks.py +++ b/tools/iterate_tasks.py @@ -16,7 +16,7 @@ '--src', type=str, default=None, - help='Template for source file paths. Use varibles src_lang and tgt_lang.', + help='Template for source file paths. Use varibles src_lang, tgt_lang, and task_id.', ) @click.option( '--output', @@ -27,7 +27,7 @@ @click.option( '--flag', is_flag=True, - help='Prefix with "--task_id". Implied by --src and --output.' + help='Prefix output with "--task_id". Implied by --src and --output.' ) def main(config_path, match, src, output, flag): if src is not None or output is not None: