Skip to content

Commit

Permalink
Merge branch 'develop' into iblsort
Browse files Browse the repository at this point in the history
  • Loading branch information
oliche committed Oct 24, 2024
2 parents 50ba7fb + 09905da commit 1d11edc
Show file tree
Hide file tree
Showing 10 changed files with 98 additions and 56 deletions.
2 changes: 1 addition & 1 deletion ibllib/io/extractors/fibrephotometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ def _extract(self, light_source_map=None, collection=None, regions=None, **kwarg
regions = regions or [k for k in fp_data['raw'].keys() if 'Region' in k]
out_df = fp_data['raw'].filter(items=regions, axis=1).sort_index(axis=1)
out_df['times'] = ts
out_df['wavelength'] = np.NaN
out_df['wavelength'] = np.nan
out_df['name'] = ''
out_df['color'] = ''
# Extract channel index
Expand Down
8 changes: 4 additions & 4 deletions ibllib/io/extractors/opto_trials.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ class LaserBool(BaseBpodTrialsExtractor):
def _extract(self, **kwargs):
_logger.info('Extracting laser datasets')
# reference pybpod implementation
lstim = np.array([float(t.get('laser_stimulation', np.NaN)) for t in self.bpod_trials])
lprob = np.array([float(t.get('laser_probability', np.NaN)) for t in self.bpod_trials])
lstim = np.array([float(t.get('laser_stimulation', np.nan)) for t in self.bpod_trials])
lprob = np.array([float(t.get('laser_probability', np.nan)) for t in self.bpod_trials])

# Karolina's choice world legacy implementation - from Slack message:
# it is possible that some versions I have used:
Expand All @@ -30,9 +30,9 @@ def _extract(self, **kwargs):
# laserOFF_trials=(optoOUT ==0);
if 'PROBABILITY_OPTO' in self.settings.keys() and np.all(np.isnan(lstim)):
lprob = np.zeros_like(lprob) + self.settings['PROBABILITY_OPTO']
lstim = np.array([float(t.get('opto_ON_time', np.NaN)) for t in self.bpod_trials])
lstim = np.array([float(t.get('opto_ON_time', np.nan)) for t in self.bpod_trials])
if np.all(np.isnan(lstim)):
lstim = np.array([float(t.get('optoOUT', np.NaN)) for t in self.bpod_trials])
lstim = np.array([float(t.get('optoOUT', np.nan)) for t in self.bpod_trials])
lstim[lstim == 255] = 1
else:
lstim[~np.isnan(lstim)] = 1
Expand Down
2 changes: 1 addition & 1 deletion ibllib/io/extractors/training_trials.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def _extract(self):
feedbackType = np.zeros(len(self.bpod_trials), np.int64)
for i, t in enumerate(self.bpod_trials):
state_names = ['correct', 'error', 'no_go', 'omit_correct', 'omit_error', 'omit_no_go']
outcome = {sn: ~np.isnan(t['behavior_data']['States timestamps'].get(sn, [[np.NaN]])[0][0]) for sn in state_names}
outcome = {sn: ~np.isnan(t['behavior_data']['States timestamps'].get(sn, [[np.nan]])[0][0]) for sn in state_names}
assert np.sum(list(outcome.values())) == 1
outcome = next(k for k in outcome if outcome[k])
if outcome == 'correct':
Expand Down
19 changes: 17 additions & 2 deletions ibllib/io/session_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import logging
import socket
from pathlib import Path
from itertools import chain
from copy import deepcopy

from one.converters import ConversionMixin
Expand Down Expand Up @@ -77,6 +78,9 @@ def _patch_file(data: dict) -> dict:
if 'tasks' in data and isinstance(data['tasks'], dict):
data['tasks'] = [{k: v} for k, v in data['tasks'].copy().items()]
data['version'] = SPEC_VERSION
# Ensure all items in tasks list are single value dicts
if 'tasks' in data:
data['tasks'] = [{k: v} for k, v in chain.from_iterable(map(dict.items, data['tasks']))]
return data


Expand Down Expand Up @@ -168,8 +172,19 @@ def merge_params(a, b, copy=False):
assert k not in a or a[k] == b[k], 'multiple sync fields defined'
if isinstance(b[k], list):
prev = list(a.get(k, []))
# For procedures and projects, remove duplicates
to_add = b[k] if k == 'tasks' else set(b[k]) - set(prev)
if k == 'tasks':
# For tasks, keep order and skip duplicates
# Assert tasks is a list of single value dicts
assert (not prev or set(map(len, prev)) == {1}) and set(map(len, b[k])) == {1}
# Convert protocol -> dict map to hashable tuple of protocol + sorted key value pairs
to_hashable = lambda itm: (itm[0], *chain.from_iterable(sorted(itm[1].items()))) # noqa
# Get the set of previous tasks
prev_tasks = set(map(to_hashable, chain.from_iterable(map(dict.items, prev))))
tasks = chain.from_iterable(map(dict.items, b[k]))
to_add = [dict([itm]) for itm in tasks if to_hashable(itm) not in prev_tasks]
else:
# For procedures and projects, remove duplicates
to_add = set(b[k]) - set(prev)
a[k] = prev + list(to_add)
elif isinstance(b[k], dict):
a[k] = {**a.get(k, {}), **b[k]}
Expand Down
51 changes: 35 additions & 16 deletions ibllib/pipes/mesoscope_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,23 +486,42 @@ def _consolidate_exptQC(exptQC):
numpy.array
An array of frame indices where QC code != 0.
"""

# Merge and make sure same indexes have same names across all files
frameQC_names_list = [e['frameQC_names'] for e in exptQC]
frameQC_names_list = [{k: i for i, k in enumerate(ensure_list(f))} for f in frameQC_names_list]
frameQC_names = {k: v for d in frameQC_names_list for k, v in d.items()}
for d in frameQC_names_list:
for k, v in d.items():
if frameQC_names[k] != v:
raise IOError(f'exptQC.mat files have different values for name "{k}"')

frameQC_names = pd.DataFrame(sorted([(v, k) for k, v in frameQC_names.items()]),
columns=['qc_values', 'qc_labels'])

# Create a new enumeration combining all unique QC labels.
# 'ok' will always have an enum of 0, the rest are determined by order alone
qc_labels = ['ok']
frame_qc = []
for e in exptQC:
assert e.keys() >= set(['frameQC_names', 'frameQC_frames'])
# Initialize an NaN array the same size of frameQC_frames to fill with new enum values
frames = np.full(e['frameQC_frames'].shape, fill_value=np.nan)
# May be numpy array of str or a single str, in both cases we cast to list of str
names = list(ensure_list(e['frameQC_names']))
# For each label for the old enum, populate initialized array with the new one
for name in names:
i_old = names.index(name) # old enumeration
name = name if len(name) else 'unknown' # handle empty array and empty str
try:
i_new = qc_labels.index(name)
except ValueError:
i_new = len(qc_labels)
qc_labels.append(name)
frames[e['frameQC_frames'] == i_old] = i_new
frame_qc.append(frames)
# Concatenate frames
frameQC = np.concatenate([e['frameQC_frames'] for e in exptQC], axis=0)
bad_frames = np.where(frameQC != 0)[0]
return frameQC, frameQC_names, bad_frames
frame_qc = np.concatenate(frame_qc)
# If any NaNs left over, assign 'unknown' label
if (missing_name := np.isnan(frame_qc)).any():
try:
i = qc_labels.index('unknown')
except ValueError:
i = len(qc_labels)
qc_labels.append('unknown')
frame_qc[missing_name] = i
frame_qc = frame_qc.astype(np.uint32) # case to uint
bad_frames, = np.where(frame_qc != 0)
# Convert labels to value -> label data frame
frame_qc_names = pd.DataFrame(list(enumerate(qc_labels)), columns=['qc_values', 'qc_labels'])
return frame_qc, frame_qc_names, bad_frames

def get_default_tau(self):
"""
Expand Down
2 changes: 1 addition & 1 deletion ibllib/qc/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def overall_outcome(outcomes: iter, agg=max) -> spec.QC:
one.alf.spec.QC
The overall outcome.
"""
outcomes = filter(lambda x: x not in (None, np.NaN), outcomes)
outcomes = filter(lambda x: x not in (None, np.nan), outcomes)
return agg(map(spec.QC.validate, outcomes))

def _set_eid_or_path(self, session_path_or_eid):
Expand Down
2 changes: 1 addition & 1 deletion ibllib/qc/task_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,7 +512,7 @@ def compute(self, **kwargs):
iti = (np.roll(data['stimOn_times'], -1) - data['stimOff_times'])[:-1]
metric = np.r_[np.nan_to_num(iti, nan=np.inf), np.nan] - 1.
passed[check] = np.abs(metric) <= 0.1
passed[check][-1] = np.NaN
passed[check][-1] = np.nan
metrics[check] = metric

# Checks common to training QC
Expand Down
20 changes: 19 additions & 1 deletion ibllib/tests/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,10 +527,22 @@ def test_read_yaml(self):
self.assertCountEqual(self.fixture.keys(), data_keys)

def test_patch_data(self):
"""Test for session_params._patch_file function."""
with patch(session_params.__name__ + '.SPEC_VERSION', '1.0.0'), \
self.assertLogs(session_params.__name__, logging.WARNING):
data = session_params._patch_file({'version': '1.1.0'})
self.assertEqual(data, {'version': '1.0.0'})
# Check tasks dicts separated into lists
unpatched = {'version': '0.0.1', 'tasks': {
'fooChoiceWorld': {1: '1'}, 'barChoiceWorld': {2: '2'}}}
data = session_params._patch_file(unpatched)
self.assertIsInstance(data['tasks'], list)
self.assertEqual([['fooChoiceWorld'], ['barChoiceWorld']], list(map(list, data['tasks'])))
# Check patching list of dicts with some containing more than 1 key
unpatched = {'tasks': [{'foo': {1: '1'}}, {'bar': {2: '2'}, 'baz': {3: '3'}}]}
data = session_params._patch_file(unpatched)
self.assertEqual(3, len(data['tasks']))
self.assertEqual([['foo'], ['bar'], ['baz']], list(map(list, data['tasks'])))

def test_get_collections(self):
collections = session_params.get_collections(self.fixture)
Expand Down Expand Up @@ -561,10 +573,16 @@ def test_merge_params(self):
b = {'procedures': ['Imaging', 'Injection'], 'tasks': [{'fooChoiceWorld': {'collection': 'bar'}}]}
c = session_params.merge_params(a, b, copy=True)
self.assertCountEqual(['Imaging', 'Behavior training/tasks', 'Injection'], c['procedures'])
self.assertCountEqual(['passiveChoiceWorld', 'ephysChoiceWorld', 'fooChoiceWorld'], (list(x)[0] for x in c['tasks']))
self.assertEqual(['passiveChoiceWorld', 'ephysChoiceWorld', 'fooChoiceWorld'], [list(x)[0] for x in c['tasks']])
# Ensure a and b not modified
self.assertNotEqual(set(c['procedures']), set(a['procedures']))
self.assertNotEqual(set(a['procedures']), set(b['procedures']))
# Test duplicate tasks skipped while order kept constant
d = {'tasks': [a['tasks'][1], {'ephysChoiceWorld': {'collection': 'raw_task_data_02', 'sync_label': 'nidq'}}]}
e = session_params.merge_params(c, d, copy=True)
expected = ['passiveChoiceWorld', 'ephysChoiceWorld', 'fooChoiceWorld', 'ephysChoiceWorld']
self.assertEqual(expected, [list(x)[0] for x in e['tasks']])
self.assertDictEqual({'collection': 'raw_task_data_02', 'sync_label': 'nidq'}, e['tasks'][-1]['ephysChoiceWorld'])
# Test without copy
session_params.merge_params(a, b, copy=False)
self.assertCountEqual(['Imaging', 'Behavior training/tasks', 'Injection'], a['procedures'])
Expand Down
36 changes: 16 additions & 20 deletions ibllib/tests/test_mesoscope.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,30 +102,26 @@ def test_consolidate_exptQC(self):
exptQC = [
{'frameQC_names': np.array(['ok', 'PMT off', 'galvos fault', 'high signal'], dtype=object),
'frameQC_frames': np.array([0, 0, 0, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 4])},
{'frameQC_names': np.array(['ok', 'PMT off', 'galvos fault', 'high signal'], dtype=object),
'frameQC_frames': np.zeros(50, dtype=int)}
{'frameQC_names': np.array(['ok', 'PMT off', 'foo', 'galvos fault', np.array([])], dtype=object),
'frameQC_frames': np.array([0, 0, 1, 1, 2, 2, 2, 2, 3, 4])},
{'frameQC_names': 'ok', # check with single str instead of array
'frameQC_frames': np.array([0, 0])}
]

# Check concatinates frame QC arrays
frameQC, frameQC_names, bad_frames = self.task._consolidate_exptQC(exptQC)
expected_frames = np.r_[exptQC[0]['frameQC_frames'], exptQC[1]['frameQC_frames']]
np.testing.assert_array_equal(expected_frames, frameQC)
expected = [3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]
frame_qc, frame_qc_names, bad_frames = self.task._consolidate_exptQC(exptQC)
# Check frame_qc array
expected_frames = [
0, 0, 0, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 5, 0, 0, 1, 1, 4, 4, 4, 4, 2, 5, 0, 0]
np.testing.assert_array_equal(expected_frames, frame_qc)
# Check bad_frames array
expected = [3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 18, 19, 20, 21, 22, 23, 24, 25]
np.testing.assert_array_equal(expected, bad_frames)
self.assertCountEqual(['qc_values', 'qc_labels'], frameQC_names.columns)
self.assertCountEqual(range(4), frameQC_names['qc_values'])
expected = ['ok', 'PMT off', 'galvos fault', 'high signal']
self.assertCountEqual(expected, frameQC_names['qc_labels'])

# Check with single str instead of array
exptQC[1]['frameQC_names'] = 'ok'
frameQC, frameQC_names, bad_frames = self.task._consolidate_exptQC(exptQC)
self.assertCountEqual(expected, frameQC_names['qc_labels'])
np.testing.assert_array_equal(expected_frames, frameQC)
# Check with inconsistent enumerations
exptQC[0]['frameQC_names'] = expected
exptQC[1]['frameQC_names'] = [*expected[-2:], *expected[:-2]]
self.assertRaises(IOError, self.task._consolidate_exptQC, exptQC)
# Check frame_qc_names data frame
self.assertCountEqual(['qc_values', 'qc_labels'], frame_qc_names.columns)
self.assertEqual(list(range(6)), frame_qc_names['qc_values'].tolist())
expected = ['ok', 'PMT off', 'galvos fault', 'high signal', 'foo', 'unknown']
self.assertCountEqual(expected, frame_qc_names['qc_labels'].tolist())

def test_setup_uncompressed(self):
"""Test set up behaviour when raw tifs present."""
Expand Down
12 changes: 3 additions & 9 deletions ibllib/tests/test_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,15 +153,9 @@ class TestPipelineAlyx(unittest.TestCase):

def setUp(self) -> None:
self.td = tempfile.TemporaryDirectory()
# ses = one.alyx.rest('sessions', 'list', subject=ses_dict['subject'],
# date_range=[ses_dict['start_time'][:10]] * 2,
# number=ses_dict['number'],
# no_cache=True)
# if len(ses):
# one.alyx.rest('sessions', 'delete', ses[0]['url'][-36:])
# randomise number
ses_dict['number'] = np.random.randint(1, 30)
ses = one.alyx.rest('sessions', 'create', data=ses_dict)
self.ses_dict = ses_dict.copy()
self.ses_dict['number'] = np.random.randint(1, 999)
ses = one.alyx.rest('sessions', 'create', data=self.ses_dict)
session_path = Path(self.td.name).joinpath(
ses['subject'], ses['start_time'][:10], str(ses['number']).zfill(3))
session_path.joinpath('alf').mkdir(exist_ok=True, parents=True)
Expand Down

0 comments on commit 1d11edc

Please sign in to comment.