Skip to content

Commit

Permalink
translation with prefix transforms
Browse files Browse the repository at this point in the history
- Retrieve prefixes in different way during training and translation
- Pass task to translation dataset
- Allow target to be None

Fix failing unit tests

- Prefix transform requires is_train to be correctly set.
- Always raising ValueError on errors, although it is not completely appropriate.
  • Loading branch information
Waino committed Oct 2, 2023
1 parent 9874bb7 commit 9310c4d
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 18 deletions.
16 changes: 13 additions & 3 deletions mammoth/opts.py
Original file line number Diff line number Diff line change
Expand Up @@ -1193,9 +1193,6 @@ def translate_opts(parser, dynamic=False):
"Ex: {'feat_0': '../data.txt.feats0', 'feat_1': '../data.txt.feats1'}",
) # noqa: E501
group.add('--tgt', '-tgt', help='True target sequence (optional)')
group.add(
'--tgt_prefix', '-tgt_prefix', action='store_true', help='Generate predictions using provided `-tgt` as prefix.'
)
group.add(
'--shard_size',
'-shard_size',
Expand Down Expand Up @@ -1248,6 +1245,19 @@ def translate_opts(parser, dynamic=False):
# Adding options related to Transforms
_add_dynamic_transform_opts(parser)

group.add(
"--src_prefix",
"-src_prefix",
default="",
help="The encoder prefix, i.e. language selector token",
)
group.add(
"--tgt_prefix",
"-tgt_prefix",
default="",
help="The decoder prefix (FIXME: does not work, but must be set nevertheless)",
)


def build_bilingual_model(parser):
"""options for modular translation"""
Expand Down
11 changes: 6 additions & 5 deletions mammoth/tests/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def test_transform_pipe(self):
"tgt": ["Bonjour", "le", "monde", "."],
}
# 4. apply transform pipe for example
ex_after = transform_pipe.apply(copy.deepcopy(ex), corpus_name="trainset")
ex_after = transform_pipe.apply(copy.deepcopy(ex), is_train=True, corpus_name="trainset")
# 5. example after the pipe exceed the length limit, thus filtered
self.assertIsNone(ex_after)
# 6. Transform statistics registed (here for filtertoolong)
Expand Down Expand Up @@ -121,8 +121,9 @@ def test_prefix(self):
}
with self.assertRaises(ValueError):
prefix_transform.apply(ex_in)
prefix_transform.apply(ex_in, corpus_name="validset")
ex_out = prefix_transform.apply(ex_in, corpus_name="trainset")
with self.assertRaises(ValueError):
prefix_transform.apply(ex_in, is_train=False, corpus_name="validset")
ex_out = prefix_transform.apply(ex_in, is_train=True, corpus_name="trainset")
self.assertEqual(ex_out["src"][0], "⦅_pf_src⦆")
self.assertEqual(ex_out["tgt"][0], "⦅_pf_tgt⦆")

Expand All @@ -135,10 +136,10 @@ def test_filter_too_long(self):
"src": ["Hello", "world", "."],
"tgt": ["Bonjour", "le", "monde", "."],
}
ex_out = filter_transform.apply(ex_in)
ex_out = filter_transform.apply(ex_in, is_train=True)
self.assertIs(ex_out, ex_in)
filter_transform.tgt_seq_length = 2
ex_out = filter_transform.apply(ex_in)
ex_out = filter_transform.apply(ex_in, is_train=True)
self.assertIsNone(ex_out)


Expand Down
35 changes: 27 additions & 8 deletions mammoth/transforms/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,25 +43,44 @@ def get_specials(cls, opts):
def warm_up(self, vocabs=None):
"""Warm up to get prefix dictionary."""
super().warm_up(None)
self.prefix_dict = self.get_prefix_dict(self.opts)
# TODO: The following try/except is a hack to work around the different
# structure of opts during training vs translation, and the fact the transform
# does not know whether it is being warmed up for training or translation.
# This is most elegantly fixed by redesigning and unifying the formats of opts.
try:
# This should succeed during training
self.prefix_dict = self.get_prefix_dict(self.opts)
except AttributeError:
# Normal during translation
src_prefix = self.opts.src_prefix
tgt_prefix = self.opts.tgt_prefix
self.prefix_dict = {
'trans': {'src': src_prefix, 'tgt': tgt_prefix}
}

def _prepend(self, example, prefix):
"""Prepend `prefix` to `tokens`."""
for side, side_prefix in prefix.items():
example[side] = side_prefix.split() + example[side]
if example[side] is not None:
example[side] = side_prefix.split() + example[side]
return example

def apply(self, example, is_train=False, stats=None, **kwargs):
"""Apply prefix prepend to example.
Should provide `corpus_name` to get correspond prefix.
"""
corpus_name = kwargs.get('corpus_name', None)
if corpus_name is None:
raise ValueError('corpus_name is required.')
corpus_prefix = self.prefix_dict.get(corpus_name, None)
if corpus_prefix is None:
raise ValueError(f'prefix for {corpus_name} does not exist.')
if is_train:
corpus_name = kwargs.get('corpus_name', None)
if corpus_name is None:
raise ValueError('corpus_name is required.')
corpus_prefix = self.prefix_dict.get(corpus_name, None)
if corpus_prefix is None:
raise ValueError(f'prefix for {corpus_name} does not exist.')
else:
corpus_prefix = self.prefix_dict.get('trans', None)
if corpus_prefix is None:
raise ValueError('failed to set prefixes for translation')
return self._prepend(example, corpus_prefix)

def _repr_args(self):
Expand Down
3 changes: 1 addition & 2 deletions mammoth/translate/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,8 +350,6 @@ def translate_dynamic(
if batch_size is None:
raise ValueError("batch_size must be set")

if self.tgt_prefix and tgt is None:
raise ValueError("Prefix should be feed to tgt if -tgt_prefix.")
#
# data_iter = InferenceDataIterator(src, tgt, src_feats, transform)
#
Expand Down Expand Up @@ -474,6 +472,7 @@ def _translate(
transforms=transforms, # I suppose you might want *some* transforms
# batch_size=batch_size,
# batch_type=batch_type,
task=self.task,
).to(self._dev)

batches = build_dataloader(
Expand Down

0 comments on commit 9310c4d

Please sign in to comment.