From 9d6fb12481a8da1159f7a42f56bfdd4058bc576b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stig-Arne=20Gr=C3=B6nroos?= Date: Mon, 2 Oct 2023 10:50:17 +0300 Subject: [PATCH] Fix failing unit tests - Prefix transform requires is_train to be correctly set. - Always raising ValueError on errors, although it is not completely appropriate. --- onmt/tests/test_transform.py | 11 ++++++----- onmt/transforms/misc.py | 2 +- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/onmt/tests/test_transform.py b/onmt/tests/test_transform.py index 234e6ebb..b0543dba 100644 --- a/onmt/tests/test_transform.py +++ b/onmt/tests/test_transform.py @@ -85,7 +85,7 @@ def test_transform_pipe(self): "tgt": ["Bonjour", "le", "monde", "."], } # 4. apply transform pipe for example - ex_after = transform_pipe.apply(copy.deepcopy(ex), corpus_name="trainset") + ex_after = transform_pipe.apply(copy.deepcopy(ex), is_train=True, corpus_name="trainset") # 5. example after the pipe exceed the length limit, thus filtered self.assertIsNone(ex_after) # 6. Transform statistics registed (here for filtertoolong) @@ -121,8 +121,9 @@ def test_prefix(self): } with self.assertRaises(ValueError): prefix_transform.apply(ex_in) - prefix_transform.apply(ex_in, corpus_name="validset") - ex_out = prefix_transform.apply(ex_in, corpus_name="trainset") + with self.assertRaises(ValueError): + prefix_transform.apply(ex_in, is_train=False, corpus_name="validset") + ex_out = prefix_transform.apply(ex_in, is_train=True, corpus_name="trainset") self.assertEqual(ex_out["src"][0], "⦅_pf_src⦆") self.assertEqual(ex_out["tgt"][0], "⦅_pf_tgt⦆") @@ -135,10 +136,10 @@ def test_filter_too_long(self): "src": ["Hello", "world", "."], "tgt": ["Bonjour", "le", "monde", "."], } - ex_out = filter_transform.apply(ex_in) + ex_out = filter_transform.apply(ex_in, is_train=True) self.assertIs(ex_out, ex_in) filter_transform.tgt_seq_length = 2 - ex_out = filter_transform.apply(ex_in) + ex_out = filter_transform.apply(ex_in, is_train=True) self.assertIsNone(ex_out) diff --git a/onmt/transforms/misc.py b/onmt/transforms/misc.py index 0c50ef59..58c501b4 100644 --- a/onmt/transforms/misc.py +++ b/onmt/transforms/misc.py @@ -129,7 +129,7 @@ def apply(self, example, is_train=False, stats=None, **kwargs): else: corpus_prefix = self.prefix_dict.get('trans', None) if corpus_prefix is None: - raise Exception('failed to set prefixes for translation') + raise ValueError('failed to set prefixes for translation') return self._prepend(example, corpus_prefix) def _repr_args(self):