Skip to content

Commit

Permalink
Fix failing unit tests
Browse files Browse the repository at this point in the history
- Prefix transform requires is_train to be correctly set.
- Always raising ValueError on errors, although it is not completely appropriate.
  • Loading branch information
Waino committed Oct 2, 2023
1 parent 882262f commit 9d6fb12
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 6 deletions.
11 changes: 6 additions & 5 deletions onmt/tests/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def test_transform_pipe(self):
"tgt": ["Bonjour", "le", "monde", "."],
}
# 4. apply transform pipe for example
ex_after = transform_pipe.apply(copy.deepcopy(ex), corpus_name="trainset")
ex_after = transform_pipe.apply(copy.deepcopy(ex), is_train=True, corpus_name="trainset")
# 5. example after the pipe exceed the length limit, thus filtered
self.assertIsNone(ex_after)
# 6. Transform statistics registed (here for filtertoolong)
Expand Down Expand Up @@ -121,8 +121,9 @@ def test_prefix(self):
}
with self.assertRaises(ValueError):
prefix_transform.apply(ex_in)
prefix_transform.apply(ex_in, corpus_name="validset")
ex_out = prefix_transform.apply(ex_in, corpus_name="trainset")
with self.assertRaises(ValueError):
prefix_transform.apply(ex_in, is_train=False, corpus_name="validset")
ex_out = prefix_transform.apply(ex_in, is_train=True, corpus_name="trainset")
self.assertEqual(ex_out["src"][0], "⦅_pf_src⦆")
self.assertEqual(ex_out["tgt"][0], "⦅_pf_tgt⦆")

Expand All @@ -135,10 +136,10 @@ def test_filter_too_long(self):
"src": ["Hello", "world", "."],
"tgt": ["Bonjour", "le", "monde", "."],
}
ex_out = filter_transform.apply(ex_in)
ex_out = filter_transform.apply(ex_in, is_train=True)
self.assertIs(ex_out, ex_in)
filter_transform.tgt_seq_length = 2
ex_out = filter_transform.apply(ex_in)
ex_out = filter_transform.apply(ex_in, is_train=True)
self.assertIsNone(ex_out)


Expand Down
2 changes: 1 addition & 1 deletion onmt/transforms/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def apply(self, example, is_train=False, stats=None, **kwargs):
else:
corpus_prefix = self.prefix_dict.get('trans', None)
if corpus_prefix is None:
raise Exception('failed to set prefixes for translation')
raise ValueError('failed to set prefixes for translation')
return self._prepend(example, corpus_prefix)

def _repr_args(self):
Expand Down

0 comments on commit 9d6fb12

Please sign in to comment.