From 1c5c3f696b123c3a3d3afd1dcb4f570da70098d3 Mon Sep 17 00:00:00 2001 From: Abdulrahman Semrie Date: Wed, 10 Apr 2019 20:29:19 +0300 Subject: [PATCH] Fixed a filter name resolution error --- crossval/filters/loader.py | 8 +++++++- task/task_runner.py | 2 +- tests/test_score_filter.py | 4 ++-- 3 files changed, 10 insertions(+), 4 deletions(-) diff --git a/crossval/filters/loader.py b/crossval/filters/loader.py index b8922ed..99997e1 100644 --- a/crossval/filters/loader.py +++ b/crossval/filters/loader.py @@ -3,18 +3,24 @@ import importlib import inspect from crossval.filters.base_filter import BaseFilter +import logging + +logger = logging.getLogger("mozi_snet") def get_score_filters(filter_name): try: + filter_name = filter_name.lower().replace("-", "").replace("_", "") module = importlib.import_module(".score_filters", "crossval.filters") classes = inspect.getmembers(module, lambda k: inspect.isclass(k)) for name, _class in classes: - if issubclass(_class, BaseFilter) and name.lower().find(filter_name) != -1: + if issubclass(_class, BaseFilter) and name.lower().find(filter_name.lower()) != -1: + logger.info(f"Using {filter_name} score for filtering") return _class() except ImportError: + logger.error(f"Couldn't find a filter class for {filter_name}") return None diff --git a/task/task_runner.py b/task/task_runner.py index 64f03cc..5a1cf42 100644 --- a/task/task_runner.py +++ b/task/task_runner.py @@ -67,7 +67,7 @@ def setup_periodic_task(sender, **kwargs): sender.add_periodic_task(SCAN_INTERVAL, scan_expired_sessions.s(EXPIRY_SPAN), name="Scan for expired sessions") -@celery.task(name="task.task_runner.start_analysis") +@celery.task(name="task.task_runner.start_analysis", serializer="json") def start_analysis(**kwargs): """ A celery task that runs the MOSES analysis diff --git a/tests/test_score_filter.py b/tests/test_score_filter.py index ce5047a..8a9850f 100644 --- a/tests/test_score_filter.py +++ b/tests/test_score_filter.py @@ -8,9 +8,9 @@ class TestScoreFilter(unittest.TestCase): def test_loader(self): - filter_cls = loader.get_score_filters("precision") + filter_cls = loader.get_score_filters("P-value") - self.assertTrue(isinstance(filter_cls, score_filters.PrecisionFilter)) + self.assertIsNotNone(filter_cls) def test_filter_accuracy(self): models = []