From 5f036ab87a792c069d7d22cf0812343c4c29ce5f Mon Sep 17 00:00:00 2001 From: l-k-11235 Date: Thu, 4 Apr 2024 16:57:35 +0200 Subject: [PATCH 1/3] add warm-up method for inference engines --- onmt/inference_engine.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/onmt/inference_engine.py b/onmt/inference_engine.py index b088f497ef..1998abaa51 100755 --- a/onmt/inference_engine.py +++ b/onmt/inference_engine.py @@ -155,6 +155,13 @@ def __init__(self, opt): self.transforms_cls = get_transforms_cls(opt._all_transform) self.vocabs = self.translator.vocabs + def warm_up(self): + from onmt.translate.translator import build_translator + + self.translator = build_translator( + self.opt, self.device_id, logger=self.logger, report_score=True + ) + def _translate(self, infer_iter): scores, preds = self.translator._translate( infer_iter, infer_iter.transforms, self.opt.attn_debug, self.opt.align_debug From 511288131b731e155fc5f189fca4835c7f7d6763 Mon Sep 17 00:00:00 2001 From: l-k-11235 Date: Mon, 8 Apr 2024 15:44:45 +0200 Subject: [PATCH 2/3] removed warm-up method for inference engines --- onmt/inference_engine.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/onmt/inference_engine.py b/onmt/inference_engine.py index 1998abaa51..b088f497ef 100755 --- a/onmt/inference_engine.py +++ b/onmt/inference_engine.py @@ -155,13 +155,6 @@ def __init__(self, opt): self.transforms_cls = get_transforms_cls(opt._all_transform) self.vocabs = self.translator.vocabs - def warm_up(self): - from onmt.translate.translator import build_translator - - self.translator = build_translator( - self.opt, self.device_id, logger=self.logger, report_score=True - ) - def _translate(self, infer_iter): scores, preds = self.translator._translate( infer_iter, infer_iter.transforms, self.opt.attn_debug, self.opt.align_debug From 578f7197a336d930e82261661c1671721151afe2 Mon Sep 17 00:00:00 2001 From: l-k-11235 Date: Mon, 8 Apr 2024 16:17:13 +0200 Subject: [PATCH 3/3] fixed 'beam size' value in greed search and allowed to define a stop token --- onmt/opts.py | 7 +++++++ onmt/translate/greedy_search.py | 1 + onmt/translate/translator.py | 7 +++++-- 3 files changed, 13 insertions(+), 2 deletions(-) diff --git a/onmt/opts.py b/onmt/opts.py index 21abd96a3d..4a594237aa 100644 --- a/onmt/opts.py +++ b/onmt/opts.py @@ -1793,6 +1793,13 @@ def _add_decoding_opts(parser): "(or the identified source token does not exist in " "the table), then it will copy the source token.", ) + group.add( + "--stop_token", + "-stop_token", + type=str, + default="", + help="Stop token to be used instead of the EOS token.", + ) def translate_opts(parser): diff --git a/onmt/translate/greedy_search.py b/onmt/translate/greedy_search.py index 5ca6ca0544..925a5a5fb1 100644 --- a/onmt/translate/greedy_search.py +++ b/onmt/translate/greedy_search.py @@ -160,6 +160,7 @@ def __init__( self.topk_scores = None self.beam_size = beam_size self.n_best = n_best + self.parallel_paths = 1 def initialize( self, enc_out, src_len, src_map=None, device=None, target_prefix=None diff --git a/onmt/translate/translator.py b/onmt/translate/translator.py index 98caf3d0fd..a9aef6699c 100644 --- a/onmt/translate/translator.py +++ b/onmt/translate/translator.py @@ -46,6 +46,8 @@ def build_translator(opt, device_id=0, report_score=True, logger=None, out_file= report_score=report_score, logger=logger, ) + if opt.stop_token is DefaultTokens.SEP: + translator.eos = translator.vocabs["tgt"].lookup_token("<0x0A>") else: translator = Translator.from_opt( model, @@ -142,7 +144,7 @@ def __init__( self._tgt_pad_idx = vocabs["tgt"].lookup_token(DefaultTokens.PAD) self._tgt_bos_idx = vocabs["tgt"].lookup_token(DefaultTokens.BOS) self._tgt_unk_idx = vocabs["tgt"].lookup_token(DefaultTokens.UNK) - self._tgt_sep_idx = vocabs["tgt"].lookup_token(DefaultTokens.SEP) + self._tgt_sep_idx = vocabs["tgt"].lookup_token("<0x0A>") self._tgt_start_with = vocabs["tgt"].lookup_token(vocabs["decoder_start_token"]) self._tgt_vocab_len = len(self._tgt_vocab) @@ -907,7 +909,6 @@ def _translate_batch_with_strategy(self, batch, decode_strategy): # (0) Prep the components of the search. use_src_map = self.copy_attn parallel_paths = decode_strategy.parallel_paths # beam_size - batch_size = len(batch["srclen"]) # (1) Run the encoder on the src. @@ -1022,6 +1023,7 @@ def translate_batch(self, batch, attn_debug, scoring=False): max_length = 0 if scoring else self.max_length with torch.no_grad(): if self.sample_from_topk != 0 or self.sample_from_topp != 0: + self.beam_size = 1 decode_strategy = GreedySearchLM( pad=self._tgt_pad_idx, bos=self._tgt_bos_idx, @@ -1128,6 +1130,7 @@ def _translate_batch_with_strategy(self, batch, decode_strategy, left_pad=True): # (4) Begin decoding step by step: # beg_time = time() for step in range(decode_strategy.max_length): + print("# step", step) decoder_input = ( src if step == 0 else decode_strategy.current_predictions.view(-1, 1, 1) )