Skip to content

Commit

Permalink
translation with prefix transforms
Browse files Browse the repository at this point in the history
  • Loading branch information
Waino committed Sep 25, 2023
1 parent 6e09039 commit 9df5ea5
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 4 deletions.
13 changes: 13 additions & 0 deletions onmt/opts.py
Original file line number Diff line number Diff line change
Expand Up @@ -1260,6 +1260,19 @@ def translate_opts(parser, dynamic=False):
# Adding options related to Transforms
_add_dynamic_transform_opts(parser)

group.add(
"--src_prefix",
"-src_prefix",
default="",
help="The encoder prefix, i.e. language selector token",
)
group.add(
"--tgt_prefix",
"-tgt_prefix",
default="",
help="The decoder prefix (FIXME: does not work, but must be set nevertheless)",
)


def build_bilingual_model(parser):
"""options for modular translation"""
Expand Down
21 changes: 17 additions & 4 deletions onmt/transforms/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,20 @@ def get_specials(cls, opts):
def warm_up(self, vocabs=None):
"""Warm up to get prefix dictionary."""
super().warm_up(None)
self.prefix_dict = self.get_prefix_dict(self.opts)
# TODO: The following try/except is a hack to work around the different
# structure of opts during training vs translation, and the fact the transform
# does not know whether it is being warmed up for training or translation.
# This is most elegantly fixed by redesigning and unifying the formats of opts.
try:
# This should succeed during training
self.prefix_dict = self.get_prefix_dict(self.opts)
except AttributeError:
# Normal during translation
src_prefix = self.opts.src_prefix
tgt_prefix = self.opts.tgt_prefix
self.prefix_dict = {
'translation': {'src': src_prefix, 'tgt': tgt_prefix}
}

def _prepend(self, example, prefix):
"""Prepend `prefix` to `tokens`."""
Expand All @@ -113,9 +126,9 @@ def apply(self, example, is_train=False, stats=None, **kwargs):
if corpus_prefix is None:
raise ValueError(f'prefix for {corpus_name} does not exist.')
else:
corpus_prefix = kwargs.get('corpus_prefix', None)
if corpus_name is None:
raise ValueError('corpus_prefix is required.')
corpus_prefix = self.prefix_dict.get('translation', None)
if corpus_prefix is None:
raise Exception('failed to set prefixes for translation')

Check failure on line 131 in onmt/transforms/misc.py

View workflow job for this annotation

GitHub Actions / lint-and-tests (3.8)

failed to set prefixes for translation

Check failure on line 131 in onmt/transforms/misc.py

View workflow job for this annotation

GitHub Actions / lint-and-tests (3.8)

failed to set prefixes for translation
return self._prepend(example, corpus_prefix)

def _repr_args(self):
Expand Down

0 comments on commit 9df5ea5

Please sign in to comment.