From 9df5ea59ee83e99ca22b7b9cb5d0e708831654c0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stig-Arne=20Gr=C3=B6nroos?= Date: Mon, 25 Sep 2023 12:11:11 +0300 Subject: [PATCH] translation with prefix transforms --- onmt/opts.py | 13 +++++++++++++ onmt/transforms/misc.py | 21 +++++++++++++++++---- 2 files changed, 30 insertions(+), 4 deletions(-) diff --git a/onmt/opts.py b/onmt/opts.py index 8dae9f55..d8886a34 100644 --- a/onmt/opts.py +++ b/onmt/opts.py @@ -1260,6 +1260,19 @@ def translate_opts(parser, dynamic=False): # Adding options related to Transforms _add_dynamic_transform_opts(parser) + group.add( + "--src_prefix", + "-src_prefix", + default="", + help="The encoder prefix, i.e. language selector token", + ) + group.add( + "--tgt_prefix", + "-tgt_prefix", + default="", + help="The decoder prefix (FIXME: does not work, but must be set nevertheless)", + ) + def build_bilingual_model(parser): """options for modular translation""" diff --git a/onmt/transforms/misc.py b/onmt/transforms/misc.py index 2c260a5c..f4ef92c8 100644 --- a/onmt/transforms/misc.py +++ b/onmt/transforms/misc.py @@ -92,7 +92,20 @@ def get_specials(cls, opts): def warm_up(self, vocabs=None): """Warm up to get prefix dictionary.""" super().warm_up(None) - self.prefix_dict = self.get_prefix_dict(self.opts) + # TODO: The following try/except is a hack to work around the different + # structure of opts during training vs translation, and the fact the transform + # does not know whether it is being warmed up for training or translation. + # This is most elegantly fixed by redesigning and unifying the formats of opts. + try: + # This should succeed during training + self.prefix_dict = self.get_prefix_dict(self.opts) + except AttributeError: + # Normal during translation + src_prefix = self.opts.src_prefix + tgt_prefix = self.opts.tgt_prefix + self.prefix_dict = { + 'translation': {'src': src_prefix, 'tgt': tgt_prefix} + } def _prepend(self, example, prefix): """Prepend `prefix` to `tokens`.""" @@ -113,9 +126,9 @@ def apply(self, example, is_train=False, stats=None, **kwargs): if corpus_prefix is None: raise ValueError(f'prefix for {corpus_name} does not exist.') else: - corpus_prefix = kwargs.get('corpus_prefix', None) - if corpus_name is None: - raise ValueError('corpus_prefix is required.') + corpus_prefix = self.prefix_dict.get('translation', None) + if corpus_prefix is None: + raise Exception('failed to set prefixes for translation') return self._prepend(example, corpus_prefix) def _repr_args(self):