Skip to content

Commit

Permalink
WIP: translation with prefix transforms
Browse files Browse the repository at this point in the history
  • Loading branch information
Waino committed Sep 25, 2023
1 parent 86b8183 commit 6e09039
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 11 deletions.
3 changes: 0 additions & 3 deletions onmt/opts.py
Original file line number Diff line number Diff line change
Expand Up @@ -1208,9 +1208,6 @@ def translate_opts(parser, dynamic=False):
"Ex: {'feat_0': '../data.txt.feats0', 'feat_1': '../data.txt.feats1'}",
) # noqa: E501
group.add('--tgt', '-tgt', help='True target sequence (optional)')
group.add(
'--tgt_prefix', '-tgt_prefix', action='store_true', help='Generate predictions using provided `-tgt` as prefix.'
)
group.add(
'--shard_size',
'-shard_size',
Expand Down
17 changes: 11 additions & 6 deletions onmt/transforms/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,12 +105,17 @@ def apply(self, example, is_train=False, stats=None, **kwargs):
Should provide `corpus_name` to get correspond prefix.
"""
corpus_name = kwargs.get('corpus_name', None)
if corpus_name is None:
raise ValueError('corpus_name is required.')
corpus_prefix = self.prefix_dict.get(corpus_name, None)
if corpus_prefix is None:
raise ValueError(f'prefix for {corpus_name} does not exist.')
if is_train:
corpus_name = kwargs.get('corpus_name', None)
if corpus_name is None:
raise ValueError('corpus_name is required.')
corpus_prefix = self.prefix_dict.get(corpus_name, None)
if corpus_prefix is None:
raise ValueError(f'prefix for {corpus_name} does not exist.')
else:
corpus_prefix = kwargs.get('corpus_prefix', None)
if corpus_name is None:
raise ValueError('corpus_prefix is required.')
return self._prepend(example, corpus_prefix)

def _repr_args(self):
Expand Down
2 changes: 0 additions & 2 deletions onmt/translate/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,8 +350,6 @@ def translate_dynamic(
if batch_size is None:
raise ValueError("batch_size must be set")

if self.tgt_prefix and tgt is None:
raise ValueError("Prefix should be feed to tgt if -tgt_prefix.")
#
# data_iter = InferenceDataIterator(src, tgt, src_feats, transform)
#
Expand Down

0 comments on commit 6e09039

Please sign in to comment.