diff --git a/onmt/opts.py b/onmt/opts.py index 21abd96a3d..4a594237aa 100644 --- a/onmt/opts.py +++ b/onmt/opts.py @@ -1793,6 +1793,13 @@ def _add_decoding_opts(parser): "(or the identified source token does not exist in " "the table), then it will copy the source token.", ) + group.add( + "--stop_token", + "-stop_token", + type=str, + default="", + help="Stop token to be used instead of the EOS token.", + ) def translate_opts(parser): diff --git a/onmt/translate/greedy_search.py b/onmt/translate/greedy_search.py index 5ca6ca0544..925a5a5fb1 100644 --- a/onmt/translate/greedy_search.py +++ b/onmt/translate/greedy_search.py @@ -160,6 +160,7 @@ def __init__( self.topk_scores = None self.beam_size = beam_size self.n_best = n_best + self.parallel_paths = 1 def initialize( self, enc_out, src_len, src_map=None, device=None, target_prefix=None diff --git a/onmt/translate/translator.py b/onmt/translate/translator.py index 98caf3d0fd..a9aef6699c 100644 --- a/onmt/translate/translator.py +++ b/onmt/translate/translator.py @@ -46,6 +46,8 @@ def build_translator(opt, device_id=0, report_score=True, logger=None, out_file= report_score=report_score, logger=logger, ) + if opt.stop_token is DefaultTokens.SEP: + translator.eos = translator.vocabs["tgt"].lookup_token("<0x0A>") else: translator = Translator.from_opt( model, @@ -142,7 +144,7 @@ def __init__( self._tgt_pad_idx = vocabs["tgt"].lookup_token(DefaultTokens.PAD) self._tgt_bos_idx = vocabs["tgt"].lookup_token(DefaultTokens.BOS) self._tgt_unk_idx = vocabs["tgt"].lookup_token(DefaultTokens.UNK) - self._tgt_sep_idx = vocabs["tgt"].lookup_token(DefaultTokens.SEP) + self._tgt_sep_idx = vocabs["tgt"].lookup_token("<0x0A>") self._tgt_start_with = vocabs["tgt"].lookup_token(vocabs["decoder_start_token"]) self._tgt_vocab_len = len(self._tgt_vocab) @@ -907,7 +909,6 @@ def _translate_batch_with_strategy(self, batch, decode_strategy): # (0) Prep the components of the search. use_src_map = self.copy_attn parallel_paths = decode_strategy.parallel_paths # beam_size - batch_size = len(batch["srclen"]) # (1) Run the encoder on the src. @@ -1022,6 +1023,7 @@ def translate_batch(self, batch, attn_debug, scoring=False): max_length = 0 if scoring else self.max_length with torch.no_grad(): if self.sample_from_topk != 0 or self.sample_from_topp != 0: + self.beam_size = 1 decode_strategy = GreedySearchLM( pad=self._tgt_pad_idx, bos=self._tgt_bos_idx, @@ -1128,6 +1130,7 @@ def _translate_batch_with_strategy(self, batch, decode_strategy, left_pad=True): # (4) Begin decoding step by step: # beg_time = time() for step in range(decode_strategy.max_length): + print("# step", step) decoder_input = ( src if step == 0 else decode_strategy.current_predictions.view(-1, 1, 1) )