From 8bbd18aca4d88c72e768ce192de663ecbf9b97fe Mon Sep 17 00:00:00 2001 From: Vincent Nguyen Date: Wed, 3 Jan 2024 15:00:04 +0100 Subject: [PATCH] fix #2329 (#2544) --- onmt/decoders/transformer.py | 2 ++ onmt/inputters/text_corpus.py | 2 +- onmt/train_single.py | 1 + onmt/utils/scoring_utils.py | 29 +++++++++++++++++------------ 4 files changed, 21 insertions(+), 13 deletions(-) diff --git a/onmt/decoders/transformer.py b/onmt/decoders/transformer.py index 557c566e20..47e2f338dc 100644 --- a/onmt/decoders/transformer.py +++ b/onmt/decoders/transformer.py @@ -737,6 +737,8 @@ def _init_cache(self, enc_out): ) if hasattr(layer.self_attn, "rope"): layer.self_attn.rope = layer.self_attn.rope.to(enc_out.device) + layer.self_attn.cos = layer.self_attn.cos.to(enc_out.device) + layer.self_attn.sin = layer.self_attn.sin.to(enc_out.device) class TransformerLMDecoderLayer(TransformerDecoderLayerBase): diff --git a/onmt/inputters/text_corpus.py b/onmt/inputters/text_corpus.py index 5641daae06..af0b63b7dc 100644 --- a/onmt/inputters/text_corpus.py +++ b/onmt/inputters/text_corpus.py @@ -130,7 +130,7 @@ def get_corpora(opts, task=CorpusTask.TRAIN, src=None, tgt=None, align=None): corpora_dict[CorpusName.VALID] = ParallelCorpus( CorpusName.VALID, opts.data[CorpusName.VALID]["path_src"], - opts.data[CorpusName.VALID]["path_tgt"], + opts.data[CorpusName.VALID]["path_tgt"] if tgt is None else None, opts.data[CorpusName.VALID]["path_align"], n_src_feats=opts.n_src_feats, src_feats_defaults=opts.src_feats_defaults, diff --git a/onmt/train_single.py b/onmt/train_single.py index 967f4bd552..efd76a2752 100644 --- a/onmt/train_single.py +++ b/onmt/train_single.py @@ -34,6 +34,7 @@ def prepare_transforms_vocabs(opt, transforms_cls): opt.transforms = validset_transforms if opt.data.get("valid", {}).get("tgt_prefix", None): opt.tgt_prefix = opt.data.get("valid", {}).get("tgt_prefix", None) + opt.tgt_file_prefix = True if opt.data.get("valid", {}).get("src_prefix", None): opt.src_prefix = opt.data.get("valid", {}).get("src_prefix", None) if opt.data.get("valid", {}).get("tgt_suffix", None): diff --git a/onmt/utils/scoring_utils.py b/onmt/utils/scoring_utils.py index a0553d9b2a..8e1e845718 100644 --- a/onmt/utils/scoring_utils.py +++ b/onmt/utils/scoring_utils.py @@ -2,7 +2,7 @@ import os from onmt.utils.parse import ArgumentParser from onmt.translate import GNMTGlobalScorer, Translator -from onmt.opts import translate_opts +from onmt.opts import config_opts, translate_opts from onmt.constants import CorpusTask from onmt.inputters.dynamic_iterator import build_dynamic_dataset_iter from onmt.transforms import get_transforms_cls, make_transforms, TransformPipe @@ -49,12 +49,16 @@ def translate(self, model, gpu_rank, step): # Translator # # ########## # - # Set translation options + # Set "default" translation options on empty cfgfile parser = ArgumentParser() + config_opts(parser) translate_opts(parser) base_args = ["-model", "dummy"] + ["-src", "dummy"] opt = parser.parse_args(base_args) opt.gpu = gpu_rank + if hasattr(self.opt, "tgt_file_prefix"): + opt.tgt_file_prefix = self.opt.tgt_file_prefix + opt.beam_size = 1 # prevent OOM when GPU is almost full at training ArgumentParser.validate_translate_opts(opt) # Build translator from options @@ -85,25 +89,26 @@ def translate(self, model, gpu_rank, step): model_opt.num_workers = 0 model_opt.tgt = None + # Retrieve raw references and sources + with codecs.open( + model_opt.data["valid"]["path_tgt"], "r", encoding="utf-8" + ) as f: + raw_refs = [line.strip("\n") for line in f if line.strip("\n")] + with codecs.open( + model_opt.data["valid"]["path_src"], "r", encoding="utf-8" + ) as f: + raw_srcs = [line.strip("\n") for line in f if line.strip("\n")] + valid_iter = build_dynamic_dataset_iter( model_opt, transforms_cls, translator.vocabs, task=CorpusTask.VALID, + tgt="", # This force to clear the target side (needed when using tgt_file_prefix) copy=model_opt.copy_attn, device_id=opt.gpu, ) - # Retrieve raw references and sources - with codecs.open( - valid_iter.corpora_info["valid"]["path_tgt"], "r", encoding="utf-8" - ) as f: - raw_refs = [line.strip("\n") for line in f if line.strip("\n")] - with codecs.open( - valid_iter.corpora_info["valid"]["path_src"], "r", encoding="utf-8" - ) as f: - raw_srcs = [line.strip("\n") for line in f if line.strip("\n")] - # ########### # # Predictions # # ########### #