From 063c7133b999d2f046df3edd3770c7efc6bd9495 Mon Sep 17 00:00:00 2001 From: Miles Wells Date: Thu, 17 Oct 2024 15:38:45 +0300 Subject: [PATCH 1/2] Fix for https://github.com/int-brain-lab/iblrig/issues/728 --- ibllib/io/session_params.py | 19 +++++++++++++++++-- ibllib/tests/test_io.py | 20 +++++++++++++++++++- 2 files changed, 36 insertions(+), 3 deletions(-) diff --git a/ibllib/io/session_params.py b/ibllib/io/session_params.py index c344795c9..6e4848a40 100644 --- a/ibllib/io/session_params.py +++ b/ibllib/io/session_params.py @@ -27,6 +27,7 @@ import logging import socket from pathlib import Path +from itertools import chain from copy import deepcopy from one.converters import ConversionMixin @@ -77,6 +78,9 @@ def _patch_file(data: dict) -> dict: if 'tasks' in data and isinstance(data['tasks'], dict): data['tasks'] = [{k: v} for k, v in data['tasks'].copy().items()] data['version'] = SPEC_VERSION + # Ensure all items in tasks list are single value dicts + if 'tasks' in data: + data['tasks'] = [{k: v} for k, v in chain.from_iterable(map(dict.items, data['tasks']))] return data @@ -168,8 +172,19 @@ def merge_params(a, b, copy=False): assert k not in a or a[k] == b[k], 'multiple sync fields defined' if isinstance(b[k], list): prev = list(a.get(k, [])) - # For procedures and projects, remove duplicates - to_add = b[k] if k == 'tasks' else set(b[k]) - set(prev) + if k == 'tasks': + # For tasks, keep order and skip duplicates + # Assert tasks is a list of single value dicts + assert set(map(len, prev)) == {1} and set(map(len, b[k])) == {1} + # Convert protocol -> dict map to hashable tuple of protocol + sorted key value pairs + to_hashable = lambda itm: (itm[0], *chain.from_iterable(sorted(itm[1].items()))) # noqa + # Get the set of previous tasks + prev_tasks = set(map(to_hashable, chain.from_iterable(map(dict.items, prev)))) + tasks = chain.from_iterable(map(dict.items, b[k])) + to_add = [dict([itm]) for itm in tasks if to_hashable(itm) not in prev_tasks] + else: + # For procedures and projects, remove duplicates + to_add = set(b[k]) - set(prev) a[k] = prev + list(to_add) elif isinstance(b[k], dict): a[k] = {**a.get(k, {}), **b[k]} diff --git a/ibllib/tests/test_io.py b/ibllib/tests/test_io.py index 3ef9574e4..661c4d2a4 100644 --- a/ibllib/tests/test_io.py +++ b/ibllib/tests/test_io.py @@ -527,10 +527,22 @@ def test_read_yaml(self): self.assertCountEqual(self.fixture.keys(), data_keys) def test_patch_data(self): + """Test for session_params._patch_file function.""" with patch(session_params.__name__ + '.SPEC_VERSION', '1.0.0'), \ self.assertLogs(session_params.__name__, logging.WARNING): data = session_params._patch_file({'version': '1.1.0'}) self.assertEqual(data, {'version': '1.0.0'}) + # Check tasks dicts separated into lists + unpatched = {'version': '0.0.1', 'tasks': { + 'fooChoiceWorld': {1: '1'}, 'barChoiceWorld': {2: '2'}}} + data = session_params._patch_file(unpatched) + self.assertIsInstance(data['tasks'], list) + self.assertEqual([['fooChoiceWorld'], ['barChoiceWorld']], list(map(list, data['tasks']))) + # Check patching list of dicts with some containing more than 1 key + unpatched = {'tasks': [{'foo': {1: '1'}}, {'bar': {2: '2'}, 'baz': {3: '3'}}]} + data = session_params._patch_file(unpatched) + self.assertEqual(3, len(data['tasks'])) + self.assertEqual([['foo'], ['bar'], ['baz']], list(map(list, data['tasks']))) def test_get_collections(self): collections = session_params.get_collections(self.fixture) @@ -561,10 +573,16 @@ def test_merge_params(self): b = {'procedures': ['Imaging', 'Injection'], 'tasks': [{'fooChoiceWorld': {'collection': 'bar'}}]} c = session_params.merge_params(a, b, copy=True) self.assertCountEqual(['Imaging', 'Behavior training/tasks', 'Injection'], c['procedures']) - self.assertCountEqual(['passiveChoiceWorld', 'ephysChoiceWorld', 'fooChoiceWorld'], (list(x)[0] for x in c['tasks'])) + self.assertEqual(['passiveChoiceWorld', 'ephysChoiceWorld', 'fooChoiceWorld'], [list(x)[0] for x in c['tasks']]) # Ensure a and b not modified self.assertNotEqual(set(c['procedures']), set(a['procedures'])) self.assertNotEqual(set(a['procedures']), set(b['procedures'])) + # Test duplicate tasks skipped while order kept constant + d = {'tasks': [a['tasks'][1], {'ephysChoiceWorld': {'collection': 'raw_task_data_02', 'sync_label': 'nidq'}}]} + e = session_params.merge_params(c, d, copy=True) + expected = ['passiveChoiceWorld', 'ephysChoiceWorld', 'fooChoiceWorld', 'ephysChoiceWorld'] + self.assertEqual(expected, [list(x)[0] for x in e['tasks']]) + self.assertDictEqual({'collection': 'raw_task_data_02', 'sync_label': 'nidq'}, e['tasks'][-1]['ephysChoiceWorld']) # Test without copy session_params.merge_params(a, b, copy=False) self.assertCountEqual(['Imaging', 'Behavior training/tasks', 'Injection'], a['procedures']) From 4308f6d79debd7cdd8483ef89556ef1c38c998a8 Mon Sep 17 00:00:00 2001 From: Miles Wells Date: Thu, 17 Oct 2024 16:14:14 +0300 Subject: [PATCH 2/2] Handle empty tasks key --- ibllib/io/session_params.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ibllib/io/session_params.py b/ibllib/io/session_params.py index 6e4848a40..e9127e9ae 100644 --- a/ibllib/io/session_params.py +++ b/ibllib/io/session_params.py @@ -175,7 +175,7 @@ def merge_params(a, b, copy=False): if k == 'tasks': # For tasks, keep order and skip duplicates # Assert tasks is a list of single value dicts - assert set(map(len, prev)) == {1} and set(map(len, b[k])) == {1} + assert (not prev or set(map(len, prev)) == {1}) and set(map(len, b[k])) == {1} # Convert protocol -> dict map to hashable tuple of protocol + sorted key value pairs to_hashable = lambda itm: (itm[0], *chain.from_iterable(sorted(itm[1].items()))) # noqa # Get the set of previous tasks