Skip to content

Commit

Permalink
Update insertions cache with search_insertions (issue #153)
Browse files Browse the repository at this point in the history
  • Loading branch information
k1o0 committed Jan 9, 2025
1 parent 638a46e commit 21087e3
Show file tree
Hide file tree
Showing 3 changed files with 243 additions and 12 deletions.
149 changes: 142 additions & 7 deletions one/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,7 +415,7 @@ def _download_dataset(self, dset, cache_dir=None, **kwargs) -> ALFPath:
"""
pass # pragma: no cover

def search(self, details=False, query_type=None, **kwargs):
def search(self, details=False, **kwargs):
"""
Searches sessions matching the given criteria and returns a list of matching eids
Expand Down Expand Up @@ -458,8 +458,6 @@ def search(self, details=False, query_type=None, **kwargs):
will be found).
details : bool
If true also returns a dict of dataset details.
query_type : str, None
Query cache ('local') or Alyx database ('remote').
Returns
-------
Expand Down Expand Up @@ -525,6 +523,7 @@ def sort_fcn(itm):

# Validate and get full name for queries
search_terms = self.search_terms(query_type='local')
kwargs.pop('query_type', None) # used by subclasses
queries = {util.autocomplete(k, search_terms): v for k, v in kwargs.items()}
for key, value in sorted(queries.items(), key=sort_fcn):
# No matches; short circuit
Expand Down Expand Up @@ -576,6 +575,106 @@ def sort_fcn(itm):
else:
return eids

def search_insertions(self, details=False, **kwargs):
"""
Searches insertions matching the given criteria and returns a list of matching probe IDs.
For a list of search terms, use the method
one.search_terms(query_type='remote', endpoint='insertions')
All of the search parameters, apart from dataset and dataset type require a single value.
For dataset and dataset type, a single value or a list can be provided. Insertions
returned will contain all listed datasets.
Parameters
----------
session : str
A session eid, returns insertions associated with the session.
name: str
An insertion label, returns insertions with specified name.
lab : str
A lab name, returns insertions associated with the lab.
subject : str
A subject nickname, returns insertions associated with the subject.
task_protocol : str
A task protocol name (can be partial, i.e. any task protocol containing that str
will be found).
project(s) : str
The project name (can be partial, i.e. any task protocol containing that str
will be found).
dataset : str, list
One or more dataset names. Returns sessions containing all these datasets.
A dataset matches if it contains the search string e.g. 'wheel.position' matches
'_ibl_wheel.position.npy'.
dataset_qc_lte : int, str, one.alf.spec.QC
The maximum QC value for associated datasets.
details : bool
If true also returns a dict of dataset details.
Returns
-------
list
List of probe IDs (pids).
(list of dicts)
If details is True, also returns a list of dictionaries, each entry corresponding to a
matching insertion.
Notes
-----
- This method does not use the local cache and therefore can not work in 'local' mode.
Examples
--------
List the insertions associated with a given subject
>>> ins = one.search_insertions(subject='SWC_043')
"""
# Warn if no insertions table present
if (insertions := self._cache.get('insertions')) is None:
warnings.warn('No insertions data loaded.')
return ([], None) if details else []

# Validate and get full names
search_terms = ('model', 'name', 'json', 'serial', 'chronic_insertion')
search_terms += self._search_terms
kwargs.pop('query_type', None) # used by subclasses
arguments = {util.autocomplete(key, search_terms): value for key, value in kwargs.items()}
# Apply session filters first
session_kwargs = {k: v for k, v in arguments.items() if k in self._search_terms}
if session_kwargs:
eids = self.search(**session_kwargs, details=False, query_type='local')
insertions = insertions.loc[eids]
# Apply insertion filters
# Iterate over search filters, reducing the insertions table
for key, value in sorted(filter(lambda x: x[0] not in session_kwargs, kwargs.items())):
if insertions.size == 0:
return ([], None) if details else []
# String fields
elif key in ('model', 'serial', 'name'):
query = '|'.join(ensure_list(value))
mask = insertions[key].str.contains(query, regex=self.wildcards)
insertions = insertions[mask.astype(bool, copy=False)]
else:
raise NotImplementedError(key)

# Return results
if insertions.size == 0:
return ([], None) if details else []
# Sort insertions
eids = insertions.index.get_level_values('eid').unique()
# NB: This will raise if no session in cache; may need to improve error handling here
sessions = self._cache['sessions'].loc[eids, ['date', 'subject', 'number']]
insertions = (insertions
.join(sessions, how='inner')
.sort_values(['date', 'subject', 'number', 'name'], ascending=False))
pids = insertions.index.get_level_values('id').to_list()

if details: # TODO replicate Alyx records here
return pids, insertions.reset_index(drop=True).to_dict('records', into=Bunch)
else:
return pids

def _check_filesystem(self, datasets, offline=None, update_exists=True, check_hash=True):
"""Update the local filesystem for the given datasets.
Expand Down Expand Up @@ -2136,9 +2235,9 @@ def search_insertions(self, details=False, query_type=None, **kwargs):
... ins = one.search_insertions(django='datasets__tags__name,' + tag)
"""
query_type = query_type or self.mode
if query_type == 'local' and 'insertions' not in self._cache.keys():
raise NotImplementedError('Searching on insertions required remote connection')
elif query_type == 'auto':
if query_type == 'local':
return super().search_insertions(details=details, query_type=query_type, **kwargs)
elif query_type == 'auto': # NB behaviour here may change in the future
_logger.debug('OneAlyx.search_insertions only supports remote queries')
query_type = 'remote'
# Get remote query params from REST endpoint
Expand All @@ -2165,12 +2264,48 @@ def search_insertions(self, details=False, query_type=None, **kwargs):
params.pop('django')

ins = self.alyx.rest('insertions', 'list', **params)
# Update cache table with results
if len(ins) == 0:
pass # no need to update cache here
elif isinstance(ins, list): # not a paginated response
self._update_insetions_table(ins)
else:
# populate first page
self._update_insetions_table(ins._cache[:ins.limit])
# Add callback for updating cache on future fetches
ins.add_callback(WeakMethod(self._update_insetions_table))

pids = util.LazyId(ins)
if not details:
return pids

return pids, ins

def _update_insetions_table(self, insertions_records):
"""Update the insertions tables with a list of insertions records.
Parameters
----------
insertions_records : list of dict
A list of insertions records from the /insertions list endpoint.
Returns
-------
datetime.datetime
A timestamp of when the cache was updated.
"""
df = (pd.DataFrame(insertions_records)
.drop(['session_info'], axis=1)
.rename({'session': 'eid'}, axis=1)
.set_index(['eid', 'id'])
.sort_index())
if 'insertions' not in self._cache:
self._cache['insertions'] = df.iloc[0:0]
# Build sessions table
session_records = (x['session_info'] for x in insertions_records)
sessions_df = pd.DataFrame(next(zip(*map(ses2records, session_records))))
return self._update_cache_from_records(insertions=df, sessions=sessions_df)

def search(self, details=False, query_type=None, **kwargs):
"""
Searches sessions matching the given criteria and returns a list of matching eids.
Expand Down Expand Up @@ -2330,7 +2465,7 @@ def _update_sessions_table(self, session_records):
Returns
-------
datetime.datetime:
datetime.datetime
A timestamp of when the cache was updated.
"""
df = pd.DataFrame(next(zip(*map(ses2records, session_records))))
Expand Down
2 changes: 1 addition & 1 deletion one/converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -757,7 +757,7 @@ def ses2records(ses: dict):
Datasets frame.
"""
# Extract session record
eid = ses['url'][-36:]
eid = ses.get('id') or ses['url'][-36:] # id used for session_info field of probe insertion
session_keys = ('subject', 'start_time', 'lab', 'number', 'task_protocol', 'projects')
session_data = {k: v for k, v in ses.items() if k in session_keys}
session = (
Expand Down
104 changes: 100 additions & 4 deletions one/tests/test_one.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,19 +90,19 @@ def tearDown(self) -> None:
self.tempdir.cleanup()

def test_list_subjects(self):
"""Test One.list_subejcts"""
"""Test One.list_subejcts."""
subjects = self.one.list_subjects()
expected = ['KS005', 'ZFM-01935', 'ZM_1094', 'ZM_1150',
'ZM_1743', 'ZM_335', 'clns0730', 'flowers']
self.assertCountEqual(expected, subjects)

def test_offline_repr(self):
"""Test for One.offline property"""
"""Test for One.offline property."""
self.assertTrue('offline' in str(self.one))
self.assertTrue(str(self.tempdir.name) in str(self.one))

def test_one_search(self):
"""Test for One.search"""
"""Test for One.search."""
one = self.one
# Search subject
eids = one.search(subject='ZM_335')
Expand Down Expand Up @@ -212,6 +212,63 @@ def test_one_search(self):
self.assertEqual(len(eids), len(details))
self.assertCountEqual(details[0].keys(), self.one._cache.sessions.columns)

def test_search_insertions(self):
"""Test for One.search_insertions."""
one = self.one
# Create some records (eids taken from sessions cache fixture)
insertions = [
{'model': '3A', 'name': 'probe00', 'json': {}, 'serial': '19051010091',
'chronic_insertion': None, 'datasets': [uuid4(), uuid4()], 'id': str(uuid4()),
'eid': '01390fcc-4f86-4707-8a3b-4d9309feb0a1'},
{'model': 'Fiber', 'name': 'fiber00', 'json': {}, 'serial': '18000010000',
'chronic_insertion': str(uuid4()), 'datasets': [], 'id': str(uuid4()),
'eid': 'aaf101c3-2581-450a-8abd-ddb8f557a5ad'}
]
for i in range(2):
insertions.append({
'model': '3B2', 'name': f'probe{i:02}', 'json': {}, 'serial': f'19051010{i}90',
'chronic_insertion': None, 'datasets': [uuid4(), uuid4()], 'id': str(uuid4()),
'eid': '4e0b3320-47b7-416e-b842-c34dc9004cf8'
})
one._cache['insertions'] = pd.DataFrame(insertions).set_index(['eid', 'id']).sort_index()

# Search model
pids = one.search_insertions(model='3B2')
self.assertEqual(2, len(pids))
pids = one.search_insertions(model=['3B2', '3A'])
self.assertEqual(3, len(pids))

# Search name
pids = one.search_insertions(name='probe00')
self.assertEqual(2, len(pids))
pids = one.search_insertions(name='probe00', model='3B2')
self.assertEqual(1, len(pids))

# Search session details
pids = one.search_insertions(subject='flowers')
self.assertEqual(2, len(pids))
pids = one.search_insertions(subject='flowers', name='probe00')
self.assertEqual(1, len(pids))

# Unimplemented keys
self.assertRaises(NotImplementedError, one.search_insertions, json='foo')

# Details
pids, details = one.search_insertions(name='probe00', details=True)
self.assertEqual({'probe00'}, set(x['name'] for x in details))

# Check returned sorted by date, subject, number, and name
pids, details = one.search_insertions(details=True)
expected = sorted([d['date'] for d in details], reverse=True)
self.assertEqual(expected, [d['date'] for d in details])

# Empty returns
self.assertEqual([], one.search_insertions(model='3A', name='fiber00', serial='123'))
self.assertEqual([], one.search_insertions(model='foo'))
del one._cache['insertions']
with self.assertWarns(UserWarning):
self.assertEqual([], one.search_insertions())

def test_filter(self):
"""Test one.util.filter_datasets"""
datasets = self.one._cache.datasets.iloc[:5].copy()
Expand Down Expand Up @@ -1581,6 +1638,18 @@ def test_search(self):
eids = self.one.search(dataset_type='trials.table', date='2020-09-21', query_type='remote')
self.assertIn(self.eid, list(eids))

# Ensure that when calling with anything other than remote mode, One.search is used
with mock.patch('one.api.One.search') as offline_search, \
mock.patch.object(self.one.alyx, 'rest', return_value=[]) as alyx:
# In remote mode
self.one.search(subject='SWC_043', query_type='remote')
offline_search.assert_not_called(), alyx.assert_called()
alyx.reset_mock()
# In another mode
self.one.search(subject='SWC_043', query_type='auto')
offline_search.assert_called_with(details=False, query_type='auto', subject='SWC_043')
alyx.assert_not_called()

def test_search_insertions(self):
"""Test OneAlyx.search_insertion method in remote mode."""

Expand Down Expand Up @@ -1620,9 +1689,36 @@ def test_search_insertions(self):
self.assertEqual({lab}, {x['session_info']['lab'] for x in det})

# Test mode and field validation
self.assertRaises(NotImplementedError, self.one.search_insertions, query_type='local')
self.assertRaises(TypeError, self.one.search_insertions,
dataset=['wheel.times'], query_type='remote')
# Ensure that when calling with anything other than remote mode, One is used
with mock.patch('one.api.One.search_insertions') as offline_search, \
mock.patch.object(self.one.alyx, 'rest', return_value=[]) as alyx:
# In remote mode
self.one.search_insertions(subject='SWC_043', query_type='remote')
offline_search.assert_not_called(), alyx.assert_called()
alyx.reset_mock()
# In local mode
self.one.search_insertions(subject='SWC_043', query_type='local')
offline_search.assert_called_with(details=False, query_type='local', subject='SWC_043')
alyx.assert_not_called()

# Test limit arg, LazyId, and update with paginated response callback
self.one._reset_cache() # Remove insertions table
assert 'insertions' not in self.one._cache
pids = self.one.search_insertions(limit=2, query_type='remote')
self.assertEqual(2, len(self.one._cache.insertions),
'failed to update insertions cache with first page of search results')
self.assertEqual(2, len(self.one._cache.sessions),
'failed to update sessions cache with first page of search results')
self.assertIsInstance(pids, LazyId)
assert len(pids) > 5, 'in order to check paginated response callback we need several pages'
p = pids[-3] # access an uncached value
self.assertEqual(4, len(self.one._cache.insertions),
'failed to update insertions cache after page access')
self.assertEqual(4, len(self.one._cache.sessions),
'failed to update insertions cache after page access')
self.assertTrue(p in self.one._cache.insertions.index.get_level_values('id'))

def test_search_terms(self):
"""Test OneAlyx.search_terms."""
Expand Down

0 comments on commit 21087e3

Please sign in to comment.