diff --git a/one/api.py b/one/api.py index c97196a..1067a60 100644 --- a/one/api.py +++ b/one/api.py @@ -415,7 +415,7 @@ def _download_dataset(self, dset, cache_dir=None, **kwargs) -> ALFPath: """ pass # pragma: no cover - def search(self, details=False, query_type=None, **kwargs): + def search(self, details=False, **kwargs): """ Searches sessions matching the given criteria and returns a list of matching eids @@ -458,8 +458,6 @@ def search(self, details=False, query_type=None, **kwargs): will be found). details : bool If true also returns a dict of dataset details. - query_type : str, None - Query cache ('local') or Alyx database ('remote'). Returns ------- @@ -525,6 +523,7 @@ def sort_fcn(itm): # Validate and get full name for queries search_terms = self.search_terms(query_type='local') + kwargs.pop('query_type', None) # used by subclasses queries = {util.autocomplete(k, search_terms): v for k, v in kwargs.items()} for key, value in sorted(queries.items(), key=sort_fcn): # No matches; short circuit @@ -576,6 +575,106 @@ def sort_fcn(itm): else: return eids + def search_insertions(self, details=False, **kwargs): + """ + Searches insertions matching the given criteria and returns a list of matching probe IDs. + + For a list of search terms, use the method + + one.search_terms(query_type='remote', endpoint='insertions') + + All of the search parameters, apart from dataset and dataset type require a single value. + For dataset and dataset type, a single value or a list can be provided. Insertions + returned will contain all listed datasets. + + Parameters + ---------- + session : str + A session eid, returns insertions associated with the session. + name: str + An insertion label, returns insertions with specified name. + lab : str + A lab name, returns insertions associated with the lab. + subject : str + A subject nickname, returns insertions associated with the subject. + task_protocol : str + A task protocol name (can be partial, i.e. any task protocol containing that str + will be found). + project(s) : str + The project name (can be partial, i.e. any task protocol containing that str + will be found). + dataset : str, list + One or more dataset names. Returns sessions containing all these datasets. + A dataset matches if it contains the search string e.g. 'wheel.position' matches + '_ibl_wheel.position.npy'. + dataset_qc_lte : int, str, one.alf.spec.QC + The maximum QC value for associated datasets. + details : bool + If true also returns a dict of dataset details. + + Returns + ------- + list + List of probe IDs (pids). + (list of dicts) + If details is True, also returns a list of dictionaries, each entry corresponding to a + matching insertion. + + Notes + ----- + - This method does not use the local cache and therefore can not work in 'local' mode. + + Examples + -------- + List the insertions associated with a given subject + + >>> ins = one.search_insertions(subject='SWC_043') + """ + # Warn if no insertions table present + if (insertions := self._cache.get('insertions')) is None: + warnings.warn('No insertions data loaded.') + return ([], None) if details else [] + + # Validate and get full names + search_terms = ('model', 'name', 'json', 'serial', 'chronic_insertion') + search_terms += self._search_terms + kwargs.pop('query_type', None) # used by subclasses + arguments = {util.autocomplete(key, search_terms): value for key, value in kwargs.items()} + # Apply session filters first + session_kwargs = {k: v for k, v in arguments.items() if k in self._search_terms} + if session_kwargs: + eids = self.search(**session_kwargs, details=False, query_type='local') + insertions = insertions.loc[eids] + # Apply insertion filters + # Iterate over search filters, reducing the insertions table + for key, value in sorted(filter(lambda x: x[0] not in session_kwargs, kwargs.items())): + if insertions.size == 0: + return ([], None) if details else [] + # String fields + elif key in ('model', 'serial', 'name'): + query = '|'.join(ensure_list(value)) + mask = insertions[key].str.contains(query, regex=self.wildcards) + insertions = insertions[mask.astype(bool, copy=False)] + else: + raise NotImplementedError(key) + + # Return results + if insertions.size == 0: + return ([], None) if details else [] + # Sort insertions + eids = insertions.index.get_level_values('eid').unique() + # NB: This will raise if no session in cache; may need to improve error handling here + sessions = self._cache['sessions'].loc[eids, ['date', 'subject', 'number']] + insertions = (insertions + .join(sessions, how='inner') + .sort_values(['date', 'subject', 'number', 'name'], ascending=False)) + pids = insertions.index.get_level_values('id').to_list() + + if details: # TODO replicate Alyx records here + return pids, insertions.reset_index(drop=True).to_dict('records', into=Bunch) + else: + return pids + def _check_filesystem(self, datasets, offline=None, update_exists=True, check_hash=True): """Update the local filesystem for the given datasets. @@ -2136,9 +2235,9 @@ def search_insertions(self, details=False, query_type=None, **kwargs): ... ins = one.search_insertions(django='datasets__tags__name,' + tag) """ query_type = query_type or self.mode - if query_type == 'local' and 'insertions' not in self._cache.keys(): - raise NotImplementedError('Searching on insertions required remote connection') - elif query_type == 'auto': + if query_type == 'local': + return super().search_insertions(details=details, query_type=query_type, **kwargs) + elif query_type == 'auto': # NB behaviour here may change in the future _logger.debug('OneAlyx.search_insertions only supports remote queries') query_type = 'remote' # Get remote query params from REST endpoint @@ -2165,12 +2264,48 @@ def search_insertions(self, details=False, query_type=None, **kwargs): params.pop('django') ins = self.alyx.rest('insertions', 'list', **params) + # Update cache table with results + if len(ins) == 0: + pass # no need to update cache here + elif isinstance(ins, list): # not a paginated response + self._update_insetions_table(ins) + else: + # populate first page + self._update_insetions_table(ins._cache[:ins.limit]) + # Add callback for updating cache on future fetches + ins.add_callback(WeakMethod(self._update_insetions_table)) + pids = util.LazyId(ins) if not details: return pids return pids, ins + def _update_insetions_table(self, insertions_records): + """Update the insertions tables with a list of insertions records. + + Parameters + ---------- + insertions_records : list of dict + A list of insertions records from the /insertions list endpoint. + + Returns + ------- + datetime.datetime + A timestamp of when the cache was updated. + """ + df = (pd.DataFrame(insertions_records) + .drop(['session_info'], axis=1) + .rename({'session': 'eid'}, axis=1) + .set_index(['eid', 'id']) + .sort_index()) + if 'insertions' not in self._cache: + self._cache['insertions'] = df.iloc[0:0] + # Build sessions table + session_records = (x['session_info'] for x in insertions_records) + sessions_df = pd.DataFrame(next(zip(*map(ses2records, session_records)))) + return self._update_cache_from_records(insertions=df, sessions=sessions_df) + def search(self, details=False, query_type=None, **kwargs): """ Searches sessions matching the given criteria and returns a list of matching eids. @@ -2330,7 +2465,7 @@ def _update_sessions_table(self, session_records): Returns ------- - datetime.datetime: + datetime.datetime A timestamp of when the cache was updated. """ df = pd.DataFrame(next(zip(*map(ses2records, session_records)))) diff --git a/one/converters.py b/one/converters.py index 3c04080..3e3928f 100644 --- a/one/converters.py +++ b/one/converters.py @@ -757,7 +757,7 @@ def ses2records(ses: dict): Datasets frame. """ # Extract session record - eid = ses['url'][-36:] + eid = ses.get('id') or ses['url'][-36:] # id used for session_info field of probe insertion session_keys = ('subject', 'start_time', 'lab', 'number', 'task_protocol', 'projects') session_data = {k: v for k, v in ses.items() if k in session_keys} session = ( diff --git a/one/tests/test_one.py b/one/tests/test_one.py index c21862e..b2580b5 100644 --- a/one/tests/test_one.py +++ b/one/tests/test_one.py @@ -90,19 +90,19 @@ def tearDown(self) -> None: self.tempdir.cleanup() def test_list_subjects(self): - """Test One.list_subejcts""" + """Test One.list_subejcts.""" subjects = self.one.list_subjects() expected = ['KS005', 'ZFM-01935', 'ZM_1094', 'ZM_1150', 'ZM_1743', 'ZM_335', 'clns0730', 'flowers'] self.assertCountEqual(expected, subjects) def test_offline_repr(self): - """Test for One.offline property""" + """Test for One.offline property.""" self.assertTrue('offline' in str(self.one)) self.assertTrue(str(self.tempdir.name) in str(self.one)) def test_one_search(self): - """Test for One.search""" + """Test for One.search.""" one = self.one # Search subject eids = one.search(subject='ZM_335') @@ -212,6 +212,63 @@ def test_one_search(self): self.assertEqual(len(eids), len(details)) self.assertCountEqual(details[0].keys(), self.one._cache.sessions.columns) + def test_search_insertions(self): + """Test for One.search_insertions.""" + one = self.one + # Create some records (eids taken from sessions cache fixture) + insertions = [ + {'model': '3A', 'name': 'probe00', 'json': {}, 'serial': '19051010091', + 'chronic_insertion': None, 'datasets': [uuid4(), uuid4()], 'id': str(uuid4()), + 'eid': '01390fcc-4f86-4707-8a3b-4d9309feb0a1'}, + {'model': 'Fiber', 'name': 'fiber00', 'json': {}, 'serial': '18000010000', + 'chronic_insertion': str(uuid4()), 'datasets': [], 'id': str(uuid4()), + 'eid': 'aaf101c3-2581-450a-8abd-ddb8f557a5ad'} + ] + for i in range(2): + insertions.append({ + 'model': '3B2', 'name': f'probe{i:02}', 'json': {}, 'serial': f'19051010{i}90', + 'chronic_insertion': None, 'datasets': [uuid4(), uuid4()], 'id': str(uuid4()), + 'eid': '4e0b3320-47b7-416e-b842-c34dc9004cf8' + }) + one._cache['insertions'] = pd.DataFrame(insertions).set_index(['eid', 'id']).sort_index() + + # Search model + pids = one.search_insertions(model='3B2') + self.assertEqual(2, len(pids)) + pids = one.search_insertions(model=['3B2', '3A']) + self.assertEqual(3, len(pids)) + + # Search name + pids = one.search_insertions(name='probe00') + self.assertEqual(2, len(pids)) + pids = one.search_insertions(name='probe00', model='3B2') + self.assertEqual(1, len(pids)) + + # Search session details + pids = one.search_insertions(subject='flowers') + self.assertEqual(2, len(pids)) + pids = one.search_insertions(subject='flowers', name='probe00') + self.assertEqual(1, len(pids)) + + # Unimplemented keys + self.assertRaises(NotImplementedError, one.search_insertions, json='foo') + + # Details + pids, details = one.search_insertions(name='probe00', details=True) + self.assertEqual({'probe00'}, set(x['name'] for x in details)) + + # Check returned sorted by date, subject, number, and name + pids, details = one.search_insertions(details=True) + expected = sorted([d['date'] for d in details], reverse=True) + self.assertEqual(expected, [d['date'] for d in details]) + + # Empty returns + self.assertEqual([], one.search_insertions(model='3A', name='fiber00', serial='123')) + self.assertEqual([], one.search_insertions(model='foo')) + del one._cache['insertions'] + with self.assertWarns(UserWarning): + self.assertEqual([], one.search_insertions()) + def test_filter(self): """Test one.util.filter_datasets""" datasets = self.one._cache.datasets.iloc[:5].copy() @@ -1581,6 +1638,18 @@ def test_search(self): eids = self.one.search(dataset_type='trials.table', date='2020-09-21', query_type='remote') self.assertIn(self.eid, list(eids)) + # Ensure that when calling with anything other than remote mode, One.search is used + with mock.patch('one.api.One.search') as offline_search, \ + mock.patch.object(self.one.alyx, 'rest', return_value=[]) as alyx: + # In remote mode + self.one.search(subject='SWC_043', query_type='remote') + offline_search.assert_not_called(), alyx.assert_called() + alyx.reset_mock() + # In another mode + self.one.search(subject='SWC_043', query_type='auto') + offline_search.assert_called_with(details=False, query_type='auto', subject='SWC_043') + alyx.assert_not_called() + def test_search_insertions(self): """Test OneAlyx.search_insertion method in remote mode.""" @@ -1620,9 +1689,36 @@ def test_search_insertions(self): self.assertEqual({lab}, {x['session_info']['lab'] for x in det}) # Test mode and field validation - self.assertRaises(NotImplementedError, self.one.search_insertions, query_type='local') self.assertRaises(TypeError, self.one.search_insertions, dataset=['wheel.times'], query_type='remote') + # Ensure that when calling with anything other than remote mode, One is used + with mock.patch('one.api.One.search_insertions') as offline_search, \ + mock.patch.object(self.one.alyx, 'rest', return_value=[]) as alyx: + # In remote mode + self.one.search_insertions(subject='SWC_043', query_type='remote') + offline_search.assert_not_called(), alyx.assert_called() + alyx.reset_mock() + # In local mode + self.one.search_insertions(subject='SWC_043', query_type='local') + offline_search.assert_called_with(details=False, query_type='local', subject='SWC_043') + alyx.assert_not_called() + + # Test limit arg, LazyId, and update with paginated response callback + self.one._reset_cache() # Remove insertions table + assert 'insertions' not in self.one._cache + pids = self.one.search_insertions(limit=2, query_type='remote') + self.assertEqual(2, len(self.one._cache.insertions), + 'failed to update insertions cache with first page of search results') + self.assertEqual(2, len(self.one._cache.sessions), + 'failed to update sessions cache with first page of search results') + self.assertIsInstance(pids, LazyId) + assert len(pids) > 5, 'in order to check paginated response callback we need several pages' + p = pids[-3] # access an uncached value + self.assertEqual(4, len(self.one._cache.insertions), + 'failed to update insertions cache after page access') + self.assertEqual(4, len(self.one._cache.sessions), + 'failed to update insertions cache after page access') + self.assertTrue(p in self.one._cache.insertions.index.get_level_values('id')) def test_search_terms(self): """Test OneAlyx.search_terms."""