Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/develop' into iblsort
Browse files Browse the repository at this point in the history
  • Loading branch information
oliche committed Oct 25, 2024
2 parents bf8fea5 + cdf635d commit 4f652e7
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 18 deletions.
2 changes: 1 addition & 1 deletion ibllib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import logging
import warnings

__version__ = '2.39.1'
__version__ = '2.39.2'
warnings.filterwarnings('always', category=DeprecationWarning, module='ibllib')

# if this becomes a full-blown library we should let the logging configuration to the discretion of the dev
Expand Down
30 changes: 16 additions & 14 deletions ibllib/io/extractors/bpod_trials.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,14 @@
This module will extract the Bpod trials and wheel data based on the task protocol,
i.e. habituation, training or biased.
"""
import logging
import importlib

from ibllib.io.extractors.base import get_bpod_extractor_class, protocol2extractor
from ibllib.io.extractors.base import get_bpod_extractor_class, protocol2extractor, BaseExtractor
from ibllib.io.extractors.habituation_trials import HabituationTrials
from ibllib.io.extractors.training_trials import TrainingTrials
from ibllib.io.extractors.biased_trials import BiasedTrials, EphysTrials
from ibllib.io.extractors.base import BaseBpodTrialsExtractor

_logger = logging.getLogger(__name__)


def get_bpod_extractor(session_path, protocol=None, task_collection='raw_behavior_data') -> BaseBpodTrialsExtractor:
"""
Expand All @@ -39,20 +36,25 @@ def get_bpod_extractor(session_path, protocol=None, task_collection='raw_behavio
'BiasedTrials': BiasedTrials,
'EphysTrials': EphysTrials
}

if protocol:
class_name = protocol2extractor(protocol)
extractor_class_name = protocol2extractor(protocol)
else:
class_name = get_bpod_extractor_class(session_path, task_collection=task_collection)
if class_name in builtins:
return builtins[class_name](session_path)
extractor_class_name = get_bpod_extractor_class(session_path, task_collection=task_collection)
if extractor_class_name in builtins:
return builtins[extractor_class_name](session_path)

# look if there are custom extractor types in the personal projects repo
if not class_name.startswith('projects.'):
class_name = 'projects.' + class_name
module, class_name = class_name.rsplit('.', 1)
if not extractor_class_name.startswith('projects.'):
extractor_class_name = 'projects.' + extractor_class_name
module, extractor_class_name = extractor_class_name.rsplit('.', 1)
mdl = importlib.import_module(module)
extractor_class = getattr(mdl, class_name, None)
extractor_class = getattr(mdl, extractor_class_name, None)
if extractor_class:
return extractor_class(session_path)
my_extractor = extractor_class(session_path)
if not isinstance(my_extractor, BaseExtractor):
raise ValueError(
f"{my_extractor} should be an Extractor class inheriting from ibllib.io.extractors.base.BaseExtractor")
return my_extractor
else:
raise ValueError(f'extractor {class_name} not found')
raise ValueError(f'extractor {extractor_class_name} not found')
35 changes: 32 additions & 3 deletions ibllib/tests/extractors/test_extractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@
import shutil
import tempfile
import unittest
import unittest.mock
from unittest.mock import patch, Mock, MagicMock
from pathlib import Path

import numpy as np
import pandas as pd

import one.alf.io as alfio
from ibllib.io.extractors.bpod_trials import get_bpod_extractor
from ibllib.io.extractors import training_trials, biased_trials, camera
from ibllib.io import raw_data_loaders as raw
from ibllib.io.extractors.base import BaseExtractor
Expand Down Expand Up @@ -530,13 +531,13 @@ def test_size_outputs(self):
'peakVelocity_times': np.array([1, 1])}
function_name = 'ibllib.io.extractors.training_wheel.extract_wheel_moves'
# Training
with unittest.mock.patch(function_name, return_value=mock_data):
with patch(function_name, return_value=mock_data):
task, = get_trials_tasks(self.training_lt5['path'])
trials, _ = task.extract_behaviour(save=True)
trials = alfio.load_object(self.training_lt5['path'] / 'alf', object='trials')
self.assertTrue(alfio.check_dimensions(trials) == 0)
# Biased
with unittest.mock.patch(function_name, return_value=mock_data):
with patch(function_name, return_value=mock_data):
task, = get_trials_tasks(self.biased_lt5['path'])
trials, _ = task.extract_behaviour(save=True)
trials = alfio.load_object(self.biased_lt5['path'] / 'alf', object='trials')
Expand Down Expand Up @@ -753,5 +754,33 @@ def test_attribute_times(self, display=False):
camera.attribute_times(tsa, tsb, injective=False, take='closest')


class TestGetBpodExtractor(unittest.TestCase):

def test_get_bpod_extractor(self):
# un-existing extractor should raise a value error
with self.assertRaises(ValueError):
get_bpod_extractor('', protocol='sdf', task_collection='raw_behavior_data')
# in this case this returns an ibllib.io.extractors.training_trials.TrainingTrials instance
extractor = get_bpod_extractor(
'', protocol='_trainingChoiceWorld',
task_collection='raw_behavior_data'
)
self.assertTrue(isinstance(extractor, BaseExtractor))

def test_get_bpod_custom_extractor(self):
# here we'll mock a custom module with a custom extractor
DummyModule = MagicMock()
DummyExtractor = Mock(spec_set=BaseExtractor)
DummyModule.toto.return_value = DummyExtractor
base_module = 'ibllib.io.extractors.bpod_trials'
with patch(f'{base_module}.get_bpod_extractor_class', return_value='toto'), \
patch(f'{base_module}.importlib.import_module', return_value=DummyModule) as import_mock:
self.assertIs(get_bpod_extractor(''), DummyExtractor)
import_mock.assert_called_with('projects')
# Check raises when imported class not an extractor
DummyModule.toto.return_value = MagicMock(spec=dict)
self.assertRaisesRegex(ValueError, 'should be an Extractor class', get_bpod_extractor, '')


if __name__ == '__main__':
unittest.main(exit=False, verbosity=2)
4 changes: 4 additions & 0 deletions release_notes.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@

#### 2.39.1
- Bugfix: brainbox.metrics.single_unit.quick_unit_metrics fix for indexing of n_spike_below2
-
#### 2.39.2
- Bugfix: routing of protocol to extractor through the project repository checks that the
target is indeed an extractor class.

## Release Note 2.38.0

Expand Down

0 comments on commit 4f652e7

Please sign in to comment.