Skip to content

Commit

Permalink
Fix patch_settings for iblrigv8 (#713)
Browse files Browse the repository at this point in the history
* Fix patch_settings for iblrigv8
* Increase coverage
  • Loading branch information
k1o0 authored Jan 18, 2024
1 parent 550f783 commit 9f43340
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 26 deletions.
19 changes: 7 additions & 12 deletions ibllib/io/raw_data_loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -736,16 +736,6 @@ def _groom_wheel_data_ge5(data, label='file ', path=''):
return data


def save_bool(save, dataset_type):
if isinstance(save, bool):
out = save
elif isinstance(save, list):
out = (dataset_type in save) or (Path(dataset_type).stem in save)
if out:
_logger.debug('extracting' + dataset_type)
return out


def sync_trials_robust(t0, t1, diff_threshold=0.001, drift_threshold_ppm=200, max_shift=5,
return_index=False):
"""
Expand Down Expand Up @@ -945,7 +935,7 @@ def patch_settings(session_path, collection='raw_behavior_data',
if not settings:
raise IOError('Settings file not found')

filename = PureWindowsPath(settings['SETTINGS_FILE_PATH']).name
filename = PureWindowsPath(settings.get('SETTINGS_FILE_PATH', '_iblrig_taskSettings.raw.json')).name
file_path = Path(session_path).joinpath(collection, filename)

if subject:
Expand All @@ -955,7 +945,8 @@ def patch_settings(session_path, collection='raw_behavior_data',
for k in settings.keys():
if isinstance(settings[k], str):
settings[k] = settings[k].replace(f'\\Subjects\\{old_subject}', f'\\Subjects\\{subject}')
settings['SESSION_NAME'] = '\\'.join([subject, *settings['SESSION_NAME'].split('\\')[1:]])
if 'SESSION_NAME' in settings:
settings['SESSION_NAME'] = '\\'.join([subject, *settings['SESSION_NAME'].split('\\')[1:]])
settings.pop('PYBPOD_SUBJECT_EXTRA', None) # Get rid of Alyx subject info

if date:
Expand All @@ -970,6 +961,10 @@ def patch_settings(session_path, collection='raw_behavior_data',
f'\\{settings["SUBJECT_NAME"]}\\{date}'
)
settings['SESSION_DATETIME'] = date + settings['SESSION_DATETIME'][10:]
if 'SESSION_END_TIME' in settings:
settings['SESSION_END_TIME'] = date + settings['SESSION_END_TIME'][10:]
if 'SESSION_START_TIME' in settings:
settings['SESSION_START_TIME'] = date + settings['SESSION_START_TIME'][10:]

if number:
# Patch session number
Expand Down
18 changes: 5 additions & 13 deletions ibllib/io/session_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from copy import deepcopy

from one.converters import ConversionMixin
from iblutil.util import flatten
from packaging import version

import ibllib.pipes.misc as misc
Expand Down Expand Up @@ -391,15 +392,15 @@ def get_collections(sess_params, flat=False):
sess_params : dict
The loaded experiment description map.
flat : bool (False)
If True, return a flat list of unique collections, otherwise return a map of device/sync/task
If True, return a flat set of collections, otherwise return a map of device/sync/task
Returns
-------
dict[str, str]
A map of device/sync/task and the corresponding collection name.
list[str]
A flat list of unique collection names.
set[str]
A set of unique collection names.
Notes
-----
Expand All @@ -423,16 +424,7 @@ def iter_dict(d):
iter_dict(v)

iter_dict(sess_params)
if flat:
cflat = []
for k, v in collection_map.items():
if isinstance(v, list):
cflat.extend(v)
else:
cflat.append(v)
return list(set(cflat))
else:
return collection_map
return set(flatten(collection_map.values())) if flat else collection_map


def get_video_compressed(sess_params):
Expand Down
18 changes: 17 additions & 1 deletion ibllib/tests/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,11 +166,15 @@ def test_load_encoder_trial_info(self):
self.session = Path(__file__).parent.joinpath('extractors', 'data', 'session_biased_ge5')
data = raw.load_encoder_trial_info(self.session)
self.assertTrue(data is not None)
self.assertIsNone(raw.load_encoder_trial_info(self.session.with_name('empty')))
self.assertIsNone(raw.load_encoder_trial_info(None))

def test_load_camera_ssv_times(self):
session = Path(__file__).parent.joinpath('extractors', 'data', 'session_ephys')
with self.assertRaises(ValueError):
raw.load_camera_ssv_times(session, 'tail')
with self.assertRaises(FileNotFoundError):
raw.load_camera_ssv_times(session.with_name('foobar'), 'body')
bonsai, camera = raw.load_camera_ssv_times(session, 'body')
self.assertTrue(bonsai.size == camera.size == 6001)
self.assertEqual(bonsai.dtype.str, '<M8[ns]')
Expand Down Expand Up @@ -297,6 +301,18 @@ def test_load_camera_frameData(self):
self.assertTrue(fd.dtypes.to_dict() == parsed_dtypes)
self.assertTrue(all([x == np.int64 for x in fd_raw.dtypes]))

def test_load_settings(self):
main_path = Path(__file__).parent.joinpath('extractors', 'data')
self.training_ge5 = main_path / 'session_training_ge5'
settings = raw.load_settings(self.training_ge5)
self.assertIsInstance(settings, dict)
self.assertEqual(144, len(settings))
with self.assertLogs(raw._logger, level=20):
self.assertIsNone(raw.load_settings(None))
# Should return None when path empty
with self.assertLogs(raw._logger, level=20):
self.assertIsNone(raw.load_settings(self.training_ge5, 'raw_task_data_00'))

def tearDown(self):
self.tempfile.close()
os.unlink(self.tempfile.name)
Expand Down Expand Up @@ -588,7 +604,7 @@ def test_get_collections_repeat_protocols(self):
collections = session_params.get_collections(tasks)
self.assertEqual(set(collections['passiveChoiceWorld']), set(['raw_passive_data_bis', 'raw_passive_data']))
collections = session_params.get_collections(tasks, flat=True)
self.assertEqual(set(collections), set(['raw_passive_data_bis', 'raw_passive_data', 'raw_behavior_data']))
self.assertEqual(collections, {'raw_passive_data_bis', 'raw_passive_data', 'raw_behavior_data'})


class TestRawDaqLoaders(unittest.TestCase):
Expand Down

0 comments on commit 9f43340

Please sign in to comment.