Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Duplicate experiment description tasks #867

Merged
merged 2 commits into from
Oct 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions ibllib/io/session_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import logging
import socket
from pathlib import Path
from itertools import chain
from copy import deepcopy

from one.converters import ConversionMixin
Expand Down Expand Up @@ -77,6 +78,9 @@ def _patch_file(data: dict) -> dict:
if 'tasks' in data and isinstance(data['tasks'], dict):
data['tasks'] = [{k: v} for k, v in data['tasks'].copy().items()]
data['version'] = SPEC_VERSION
# Ensure all items in tasks list are single value dicts
if 'tasks' in data:
data['tasks'] = [{k: v} for k, v in chain.from_iterable(map(dict.items, data['tasks']))]
return data


Expand Down Expand Up @@ -168,8 +172,19 @@ def merge_params(a, b, copy=False):
assert k not in a or a[k] == b[k], 'multiple sync fields defined'
if isinstance(b[k], list):
prev = list(a.get(k, []))
# For procedures and projects, remove duplicates
to_add = b[k] if k == 'tasks' else set(b[k]) - set(prev)
if k == 'tasks':
# For tasks, keep order and skip duplicates
# Assert tasks is a list of single value dicts
assert (not prev or set(map(len, prev)) == {1}) and set(map(len, b[k])) == {1}
# Convert protocol -> dict map to hashable tuple of protocol + sorted key value pairs
to_hashable = lambda itm: (itm[0], *chain.from_iterable(sorted(itm[1].items()))) # noqa
# Get the set of previous tasks
prev_tasks = set(map(to_hashable, chain.from_iterable(map(dict.items, prev))))
tasks = chain.from_iterable(map(dict.items, b[k]))
to_add = [dict([itm]) for itm in tasks if to_hashable(itm) not in prev_tasks]
else:
# For procedures and projects, remove duplicates
to_add = set(b[k]) - set(prev)
a[k] = prev + list(to_add)
elif isinstance(b[k], dict):
a[k] = {**a.get(k, {}), **b[k]}
Expand Down
20 changes: 19 additions & 1 deletion ibllib/tests/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,10 +527,22 @@ def test_read_yaml(self):
self.assertCountEqual(self.fixture.keys(), data_keys)

def test_patch_data(self):
"""Test for session_params._patch_file function."""
with patch(session_params.__name__ + '.SPEC_VERSION', '1.0.0'), \
self.assertLogs(session_params.__name__, logging.WARNING):
data = session_params._patch_file({'version': '1.1.0'})
self.assertEqual(data, {'version': '1.0.0'})
# Check tasks dicts separated into lists
unpatched = {'version': '0.0.1', 'tasks': {
'fooChoiceWorld': {1: '1'}, 'barChoiceWorld': {2: '2'}}}
data = session_params._patch_file(unpatched)
self.assertIsInstance(data['tasks'], list)
self.assertEqual([['fooChoiceWorld'], ['barChoiceWorld']], list(map(list, data['tasks'])))
# Check patching list of dicts with some containing more than 1 key
unpatched = {'tasks': [{'foo': {1: '1'}}, {'bar': {2: '2'}, 'baz': {3: '3'}}]}
data = session_params._patch_file(unpatched)
self.assertEqual(3, len(data['tasks']))
self.assertEqual([['foo'], ['bar'], ['baz']], list(map(list, data['tasks'])))

def test_get_collections(self):
collections = session_params.get_collections(self.fixture)
Expand Down Expand Up @@ -561,10 +573,16 @@ def test_merge_params(self):
b = {'procedures': ['Imaging', 'Injection'], 'tasks': [{'fooChoiceWorld': {'collection': 'bar'}}]}
c = session_params.merge_params(a, b, copy=True)
self.assertCountEqual(['Imaging', 'Behavior training/tasks', 'Injection'], c['procedures'])
self.assertCountEqual(['passiveChoiceWorld', 'ephysChoiceWorld', 'fooChoiceWorld'], (list(x)[0] for x in c['tasks']))
self.assertEqual(['passiveChoiceWorld', 'ephysChoiceWorld', 'fooChoiceWorld'], [list(x)[0] for x in c['tasks']])
# Ensure a and b not modified
self.assertNotEqual(set(c['procedures']), set(a['procedures']))
self.assertNotEqual(set(a['procedures']), set(b['procedures']))
# Test duplicate tasks skipped while order kept constant
d = {'tasks': [a['tasks'][1], {'ephysChoiceWorld': {'collection': 'raw_task_data_02', 'sync_label': 'nidq'}}]}
e = session_params.merge_params(c, d, copy=True)
expected = ['passiveChoiceWorld', 'ephysChoiceWorld', 'fooChoiceWorld', 'ephysChoiceWorld']
self.assertEqual(expected, [list(x)[0] for x in e['tasks']])
self.assertDictEqual({'collection': 'raw_task_data_02', 'sync_label': 'nidq'}, e['tasks'][-1]['ephysChoiceWorld'])
# Test without copy
session_params.merge_params(a, b, copy=False)
self.assertCountEqual(['Imaging', 'Behavior training/tasks', 'Injection'], a['procedures'])
Expand Down
Loading