From 49efc92d1f689b3ddf8a80e66dbe94eb9dfe9146 Mon Sep 17 00:00:00 2001 From: Sami Virpioja Date: Wed, 3 Apr 2024 09:57:32 +0300 Subject: [PATCH] fix score method in SentenceEmbeddingFilter --- opusfilter/embeddings.py | 3 ++- tests/test_embeddings.py | 15 +++++++++++++++ 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/opusfilter/embeddings.py b/opusfilter/embeddings.py index 2f993c8..2ad5000 100644 --- a/opusfilter/embeddings.py +++ b/opusfilter/embeddings.py @@ -133,7 +133,8 @@ def _score_chunk(self, chunk): def score(self, pairs): for chunk in grouper(pairs, self.chunksize): - return self._score_chunk(chunk) + for score in self._score_chunk(chunk): + yield score def accept(self, score): return all(similarity >= self.threshold for similarity in score) diff --git a/tests/test_embeddings.py b/tests/test_embeddings.py index bd3a973..4e22725 100644 --- a/tests/test_embeddings.py +++ b/tests/test_embeddings.py @@ -8,6 +8,7 @@ from opusfilter import ConfigurationError from opusfilter.embeddings import * +from opusfilter.pipeline import FilterPipeline try: @@ -80,3 +81,17 @@ def test_bilingual_margin_ratios(self): results = [testfilter.accept(x) for x in testfilter.score(self.bi_inputs)] for result, correct in zip(results, expected): self.assertEqual(result, correct) + + def test_chunking(self): + testfilter = SentenceEmbeddingFilter(languages=self.bi_langs, threshold=0.4, chunksize=19) + inputs = 50 * self.bi_inputs + expected = 50 * [True, True, False, False] + results = [testfilter.accept(x) for x in testfilter.score(inputs)] + for result, correct in zip(results, expected): + self.assertEqual(result, correct) + pipeline = FilterPipeline(filters=[testfilter]) + pipeline.chunksize = 30 + filtered = list(pipeline.filter(inputs)) + self.assertEqual(len(filtered), len([x for x in expected if x])) + scores = list(pipeline.score(inputs)) + self.assertEqual(len(scores), len(expected))