Skip to content

Commit

Permalink
Fixed a filter name resolution error
Browse files Browse the repository at this point in the history
  • Loading branch information
Habush committed Apr 10, 2019
1 parent 77ce337 commit 1c5c3f6
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 4 deletions.
8 changes: 7 additions & 1 deletion crossval/filters/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,24 @@
import importlib
import inspect
from crossval.filters.base_filter import BaseFilter
import logging

logger = logging.getLogger("mozi_snet")


def get_score_filters(filter_name):
try:
filter_name = filter_name.lower().replace("-", "").replace("_", "")
module = importlib.import_module(".score_filters", "crossval.filters")
classes = inspect.getmembers(module, lambda k: inspect.isclass(k))

for name, _class in classes:
if issubclass(_class, BaseFilter) and name.lower().find(filter_name) != -1:
if issubclass(_class, BaseFilter) and name.lower().find(filter_name.lower()) != -1:
logger.info(f"Using {filter_name} score for filtering")
return _class()

except ImportError:
logger.error(f"Couldn't find a filter class for {filter_name}")
return None


Expand Down
2 changes: 1 addition & 1 deletion task/task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def setup_periodic_task(sender, **kwargs):
sender.add_periodic_task(SCAN_INTERVAL, scan_expired_sessions.s(EXPIRY_SPAN), name="Scan for expired sessions")


@celery.task(name="task.task_runner.start_analysis")
@celery.task(name="task.task_runner.start_analysis", serializer="json")
def start_analysis(**kwargs):
"""
A celery task that runs the MOSES analysis
Expand Down
4 changes: 2 additions & 2 deletions tests/test_score_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
class TestScoreFilter(unittest.TestCase):

def test_loader(self):
filter_cls = loader.get_score_filters("precision")
filter_cls = loader.get_score_filters("P-value")

self.assertTrue(isinstance(filter_cls, score_filters.PrecisionFilter))
self.assertIsNotNone(filter_cls)

def test_filter_accuracy(self):
models = []
Expand Down

0 comments on commit 1c5c3f6

Please sign in to comment.