From 6e090390a9e8d690a97d6bf2433776fa1aa89ff6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stig-Arne=20Gr=C3=B6nroos?= Date: Mon, 25 Sep 2023 11:00:09 +0300 Subject: [PATCH] WIP: translation with prefix transforms --- onmt/opts.py | 3 --- onmt/transforms/misc.py | 17 +++++++++++------ onmt/translate/translator.py | 2 -- 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/onmt/opts.py b/onmt/opts.py index 531a88c7..8dae9f55 100644 --- a/onmt/opts.py +++ b/onmt/opts.py @@ -1208,9 +1208,6 @@ def translate_opts(parser, dynamic=False): "Ex: {'feat_0': '../data.txt.feats0', 'feat_1': '../data.txt.feats1'}", ) # noqa: E501 group.add('--tgt', '-tgt', help='True target sequence (optional)') - group.add( - '--tgt_prefix', '-tgt_prefix', action='store_true', help='Generate predictions using provided `-tgt` as prefix.' - ) group.add( '--shard_size', '-shard_size', diff --git a/onmt/transforms/misc.py b/onmt/transforms/misc.py index a7b8e1a0..2c260a5c 100644 --- a/onmt/transforms/misc.py +++ b/onmt/transforms/misc.py @@ -105,12 +105,17 @@ def apply(self, example, is_train=False, stats=None, **kwargs): Should provide `corpus_name` to get correspond prefix. """ - corpus_name = kwargs.get('corpus_name', None) - if corpus_name is None: - raise ValueError('corpus_name is required.') - corpus_prefix = self.prefix_dict.get(corpus_name, None) - if corpus_prefix is None: - raise ValueError(f'prefix for {corpus_name} does not exist.') + if is_train: + corpus_name = kwargs.get('corpus_name', None) + if corpus_name is None: + raise ValueError('corpus_name is required.') + corpus_prefix = self.prefix_dict.get(corpus_name, None) + if corpus_prefix is None: + raise ValueError(f'prefix for {corpus_name} does not exist.') + else: + corpus_prefix = kwargs.get('corpus_prefix', None) + if corpus_name is None: + raise ValueError('corpus_prefix is required.') return self._prepend(example, corpus_prefix) def _repr_args(self): diff --git a/onmt/translate/translator.py b/onmt/translate/translator.py index 44c9092c..421032ce 100644 --- a/onmt/translate/translator.py +++ b/onmt/translate/translator.py @@ -350,8 +350,6 @@ def translate_dynamic( if batch_size is None: raise ValueError("batch_size must be set") - if self.tgt_prefix and tgt is None: - raise ValueError("Prefix should be feed to tgt if -tgt_prefix.") # # data_iter = InferenceDataIterator(src, tgt, src_feats, transform) #