From 270fdd43db73a7e4cbece60599c80db07fc42e86 Mon Sep 17 00:00:00 2001 From: Miles Wells Date: Wed, 30 Aug 2023 10:50:29 +0300 Subject: [PATCH 01/68] Deprecate globus module --- ibllib/io/globus.py | 39 ++++++++++++++++++++++++++------ ibllib/io/raw_data_loaders.py | 13 +++++++---- ibllib/pipes/dynamic_pipeline.py | 6 ++--- ibllib/tests/test_io.py | 4 +++- 4 files changed, 45 insertions(+), 17 deletions(-) diff --git a/ibllib/io/globus.py b/ibllib/io/globus.py index 15bd8f9a1..492847866 100644 --- a/ibllib/io/globus.py +++ b/ibllib/io/globus.py @@ -1,17 +1,33 @@ -"""TODO: This entire module may be removed in favour of one.remote.globus""" +"""(DEPRECATED) Globus SDK utility functions. + +This has been deprecated in favour of the one.remote.globus module. +""" import re import sys import os from pathlib import Path +import warnings +import traceback +import logging import globus_sdk as globus from iblutil.io import params +for line in traceback.format_stack(): + print(line.strip()) + +msg = 'ibllib.io.globus has been deprecated. Use one.remote.globus instead. See stack above' +warnings.warn(msg, DeprecationWarning) +logging.getLogger(__name__).warning(msg) + + def as_globus_path(path): """ - Convert a path into one suitable for the Globus TransferClient. NB: If using tilda in path, - the home folder of your Globus Connect instance must be the same as the OS home dir. + (DEPRECATED) Convert a path into one suitable for the Globus TransferClient. + + NB: If using tilda in path, the home folder of your Globus Connect instance must be the same as + the OS home dir. :param path: A path str or Path instance :return: A formatted path string @@ -30,6 +46,9 @@ def as_globus_path(path): >>> '/E/FlatIron/integration' TODO Remove in favour of one.remote.globus.as_globus_path """ + msg = 'ibllib.io.globus.as_globus_path has been deprecated. Use one.remote.globus.as_globus_path instead.' + warnings.warn(msg, DeprecationWarning) + path = str(path) if ( re.match(r'/[A-Z]($|/)', path) @@ -64,7 +83,9 @@ def _login(globus_client_id, refresh_tokens=False): def login(globus_client_id): - # TODO Import from one.remove.globus + msg = 'ibllib.io.globus.login has been deprecated. Use one.remote.globus.Globus instead.' + warnings.warn(msg, DeprecationWarning) + token = _login(globus_client_id, refresh_tokens=False) authorizer = globus.AccessTokenAuthorizer(token['access_token']) tc = globus.TransferClient(authorizer=authorizer) @@ -72,7 +93,8 @@ def login(globus_client_id): def setup(globus_client_id, str_app='globus/default'): - # TODO Import from one.remove.globus + msg = 'ibllib.io.globus.setup has been deprecated. Use one.remote.globus.Globus instead.' + warnings.warn(msg, DeprecationWarning) # Lookup and manage consents there # https://auth.globus.org/v2/web/consents gtok = _login(globus_client_id, refresh_tokens=True) @@ -80,7 +102,8 @@ def setup(globus_client_id, str_app='globus/default'): def login_auto(globus_client_id, str_app='globus/default'): - # TODO Import from one.remove.globus + msg = 'ibllib.io.globus.login_auto has been deprecated. Use one.remote.globus.Globus instead.' + warnings.warn(msg, DeprecationWarning) token = params.read(str_app, {}) required_fields = {'refresh_token', 'access_token', 'expires_at_seconds'} if not (token and required_fields.issubset(token.as_dict())): @@ -92,7 +115,9 @@ def login_auto(globus_client_id, str_app='globus/default'): def get_local_endpoint(): - # TODO Remove in favour of one.remote.globus.get_local_endpoint_id + msg = 'ibllib.io.globus.get_local_endpoint has been deprecated. Use one.remote.globus.get_local_endpoint_id instead.' + warnings.warn(msg, DeprecationWarning) + if sys.platform == 'win32' or sys.platform == 'cygwin': id_path = Path(os.environ['LOCALAPPDATA']).joinpath("Globus Connect") else: diff --git a/ibllib/io/raw_data_loaders.py b/ibllib/io/raw_data_loaders.py index 200b8ca15..b1dad2e72 100644 --- a/ibllib/io/raw_data_loaders.py +++ b/ibllib/io/raw_data_loaders.py @@ -955,11 +955,14 @@ def patch_settings(session_path, collection='raw_behavior_data', ) if new_collection: - old_path = settings['SESSION_RAW_DATA_FOLDER'] - new_path = PureWindowsPath(settings['SESSION_RAW_DATA_FOLDER']).with_name(new_collection) - for k in settings.keys(): - if isinstance(settings[k], str): - settings[k] = settings[k].replace(old_path, str(new_path)) + if 'SESSION_RAW_DATA_FOLDER' not in settings: + _logger.warning('SESSION_RAW_DATA_FOLDER key not in settings; collection not updated') + else: + old_path = settings['SESSION_RAW_DATA_FOLDER'] + new_path = PureWindowsPath(settings['SESSION_RAW_DATA_FOLDER']).with_name(new_collection) + for k in settings.keys(): + if isinstance(settings[k], str): + settings[k] = settings[k].replace(old_path, str(new_path)) with open(file_path, 'w') as fp: json.dump(settings, fp, indent=' ') return settings diff --git a/ibllib/pipes/dynamic_pipeline.py b/ibllib/pipes/dynamic_pipeline.py index 044e242a6..f8063f9b0 100644 --- a/ibllib/pipes/dynamic_pipeline.py +++ b/ibllib/pipes/dynamic_pipeline.py @@ -90,7 +90,7 @@ def get_acquisition_description(protocol): else: devices = { 'cameras': { - 'left': {'collection': 'raw_video_data', 'sync_label': 'frame2ttl'}, + 'left': {'collection': 'raw_video_data', 'sync_label': 'audio'}, }, 'microphone': { 'microphone': {'collection': 'raw_behavior_data', 'sync_label': None} @@ -98,9 +98,7 @@ def get_acquisition_description(protocol): } acquisition_description = { # this is the current ephys pipeline description 'devices': devices, - 'sync': { - 'bpod': {'collection': 'raw_behavior_data', 'extension': 'bin'} - }, + 'sync': {'bpod': {'collection': 'raw_behavior_data'}}, 'procedures': ['Behavior training/tasks'], 'projects': ['ibl_neuropixel_brainwide_01'] } diff --git a/ibllib/tests/test_io.py b/ibllib/tests/test_io.py index 967a88d32..4c88b6c74 100644 --- a/ibllib/tests/test_io.py +++ b/ibllib/tests/test_io.py @@ -7,9 +7,9 @@ import sys import logging import json +from datetime import datetime import numpy as np -import numpy.testing from one.api import ONE from iblutil.io import params import yaml @@ -363,6 +363,7 @@ def setUp(self): self.addCleanup(self.patcher.stop) def test_as_globus_path(self): + assert datetime.now() < datetime(2023, 10, 30) # A Windows path if sys.platform == 'win32': # "/E/FlatIron/integration" @@ -380,6 +381,7 @@ def test_as_globus_path(self): @unittest.mock.patch('iblutil.io.params.read') def test_login_auto(self, mock_params): + assert datetime.now() < datetime(2023, 10, 30) client_id = 'h3u2ier' # Test ValueError thrown with incorrect parameters mock_params.return_value = None # No parameters saved From 0e608937b70ead6f5ff00907a3c0d7971bd85b7d Mon Sep 17 00:00:00 2001 From: Miles Wells Date: Wed, 6 Sep 2023 15:27:34 +0300 Subject: [PATCH 02/68] Use ONE Globus in patcher --- ibllib/oneibl/data_handlers.py | 7 ++--- ibllib/oneibl/patcher.py | 47 ++++++++++++++++------------------ ibllib/qc/task_metrics.py | 2 +- 3 files changed, 27 insertions(+), 29 deletions(-) diff --git a/ibllib/oneibl/data_handlers.py b/ibllib/oneibl/data_handlers.py index de39f31b8..1b571496d 100644 --- a/ibllib/oneibl/data_handlers.py +++ b/ibllib/oneibl/data_handlers.py @@ -121,7 +121,7 @@ def __init__(self, session_path, signatures, one=None): """ from one.remote.globus import Globus, get_lab_from_endpoint_id # noqa super().__init__(session_path, signatures, one=one) - self.globus = Globus(client_name='server') + self.globus = Globus(client_name='server', headless=True) # on local servers set up the local root path manually as some have different globus config paths self.globus.endpoints['local']['root_path'] = '/mnt/s0/Data/Subjects' @@ -131,7 +131,8 @@ def __init__(self, session_path, signatures, one=None): # For cortex lab we need to get the endpoint from the ibl alyx if self.lab == 'cortexlab': - self.globus.add_endpoint(f'flatiron_{self.lab}', alyx=ONE(base_url='https://alyx.internationalbrainlab.org').alyx) + alyx = AlyxClient(base_url='https://alyx.internationalbrainlab.org') + self.globus.add_endpoint(f'flatiron_{self.lab}', alyx=alyx) else: self.globus.add_endpoint(f'flatiron_{self.lab}', alyx=self.one.alyx) @@ -255,7 +256,7 @@ def uploadData(self, outputs, version, **kwargs): """ # Set up Globus from one.remote.globus import Globus # noqa - self.globus = Globus(client_name='server') + self.globus = Globus(client_name='server', headless=True) self.lab = session_path_parts(self.session_path, as_dict=True)['lab'] if self.lab == 'cortexlab' and 'cortexlab' in self.one.alyx.base_url: base_url = 'https://alyx.internationalbrainlab.org' diff --git a/ibllib/oneibl/patcher.py b/ibllib/oneibl/patcher.py index da97285d6..4e5ec2921 100644 --- a/ibllib/oneibl/patcher.py +++ b/ibllib/oneibl/patcher.py @@ -18,7 +18,6 @@ _logger = logging.getLogger(__name__) -FLAT_IRON_GLOBUS_ID = 'ab2d064c-413d-11eb-b188-0ee0d5d9299f' FLATIRON_HOST = 'ibl.flatironinstitute.org' FLATIRON_PORT = 61022 FLATIRON_USER = 'datauser' @@ -199,19 +198,19 @@ class GlobusPatcher(Patcher): def __init__(self, client_name='default', one=None, label='ibllib patch'): assert one - self.local_endpoint = getattr(globus.load_client_params(f'globus.{client_name}'), - 'local_endpoint', globus.get_local_endpoint_id()) - self.transfer_client = globus.create_globus_client(client_name) + self.globus = globus.Globus(client_name) self.label = label + # get a dictionary of data repositories from Alyx (with globus ids) + self.globus.fetch_endpoints_from_alyx(one.alyx) + flatiron_id = self.globus.endpoints['flatiron_cortexlab']['id'] + if not 'flatiron' in self.globus.endpoints: + self.globus.add_endpoint(flatiron_id, 'flatiron', root_path='/') + self.globus.endpoints['flatiron'] = self.globus.endpoints['flatiron_cortexlab'] # transfers/delete from the current computer to the flatiron: mandatory and executed first + local_id = self.globus.endpoints['local']['id'] self.globus_transfer = globus_sdk.TransferData( - self.transfer_client, self.local_endpoint, FLAT_IRON_GLOBUS_ID, verify_checksum=True, - sync_level='checksum', label=label) - self.globus_delete = globus_sdk.DeleteData( - self.transfer_client, FLAT_IRON_GLOBUS_ID, verify_checksum=True, - sync_level='checksum', label=label) - # get a dictionary of data repositories from Alyx (with globus ids) - self.repos = {r['name']: r for r in one.alyx.rest('data-repository', 'list')} + self.globus.client, local_id, flatiron_id, verify_checksum=True, sync_level='checksum', label=label) + self.globus_delete = globus_sdk.DeleteData(self.globus.client, flatiron_id, label=label) # transfers/delete from flatiron to optional third parties to synchronize / delete self.globus_transfers_locals = {} self.globus_deletes_locals = {} @@ -232,7 +231,7 @@ def _scp(self, local_path, remote_path, dry=True): def _rm(self, flatiron_path, dry=True): flatiron_path = Path('/').joinpath(flatiron_path.relative_to(Path(FLATIRON_MOUNT))) - _logger.info(f"Globus del {flatiron_path}") + _logger.info(f'Globus del {flatiron_path}') if not dry: if isinstance(self.globus_delete, globus_sdk.transfer.data.DeleteData): self.globus_delete.add_item(flatiron_path) @@ -253,25 +252,24 @@ def patch_datasets(self, file_list, **kwargs): for dset in responses: # get the flatiron path fr = next(fr for fr in dset['file_records'] if 'flatiron' in fr['data_repository']) - flatiron_path = self.repos[fr['data_repository']]['globus_path'] - flatiron_path = Path(flatiron_path).joinpath(fr['relative_path']) - flatiron_path = add_uuid_string(flatiron_path, dset['id']).as_posix() + relative_path = add_uuid_string(fr['relative_path'], dset['id']).as_posix() + flatiron_path = self.globus.to_address(relative_path, fr['data_repository']) # loop over the remaining repositories (local servers) and create a transfer # from flatiron to the local server for fr in dset['file_records']: if fr['data_repository'] == DMZ_REPOSITORY: continue - repo_gid = self.repos[fr['data_repository']]['globus_endpoint_id'] - if repo_gid == FLAT_IRON_GLOBUS_ID: + repo_gid = self.globus.endpoints[fr['data_repository']]['id'] + flatiron_id = self.globus.endpoints['flatiron']['id'] + if repo_gid == flatiron_id: continue # if there is no transfer already created, initialize it if repo_gid not in self.globus_transfers_locals: self.globus_transfers_locals[repo_gid] = globus_sdk.TransferData( - self.transfer_client, FLAT_IRON_GLOBUS_ID, repo_gid, verify_checksum=True, + self.globus.client, flatiron_id, repo_gid, verify_checksum=True, sync_level='checksum', label=f"{self.label} on {fr['data_repository']}") # get the local server path and create the transfer item - local_server_path = self.repos[fr['data_repository']]['globus_path'] - local_server_path = Path(local_server_path).joinpath(fr['relative_path']) + local_server_path = self.globus.to_address(fr['relative_path'], fr['data_repository']) self.globus_transfers_locals[repo_gid].add_item(flatiron_path, local_server_path) return responses @@ -282,7 +280,7 @@ def launch_transfers(self, local_servers=False): :param: local_servers (False): if True, sync the local servers after the main transfer :return: None """ - gtc = self.transfer_client + gtc = self.globus.client def _wait_for_task(resp): # patcher.transfer_client.get_task(task_id='364fbdd2-4deb-11eb-8ffb-0a34088e79f9') @@ -320,8 +318,7 @@ def _wait_for_task(resp): self.globus_delete = globus_sdk.DeleteData( gtc, endpoint=self.globus_delete['endpoint'], - label=self.globus_delete['label'], - verify_checksum=True, sync_level='checksum') + label=self.globus_delete['label']) # launch the local transfers and local deletes if local_servers: @@ -337,11 +334,11 @@ def launch_transfers_secondary(self): for lt in self.globus_transfers_locals: transfer = self.globus_transfers_locals[lt] if len(transfer['DATA']) > 0: - self.transfer_client.submit_transfer(transfer) + self.globus.client.submit_transfer(transfer) for ld in self.globus_deletes_locals: delete = self.globus_deletes_locals[ld] if len(transfer['DATA']) > 0: - self.transfer_client.submit_delete(delete) + self.globus.client.submit_delete(delete) class SSHPatcher(Patcher): diff --git a/ibllib/qc/task_metrics.py b/ibllib/qc/task_metrics.py index 36f2b4806..72cdd3ca6 100644 --- a/ibllib/qc/task_metrics.py +++ b/ibllib/qc/task_metrics.py @@ -151,7 +151,7 @@ def compute(self, **kwargs): if self.extractor is None: kwargs['download_data'] = kwargs.pop('download_data', self.download_data) self.load_data(**kwargs) - self.log.info(f"Session {self.session_path}: Running QC on behavior data...") + self.log.info(f'Session {self.session_path}: Running QC on behavior data...') self.metrics, self.passed = get_bpodqc_metrics_frame( self.extractor.data, wheel_gain=self.extractor.settings['STIM_GAIN'], # The wheel gain From d97ccd1751163f133d6801b02b730c0353088680 Mon Sep 17 00:00:00 2001 From: Miles Wells Date: Wed, 6 Sep 2023 15:48:00 +0300 Subject: [PATCH 03/68] Add delete_dataset method --- ibllib/oneibl/patcher.py | 72 ++++++++++++++++++++++++++++++++++++++-- 1 file changed, 70 insertions(+), 2 deletions(-) diff --git a/ibllib/oneibl/patcher.py b/ibllib/oneibl/patcher.py index 4e5ec2921..d7c20d95a 100644 --- a/ibllib/oneibl/patcher.py +++ b/ibllib/oneibl/patcher.py @@ -1,6 +1,8 @@ import abc import ftplib from pathlib import Path, PurePosixPath, WindowsPath +from collections import defaultdict +from itertools import groupby, starmap import subprocess import logging from getpass import getpass @@ -9,8 +11,9 @@ import globus_sdk import iblutil.io.params as iopar from one.alf.files import get_session_path, add_uuid_string -from one.alf.spec import is_uuid_string +from one.alf.spec import is_uuid_string, is_uuid from one import params +from one.webclient import AlyxClient from one.converters import path_from_dataset from one.remote import globus @@ -203,7 +206,7 @@ def __init__(self, client_name='default', one=None, label='ibllib patch'): # get a dictionary of data repositories from Alyx (with globus ids) self.globus.fetch_endpoints_from_alyx(one.alyx) flatiron_id = self.globus.endpoints['flatiron_cortexlab']['id'] - if not 'flatiron' in self.globus.endpoints: + if 'flatiron' not in self.globus.endpoints: self.globus.add_endpoint(flatiron_id, 'flatiron', root_path='/') self.globus.endpoints['flatiron'] = self.globus.endpoints['flatiron_cortexlab'] # transfers/delete from the current computer to the flatiron: mandatory and executed first @@ -341,6 +344,71 @@ def launch_transfers_secondary(self): self.globus.client.submit_delete(delete) +class IBLGlobusPatcher(Patcher, globus.Globus): + """This is a replacement for the GlobusPatcher class, utilizing the ONE Globus class. + + The GlobusPatcher class is more complicated but has the advantage of being able to launch + transfers independently to registration, although it remains to be seen whether this is useful. + """ + def __init__(self, alyx=None, client_name='default'): + """ + + Parameters + ---------- + alyx : one.webclient.AlyxClient + An instance of Alyx to use. + client_name : str, default='default' + The Globus client name. + """ + self.alyx = alyx or AlyxClient() + globus.Globus.__init__(client_name=client_name) # NB we don't init Patcher as we're not using ONE + + def delete_dataset(self, dataset, dry=False): + """ + Delete a dataset off Alyx and remove file record from all Globus repositories. + + Parameters + ---------- + dataset : uuid.UUID, str, dict + The dataset record or ID to delete. + dry : bool + If true, dataset is not deleted and file paths that would be removed are returned. + + Returns + ------- + list of uuid.UUID + A list of Globus delete task IDs if dry is false. + dict of str + A map of data repository names and relative paths of the deleted files. + """ + if is_uuid(dataset): + did = dataset + dataset = self.alyx.rest('datasets', 'read', id=did) + else: + did = dataset['url'].split('/')[-1] + + files_by_repo = defaultdict(list) # uuid.UUID -> [pathlib.PurePosixPath] + file_records = filter(lambda x: x['exists'], dataset['file_records']) + for repo, record in groupby(file_records, lambda x: x['data_repository']): + if not record['globus_id']: + raise NotImplementedError + if repo not in self.endpoints: + self.add_endpoint(repo, alyx=self.alyx) + filepath = PurePosixPath(record['relative_path']) + if 'flatiron' in repo: + filepath = add_uuid_string(filepath, did) + files_by_repo[repo].append(filepath) + + if dry: + return [], files_by_repo + + # Delete the files + task_ids = list(starmap(self.delete_data, files_by_repo.items())) + # Delete the dataset from Alyx + self.alyx.rest('datasets', 'delete', id=did) + return task_ids, files_by_repo + + class SSHPatcher(Patcher): """ Requires SSH keys access on the FlatIron From 64e72f34ead65faa010ce658579f66b0e8d0b69c Mon Sep 17 00:00:00 2001 From: Miles Wells Date: Mon, 11 Sep 2023 13:23:48 +0300 Subject: [PATCH 04/68] Delete S3 files --- ibllib/oneibl/patcher.py | 67 ++++++++++++++++++++++++++++++++-------- 1 file changed, 54 insertions(+), 13 deletions(-) diff --git a/ibllib/oneibl/patcher.py b/ibllib/oneibl/patcher.py index d7c20d95a..e8734ed72 100644 --- a/ibllib/oneibl/patcher.py +++ b/ibllib/oneibl/patcher.py @@ -2,7 +2,9 @@ import ftplib from pathlib import Path, PurePosixPath, WindowsPath from collections import defaultdict -from itertools import groupby, starmap +from itertools import starmap +from subprocess import Popen, PIPE, STDOUT +from urllib.parse import urlparse import subprocess import logging from getpass import getpass @@ -32,6 +34,13 @@ SDSC_PATCH_PATH = PurePosixPath('/home/datauser/temp') +def url2uri(data_path): + parsed = urlparse(data_path) + assert parsed.netloc and parsed.scheme and parsed.path + bucket_name = parsed.netloc.split('.')[0] + return f's3://{bucket_name}{parsed.path}' + + def _run_command(cmd, dry=True): _logger.info(cmd) if dry: @@ -387,23 +396,55 @@ def delete_dataset(self, dataset, dry=False): else: did = dataset['url'].split('/')[-1] - files_by_repo = defaultdict(list) # uuid.UUID -> [pathlib.PurePosixPath] + def is_aws(repository_name): + return repository_name.startswith('aws_') + + files_by_repo = defaultdict(list) # str -> [pathlib.PurePosixPath] + s3_files = [] file_records = filter(lambda x: x['exists'], dataset['file_records']) - for repo, record in groupby(file_records, lambda x: x['data_repository']): - if not record['globus_id']: - raise NotImplementedError - if repo not in self.endpoints: - self.add_endpoint(repo, alyx=self.alyx) - filepath = PurePosixPath(record['relative_path']) - if 'flatiron' in repo: - filepath = add_uuid_string(filepath, did) - files_by_repo[repo].append(filepath) + for record in file_records: + repo = self.repo_from_alyx(record['data_repository'], self.alyx) + # Handle S3 files + if not repo['globus_endpoint_id'] or repo['repository_type'] != 'Fileserver': + if is_aws(repo['name']): + s3_files.append(url2uri(record['data_url'])) + files_by_repo[repo['name']].append(PurePosixPath(record['relative_path'])) + else: + _logger.error('Unable to delete from %s', repo['name']) + else: + # Handle Globus files + if repo['name'] not in self.endpoints: + self.add_endpoint(repo['name'], alyx=self.alyx) + filepath = PurePosixPath(record['relative_path']) + if 'flatiron' in repo['name']: + filepath = add_uuid_string(filepath, did) + files_by_repo[repo['name']].append(filepath) + + # Remove S3 files + if s3_files: + cmd = ['aws', 's3', 'rm', *s3_files, '--profile', 'ibladmin'] + if dry: + cmd.append('--dryrun') + if _logger.level > logging.DEBUG: + log_function = _logger.error + cmd.append('--only-show-errors') # Suppress verbose output + else: + log_function = _logger.debug + cmd.append('--no-progress') # Suppress progress info, estimated time, etc. + _logger.debug(' '.join(cmd)) + process = Popen(cmd, stdout=PIPE, stderr=STDOUT) + with process.stdout: + for line in iter(process.stdout.readline, b''): + log_function(line.decode().strip()) + assert process.wait() == 0 if dry: return [], files_by_repo - # Delete the files - task_ids = list(starmap(self.delete_data, files_by_repo.items())) + # Remove Globus files + globus_files_map = filter(lambda x: not is_aws(x[0]), files_by_repo.items()) + task_ids = list(starmap(self.delete_data, map(reversed, globus_files_map))) + # Delete the dataset from Alyx self.alyx.rest('datasets', 'delete', id=did) return task_ids, files_by_repo From 1889e7cfe73d4a5996748c596b7425500a98dde8 Mon Sep 17 00:00:00 2001 From: Miles Wells Date: Fri, 15 Sep 2023 12:22:21 +0300 Subject: [PATCH 05/68] Bump ONE version --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 544601f8a..4e6050008 100644 --- a/requirements.txt +++ b/requirements.txt @@ -26,7 +26,7 @@ tqdm>=4.32.1 ibl-neuropixel>=0.4.0 iblutil>=1.7.0 labcams # widefield extractor -ONE-api>=2.2 +ONE-api>=2.3 slidingRP>=1.0.0 # steinmetz lab refractory period metrics wfield==0.3.7 # widefield extractor frozen for now (2023/07/15) until Joao fixes latest version psychofit From 7d811e0f26da3b5835091dcc69a43e63a9f33613 Mon Sep 17 00:00:00 2001 From: Mayo Faulkner Date: Mon, 25 Sep 2023 08:50:50 +0100 Subject: [PATCH 06/68] full session wheel alignment --- ibllib/io/extractors/video_motion.py | 512 ++++++++++++++++++++++++++- 1 file changed, 511 insertions(+), 1 deletion(-) diff --git a/ibllib/io/extractors/video_motion.py b/ibllib/io/extractors/video_motion.py index ef75187b5..cde4b393c 100644 --- a/ibllib/io/extractors/video_motion.py +++ b/ibllib/io/extractors/video_motion.py @@ -4,20 +4,29 @@ """ import matplotlib import matplotlib.pyplot as plt +import matplotlib.gridspec as gridspec from matplotlib.widgets import RectangleSelector import numpy as np -from scipy import signal +from scipy import signal, ndimage, interpolate import cv2 from itertools import cycle import matplotlib.animation as animation import logging from pathlib import Path +from joblib import Parallel, delayed, cpu_count +from neurodsp.utils import WindowGenerator from one.api import ONE import ibllib.io.video as vidio from iblutil.util import Bunch +from ibllib.io.extractors.ephys_fpga import get_sync_fronts, get_sync_and_chn_map +import ibllib.io.raw_data_loaders as raw +from ibllib.io.extractors.camera import CameraTimestampsBpod import brainbox.video as video import brainbox.behavior.wheel as wh +from brainbox.singlecell import bin_spikes +from brainbox.behavior.dlc import likelihood_threshold, get_speed +from brainbox.task.trials import find_trial_ids import one.alf.io as alfio from one.alf.spec import is_session_path, is_uuid_string @@ -383,3 +392,504 @@ def process_key(event): anim.save(str(filename), writer=writer) else: plt.show() + + +session_path = one.eid2path(eid) +session_path = Path('/mnt/ibl').joinpath(*session_path.parts[-5:]) +class MotionAlignmentFullSession: + def __init__(self, session_path, label, **kwargs): + self.session_path = session_path + self.label = label + self.threshold = kwargs.get('threshold', 20) + self.behavior = kwargs.get('behavior', False) + self.twin = kwargs.get('twin', 150) + self.nprocess = kwargs.get('nprocess', int(cpu_count() - cpu_count() / 4)) + + self.load_data(sync=kwargs.get('sync', 'nidq'), location=kwargs.get('location', None), behavior=self.behavior) + self.roi, self.mask = self.get_roi_mask() + + def load_data(self, sync='nidq', location=None, behavior=False): + def fix_keys(alf_object): + ob = Bunch() + for key in alf_object.keys(): + vals = alf_object[key] + ob[key.split('.')[0]] = vals + return ob + + alf_path = self.session_path.joinpath('alf') + wheel = (fix_keys(alfio.load_object(alf_path, 'wheel')) if location == 'SDSC' + else alfio.load_object(alf_path, 'wheel')) + self.wheel_timestamps = wheel.timestamps + wheel_pos, self.wheel_time = wh.interpolate_position(wheel.timestamps, wheel.position, freq=1000) + self.wheel_vel, _ = wh.velocity_filtered(wheel_pos, 1000) + self.camera_times = alfio.load_file_content(next(alf_path.glob(f'_ibl_{self.label}Camera.times.*.npy'))) + self.camera_path = str(next(self.session_path.joinpath('raw_video_data').glob( + f'_iblrig_{self.label}Camera.raw.*.mp4'))) + self.camera_meta = vidio.get_video_meta(self.camera_path) + + # TODO should read in the description file to get the correct sync location + if sync == 'nidq': + sync, chmap = get_sync_and_chn_map(self.session_path, sync_collection='raw_ephys_data') + sr = get_sync_fronts(sync, chmap[f'{self.label}_camera']) + self.ttls = sr.times[::2] + else: + cam_extractor = CameraTimestampsBpod(session_path=self.session_path) + cam_extractor.bpod_trials = raw.load_data(self.session_path, task_collection='raw_behavior_data') + self.ttls = cam_extractor._times_from_bpod() + + self.tdiff = self.ttls.size - self.camera_meta['length'] + + if self.tdiff < 0: + self.ttl_times = self.ttls + self.times = np.r_[self.ttl_times, np.full((np.abs(self.tdiff)), np.nan)] + self.short_flag = True + elif self.tdiff > 0: + self.ttl_times = self.ttls[self.tdiff:] + self.times = self.ttls[self.tdiff:] + self.short_flag = False + + if behavior: + self.trials = alfio.load_file_content(next(alf_path.glob(f'_ibl_trials.table.*.pqt'))) + self.dlc = alfio.load_file_content(next(alf_path.glob(f'_ibl_{self.label}Camera.dlc.*.pqt'))) + self.dlc = likelihood_threshold(self.dlc) + + self.frame_example = vidio.get_video_frames_preload(self.camera_path, np.arange(10, 11), mask=np.s_[:, :, 0]) + + def get_roi_mask(self): + + if self.label == 'right': + roi = ((450, 512), (120, 200)) + else: + roi = ((900, 1024), (850, 1010)) + roi_mask = (*[slice(*r) for r in roi], 0) + + return roi, roi_mask + + def find_contaminated_frames(self, video_frames, thresold=20, normalise=True): + high = np.zeros((video_frames.shape[0])) + for idx, frame in enumerate(video_frames): + ret, _ = cv2.threshold(cv2.GaussianBlur(frame, (5, 5), 0), 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) + high[idx] = ret + + if normalise: + high -= np.min(high) + + contaminated_frames = np.where(high > thresold)[0] + + return contaminated_frames + + def compute_motion_energy(self, first, last, wg, iw): + + if iw == wg.nwin - 1: + return + + cap = cv2.VideoCapture(self.camera_path) + frames = vidio.get_video_frames_preload(cap, np.arange(first, last), mask=self.mask) + idx = self.find_contaminated_frames(frames, self.threshold) + + if len(idx) != 0: + + before_status = False + after_status = False + + counter = 0 + n_frames = 200 + while np.any(idx == 0) and counter < 20 and iw != 0: + n_before_offset = (counter + 1) * n_frames + first -= n_frames + extra_frames = vidio.get_video_frames_preload(cap, frame_numbers=np.arange(first - n_frames, first), + mask=self.mask) + frames = np.concatenate([extra_frames, frames], axis=0) + + idx = self.find_contaminated_frames(frames, self.threshold) + before_status = True + counter += 1 + if counter > 0: + print(f'In before: {counter}') + + counter = 0 + while np.any(idx == frames.shape[0] - 1) and counter < 20 and iw != wg.nwin - 1: + n_after_offset = (counter + 1) * n_frames + last += n_frames + extra_frames = vidio.get_video_frames_preload(cap, frame_numbers=np.arange(last, last + n_frames), + mask=self.mask) + frames = np.concatenate([frames, extra_frames], axis=0) + idx = self.find_contaminated_frames(frames, self.threshold) + after_status = True + counter += 1 + + if counter > 0: + print(f'In after: {counter}') + + intervals = np.split(idx, np.where(np.diff(idx) != 1)[0] + 1) + for ints in intervals: + if len(ints) > 0 and ints[0] == 0: + ints = ints[1:] + if len(ints) > 0 and ints[-1] == frames.shape[0] - 1: + ints = ints[:-1] + th_all = np.zeros_like(frames[0]) + for idx in ints: + img = np.copy(frames[idx]) + blur = cv2.GaussianBlur(img, (5, 5), 0) + ret, th = cv2.threshold(blur, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) + th = cv2.GaussianBlur(th, (5, 5), 10) + th_all += th + vals = np.mean(np.dstack([frames[ints[0] - 1], frames[ints[-1] + 1]]), axis=-1) + for idx in ints: + img = frames[idx] + img[th_all > 0] = vals[th_all > 0] + + if before_status: + frames = frames[n_before_offset:] + if after_status: + frames = frames[:(-1 * n_after_offset)] + + frame_me, _ = video.motion_energy(frames, diff=2, normalize=False) + + cap.release() + + return frame_me[2:] + + def compute_shifts(self, times, me, first, last, iw, wg): + + if iw == wg.nwin - 1: + return np.nan, np.nan + t_first = times[first] + t_last = times[last] + if np.isnan(t_last) and np.isnan(t_first): + return np.nan, np.nan + elif np.isnan(t_last): + t_last = times[np.where(~np.isnan(times))[0][-1]] + + mask = np.logical_and(times >= t_first, times <= t_last) + align_me = me[np.where(mask)[0]] + align_me = (align_me - np.nanmin(align_me)) / (np.nanmax(align_me) - np.nanmin(align_me)) + + # Find closest timepoints in wheel that match the camera times + wh_mask = np.logical_and(self.wheel_time >= t_first, self.wheel_time <= t_last) + if np.sum(wh_mask) == 0: + return np.nan, np.nan + xs = np.searchsorted(self.wheel_time[wh_mask], times[mask]) + xs[xs == np.sum(wh_mask)] = np.sum(wh_mask) - 1 + # Convert to normalized speed + vs = np.abs(self.wheel_vel[wh_mask][xs]) + vs = (vs - np.min(vs)) / (np.max(vs) - np.min(vs)) + + isnan = np.isnan(align_me) + + if np.sum(isnan) > 0: + where_nan = np.where(isnan)[0] + assert where_nan[0] == 0 + assert where_nan[-1] == np.sum(isnan) - 1 + + if np.all(isnan): + return np.nan, np.nan + + xcorr = signal.correlate(align_me[~isnan], vs[~isnan]) + shift = np.nanargmax(xcorr) - align_me[~isnan].size + 2 + + return shift, t_first + (t_last - t_first) / 2 + + def clean_shifts(self, x, n=1): + y = x.copy() + dy = np.diff(y, prepend=y[0]) + while True: + pos = np.where(dy == 1)[0] if n == 1 else np.where(dy > 2)[0] + # added frames: this doesn't make sense and this is noise + if pos.size == 0: + break + neg = np.where(dy == -1)[0] if n == 1 else np.where(dy < -2)[0] + + if len(pos) > len(neg): + neg = np.append(neg, dy.size - 1) + + iss = np.minimum(np.searchsorted(neg, pos), neg.size - 1) + imin = np.argmin(np.minimum(np.abs(pos - neg[iss - 1]), np.abs(pos - neg[iss]))) + + idx = np.max([0, iss[imin] - 1]) + ineg = neg[idx:iss[imin] + 1] + ineg = ineg[np.argmin(np.abs(pos[imin] - ineg))] + dy[pos[imin]] = 0 + dy[ineg] = 0 + + return np.cumsum(dy) + y[0] + + + def qc_shifts(self, shifts, shifts_filt): + + ttl_per = (np.abs(self.tdiff) / self.camera_meta['length']) * 100 if self.tdiff < 0 else 0 + nan_per = (np.sum(np.isnan(shifts_filt)) / shifts_filt.size) * 100 + shifts_sum = np.where(np.abs(np.diff(shifts)) > 10)[0].size + shifts_filt_sum = np.where(np.abs(np.diff(shifts_filt)) > 1)[0].size + + qc = {} + qc['ttl_per'] = ttl_per + qc['nan_per'] = nan_per + qc['shifts_sum'] = shifts_sum + qc['shifts_filt_sum'] = shifts_filt_sum + + return qc + + # # If more than 10% of ttls are missing we don't get new times + # if ttl_per > 10: + # return False + # # If too many of the shifts are nans it means the alignment is not accurate + # if nan_per > 40: + # return False + # # If there are too many artefacts could be errors + # if shifts_sum > 80: + # return False + # # If there are jumps > 1 in the filtered shifts then there is a problem + # if shifts_filt_sum > 0: + # return False + # + # return True + + def extract_times(self, shifts_filt, t_shifts): + + fps = 1 / np.nanmean(np.diff(self.ttl_times)) + t_new = t_shifts - (shifts_filt * 1 / fps) + fcn = interpolate.interp1d(t_shifts, t_new, fill_value="extrapolate") + new_times = fcn(self.ttl_times) + + # TODO if short we need to interpolate the end times + + return new_times + @staticmethod + def single_cluster_raster(spike_times, events, trial_idx, dividers, colors, labels, weights=None, fr=True, + norm=False, + axs=None): + pre_time = 0.4 + post_time = 1 + raster_bin = 0.01 + psth_bin = 0.05 + raster, t_raster = bin_spikes( + spike_times, events, pre_time=pre_time, post_time=post_time, bin_size=raster_bin, weights=weights) + psth, t_psth = bin_spikes( + spike_times, events, pre_time=pre_time, post_time=post_time, bin_size=psth_bin, weights=weights) + + if fr: + psth = psth / psth_bin + + if norm: + psth = psth - np.repeat(psth[:, 0][:, np.newaxis], psth.shape[1], axis=1) + raster = raster - np.repeat(raster[:, 0][:, np.newaxis], raster.shape[1], axis=1) + + dividers = [0] + dividers + [len(trial_idx)] + if axs is None: + fig, axs = plt.subplots(2, 1, figsize=(4, 6), gridspec_kw={'height_ratios': [1, 3], 'hspace': 0}, + sharex=True) + else: + fig = axs[0].get_figure() + + label, lidx = np.unique(labels, return_index=True) + label_pos = [] + for lab, lid in zip(label, lidx): + idx = np.where(np.array(labels) == lab)[0] + for iD in range(len(idx)): + if iD == 0: + t_ids = trial_idx[dividers[idx[iD]] + 1:dividers[idx[iD] + 1] + 1] + t_ints = dividers[idx[iD] + 1] - dividers[idx[iD]] + else: + t_ids = np.r_[t_ids, trial_idx[dividers[idx[iD]] + 1:dividers[idx[iD] + 1] + 1]] + t_ints = np.r_[t_ints, dividers[idx[iD] + 1] - dividers[idx[iD]]] + + psth_div = np.nanmean(psth[t_ids], axis=0) + std_div = np.nanstd(psth[t_ids], axis=0) / np.sqrt(len(t_ids)) + + axs[0].fill_between(t_psth, psth_div - std_div, + psth_div + std_div, alpha=0.4, color=colors[lid]) + axs[0].plot(t_psth, psth_div, alpha=1, color=colors[lid]) + + lab_max = idx[np.argmax(t_ints)] + label_pos.append((dividers[lab_max + 1] - dividers[lab_max]) / 2 + dividers[lab_max]) + + axs[1].imshow(raster[trial_idx], cmap='binary', origin='lower', + extent=[np.min(t_raster), np.max(t_raster), 0, len(trial_idx)], aspect='auto') + + width = raster_bin * 4 + for iD in range(len(dividers) - 1): + axs[1].fill_between([post_time + raster_bin / 2, post_time + raster_bin / 2 + width], + [dividers[iD + 1], dividers[iD + 1]], [dividers[iD], dividers[iD]], color=colors[iD]) + + axs[1].set_xlim([-1 * pre_time, post_time + raster_bin / 2 + width]) + secax = axs[1].secondary_yaxis('right') + + secax.set_yticks(label_pos) + secax.set_yticklabels(label, rotation=90, + rotation_mode='anchor', ha='center') + for ic, c in enumerate(np.array(colors)[lidx]): + secax.get_yticklabels()[ic].set_color(c) + + axs[0].axvline(0, *axs[0].get_ylim(), c='k', ls='--', zorder=10) # TODO this doesn't always work + axs[1].axvline(0, *axs[1].get_ylim(), c='k', ls='--', zorder=10) + + return fig, axs + + def plot_with_behavior(self): + + dlc = likelihood_threshold(self.dlc) + trial_idx, dividers = find_trial_ids(self.trials, sort='side') + feature_ext = get_speed(self.dlc, self.camera_times, self.label, feature='paw_r') + feature_new = get_speed(self.dlc, self.new_times, self.label, feature='paw_r') + + fig = plt.figure() + fig.set_size_inches(15, 9) + gs = gridspec.GridSpec(1, 5, figure=fig, width_ratios=[4, 1, 1, 1, 3], wspace=0.3, hspace=0.5) + gs0 = gridspec.GridSpecFromSubplotSpec(3, 1, subplot_spec=gs[0, 0]) + ax01 = fig.add_subplot(gs0[0, 0]) + ax02 = fig.add_subplot(gs0[1, 0]) + ax03 = fig.add_subplot(gs0[2, 0]) + gs1 = gridspec.GridSpecFromSubplotSpec(2, 1, subplot_spec=gs[0, 1], height_ratios=[1, 3]) + ax11 = fig.add_subplot(gs1[0, 0]) + ax12 = fig.add_subplot(gs1[1, 0]) + gs2 = gridspec.GridSpecFromSubplotSpec(2, 1, subplot_spec=gs[0, 2], height_ratios=[1, 3]) + ax21 = fig.add_subplot(gs2[0, 0]) + ax22 = fig.add_subplot(gs2[1, 0]) + gs3 = gridspec.GridSpecFromSubplotSpec(2, 1, subplot_spec=gs[0, 3], height_ratios=[1, 3]) + ax31 = fig.add_subplot(gs3[0, 0]) + ax32 = fig.add_subplot(gs3[1, 0]) + gs4 = gridspec.GridSpecFromSubplotSpec(2, 1, subplot_spec=gs[0, 4]) + ax41 = fig.add_subplot(gs4[0, 0]) + ax42 = fig.add_subplot(gs4[1, 0]) + + ax01.plot(self.t_shifts, self.shifts, label='shifts') + ax01.plot(self.t_shifts, self.shifts_filt, label='shifts_filt') + ax01.set_ylim(np.min(self.shifts_filt) - 10, np.max(self.shifts_filt) + 10) + ax01.legend() + ax01.set_ylabel('Frames') + ax01.set_xlabel('Time in session') + + xs = np.searchsorted(self.ttl_times, self.t_shifts) + ttl_diff = (self.times - self.camera_times)[xs] * self.camera_meta['fps'] + ax02.plot(self.t_shifts, ttl_diff, label='extracted - ttl') + ax02.set_ylim(np.min(ttl_diff) - 10, np.max(ttl_diff) + 10) + ax02.legend() + ax02.set_ylabel('Frames') + ax02.set_xlabel('Time in session') + + ax03.plot(self.camera_times, (self.camera_times - self.new_times) * self.camera_meta['fps'], + 'k', label='extracted - new') + ax03.legend() + ax03.set_ylim(-5, 5) + ax03.set_ylabel('Frames') + ax03.set_xlabel('Time in session') + + self.single_cluster_raster(self.wheel_timestamps, self.trials['firstMovement_times'].values, trial_idx, dividers, + ['g', 'y'], ['left', 'right'], weights=self.wheel_vel, fr=False, axs=[ax11, ax12]) + ax11.sharex(ax12) + ax11.set_ylabel('Wheel velocity') + ax11.set_title('Wheel') + ax12.set_xlabel('Time from first move') + + self.single_cluster_raster(self.camera_times, self.trials['firstMovement_times'].values, trial_idx, dividers, + ['g', 'y'], ['left', 'right'], weights=feature_ext, fr=False, axs=[ax21, ax22]) + ax21.sharex(ax22) + ax21.set_ylabel('Paw r velocity') + ax21.set_title('Extracted times') + ax22.set_xlabel('Time from first move') + + self.single_cluster_raster(self.new_times, self.trials['firstMovement_times'].values, trial_idx, dividers, ['g', 'y'], + ['left', 'right'], weights=feature_new, fr=False, axs=[ax31, ax32]) + ax31.sharex(ax32) + ax31.set_ylabel('Paw r velocity') + ax31.set_title('New times') + ax32.set_xlabel('Time from first move') + + ax41.imshow(self.frame_example[0]) + rect = matplotlib.patches.Rectangle((self.roi[1][1], self.roi[0][0]), self.roi[1][0] - self.roi[1][1], + self.roi[0][1] - self.roi[0][0], + linewidth=4, edgecolor='g', facecolor='none') + ax41.add_patch(rect) + + ax42.plot(self.all_me) + + return fig + + def plot_without_behavior(self): + + fig = plt.figure() + fig.set_size_inches(7, 7) + gs = gridspec.GridSpec(1, 2, figure=fig) + gs0 = gridspec.GridSpecFromSubplotSpec(3, 1, subplot_spec=gs[0, 0]) + ax01 = fig.add_subplot(gs0[0, 0]) + ax02 = fig.add_subplot(gs0[1, 0]) + ax03 = fig.add_subplot(gs0[2, 0]) + + gs1 = gridspec.GridSpecFromSubplotSpec(2, 1, subplot_spec=gs[0, 1]) + ax04 = fig.add_subplot(gs1[0, 0]) + ax05 = fig.add_subplot(gs1[1, 0]) + + ax01.plot(self.t_shifts, self.shifts, label='shifts') + ax01.plot(self.t_shifts, self.shifts_filt, label='shifts_filt') + ax01.set_ylim(np.min(self.shifts_filt) - 10, np.max(self.shifts_filt) + 10) + ax01.legend() + ax01.set_ylabel('Frames') + ax01.set_xlabel('Time in session') + + xs = np.searchsorted(self.ttl_times, self.t_shifts) + ttl_diff = (self.times - self.camera_times)[xs] * self.camera_meta['fps'] + ax02.plot(self.t_shifts, ttl_diff, label='extracted - ttl') + ax02.set_ylim(np.min(ttl_diff) - 10, np.max(ttl_diff) + 10) + ax02.legend() + ax02.set_ylabel('Frames') + ax02.set_xlabel('Time in session') + + ax03.plot(self.camera_times, (self.camera_times - self.new_times) * self.camera_meta['fps'], + 'k', label='extracted - new') + ax03.legend() + ax03.set_ylim(-5, 5) + ax03.set_ylabel('Frames') + ax03.set_xlabel('Time in session') + + ax04.imshow(self.frame_example[0]) + rect = matplotlib.patches.Rectangle((self.roi[1][1], self.roi[0][0]), self.roi[1][0] - self.roi[1][1], + self.roi[0][1] - self.roi[0][0], + linewidth=4, edgecolor='g', facecolor='none') + ax04.add_patch(rect) + + ax05.plot(self.all_me) + + return fig + + def process(self): + + # Compute the motion energy of the wheel for the whole video + wg = WindowGenerator(self.camera_meta['length'], 5000, 4) + out = Parallel(n_jobs=self.nprocess)(delayed(self.compute_motion_energy)(first, last, wg, iw) + for iw, (first, last) in enumerate(wg.firstlast)) + # Concatenate the motion energy into one big array + self.all_me = np.array([]) + for vals in out[:-1]: + self.all_me = np.r_[self.all_me, vals] + + frate = round(1 / np.nanmedian(np.diff(self.ttl_times))) + toverlap = self.twin - 1 + all_me = np.r_[np.full((int(self.camera_meta['fps'] * toverlap)), np.nan), self.all_me] + to_app = self.times[0] - ((np.arange(int(self.camera_meta['fps'] * toverlap), ) + 1) / frate)[::-1] + times = np.r_[to_app, self.times] + + wg = WindowGenerator(all_me.size - 1, int(self.camera_meta['fps'] * self.twin), + int(self.camera_meta['fps'] * toverlap)) + + out = Parallel(n_jobs=self.nprocess)(delayed(self.compute_shifts)(times, all_me, first, last, iw, wg) + for iw, (first, last) in enumerate(wg.firstlast)) + + self.shifts = np.array([]) + self.t_shifts = np.array([]) + for vals in out[:-1]: + self.shifts = np.r_[self.shifts, vals[0]] + self.t_shifts = np.r_[self.t_shifts, vals[1]] + + idx = np.bitwise_and(self.t_shifts >= self.ttl_times[0], self.t_shifts < self.ttl_times[-1]) + self.shifts = self.shifts[idx] + self.t_shifts = self.t_shifts[idx] + shifts_filt = ndimage.percentile_filter(self.shifts, 80, 120) + shifts_filt = self.clean_shifts(shifts_filt, n=1) + self.shifts_filt = self.clean_shifts(shifts_filt, n=2) + + self.qc = self.qc_shifts(self.shifts, self.shifts_filt) + + self.new_times = self.extract_times(self.shifts_filt, self.t_shifts) + + return self.new_times From f77bb4a1d6505e83c44934b38adb0df272c3a08c Mon Sep 17 00:00:00 2001 From: Mayo Faulkner Date: Tue, 26 Sep 2023 13:53:26 +0100 Subject: [PATCH 07/68] add video wheel alignment to fpga camera extraction --- ibllib/io/extractors/camera.py | 19 ++++++++- ibllib/io/extractors/video_motion.py | 61 ++++++++++++++-------------- 2 files changed, 47 insertions(+), 33 deletions(-) diff --git a/ibllib/io/extractors/camera.py b/ibllib/io/extractors/camera.py index f3e5dbd1d..707e0764f 100644 --- a/ibllib/io/extractors/camera.py +++ b/ibllib/io/extractors/camera.py @@ -16,6 +16,7 @@ from ibllib.io.extractors.base import get_session_extractor_type from ibllib.io.extractors.ephys_fpga import get_sync_fronts, get_sync_and_chn_map import ibllib.io.raw_data_loaders as raw +from ibllib.io.extractors.video_motion import MotionAlignmentFullSession from ibllib.io.extractors.base import ( BaseBpodTrialsExtractor, BaseExtractor, @@ -148,12 +149,26 @@ def _extract(self, sync=None, chmap=None, video_path=None, sync_label='audio', except AssertionError as ex: _logger.critical('Failed to extract using %s: %s', sync_label, ex) - # If you reach here extracting using sync TTLs was not possible - _logger.warning('Alignment by wheel data not yet implemented') + # If you reach here extracting using sync TTLs was not possible, we attempt to align using wheel motion energy + _logger.warning('Attempting to align using wheel') + + try: + motion_class = MotionAlignmentFullSession(self.session_path, self.label, behavior=False) + new_times = motion_class.process() + if not motion_class.qc_outcome: + raise ValueError(f'Wheel alignment failed to pass qc: {motion_class.qc}') + else: + _logger.warning(f'Wheel alignment successful, qc: {motion_class.qc}') + return new_times + + except Exception as err: + _logger.critical(f'Failed to align with wheel: {err}') + if length < raw_ts.size: df = raw_ts.size - length _logger.info(f'Discarding first {df} pulses') raw_ts = raw_ts[df:] + return raw_ts diff --git a/ibllib/io/extractors/video_motion.py b/ibllib/io/extractors/video_motion.py index cde4b393c..c703f243a 100644 --- a/ibllib/io/extractors/video_motion.py +++ b/ibllib/io/extractors/video_motion.py @@ -394,8 +394,6 @@ def process_key(event): plt.show() -session_path = one.eid2path(eid) -session_path = Path('/mnt/ibl').joinpath(*session_path.parts[-5:]) class MotionAlignmentFullSession: def __init__(self, session_path, label, **kwargs): self.session_path = session_path @@ -422,9 +420,9 @@ def fix_keys(alf_object): self.wheel_timestamps = wheel.timestamps wheel_pos, self.wheel_time = wh.interpolate_position(wheel.timestamps, wheel.position, freq=1000) self.wheel_vel, _ = wh.velocity_filtered(wheel_pos, 1000) - self.camera_times = alfio.load_file_content(next(alf_path.glob(f'_ibl_{self.label}Camera.times.*.npy'))) + self.camera_times = alfio.load_file_content(next(alf_path.glob(f'_ibl_{self.label}Camera.times*.npy'))) self.camera_path = str(next(self.session_path.joinpath('raw_video_data').glob( - f'_iblrig_{self.label}Camera.raw.*.mp4'))) + f'_iblrig_{self.label}Camera.raw*.mp4'))) self.camera_meta = vidio.get_video_meta(self.camera_path) # TODO should read in the description file to get the correct sync location @@ -448,8 +446,10 @@ def fix_keys(alf_object): self.times = self.ttls[self.tdiff:] self.short_flag = False + self.frate = round(1 / np.nanmedian(np.diff(self.ttl_times))) + if behavior: - self.trials = alfio.load_file_content(next(alf_path.glob(f'_ibl_trials.table.*.pqt'))) + self.trials = alfio.load_file_content(next(alf_path.glob('_ibl_trials.table.*.pqt'))) self.dlc = alfio.load_file_content(next(alf_path.glob(f'_ibl_{self.label}Camera.dlc.*.pqt'))) self.dlc = likelihood_threshold(self.dlc) @@ -614,7 +614,6 @@ def clean_shifts(self, x, n=1): return np.cumsum(dy) + y[0] - def qc_shifts(self, shifts, shifts_filt): ttl_per = (np.abs(self.tdiff) / self.camera_meta['length']) * 100 if self.tdiff < 0 else 0 @@ -622,39 +621,40 @@ def qc_shifts(self, shifts, shifts_filt): shifts_sum = np.where(np.abs(np.diff(shifts)) > 10)[0].size shifts_filt_sum = np.where(np.abs(np.diff(shifts_filt)) > 1)[0].size - qc = {} + qc = dict() qc['ttl_per'] = ttl_per qc['nan_per'] = nan_per qc['shifts_sum'] = shifts_sum qc['shifts_filt_sum'] = shifts_filt_sum - return qc - - # # If more than 10% of ttls are missing we don't get new times - # if ttl_per > 10: - # return False - # # If too many of the shifts are nans it means the alignment is not accurate - # if nan_per > 40: - # return False - # # If there are too many artefacts could be errors - # if shifts_sum > 80: - # return False - # # If there are jumps > 1 in the filtered shifts then there is a problem - # if shifts_filt_sum > 0: - # return False - # - # return True + qc_outcome = True + # If more than 10% of ttls are missing we don't get new times + if ttl_per > 10: + qc_outcome = False + # If too many of the shifts are nans it means the alignment is not accurate + if nan_per > 40: + qc_outcome = False + # If there are too many artefacts could be errors + if shifts_sum > 60: + qc_outcome = False + # If there are jumps > 1 in the filtered shifts then there is a problem + if shifts_filt_sum > 0: + qc_outcome = False + + return qc, qc_outcome def extract_times(self, shifts_filt, t_shifts): - fps = 1 / np.nanmean(np.diff(self.ttl_times)) - t_new = t_shifts - (shifts_filt * 1 / fps) + t_new = t_shifts - (shifts_filt * 1 / self.frate) fcn = interpolate.interp1d(t_shifts, t_new, fill_value="extrapolate") new_times = fcn(self.ttl_times) - # TODO if short we need to interpolate the end times + if self.tdiff < 0: + to_app = (np.arange(np.abs(self.tdiff), ) + 1) / self.frate + new_times[-1] + new_times = np.r_[new_times, to_app] return new_times + @staticmethod def single_cluster_raster(spike_times, events, trial_idx, dividers, colors, labels, weights=None, fr=True, norm=False, @@ -728,7 +728,7 @@ def single_cluster_raster(spike_times, events, trial_idx, dividers, colors, labe def plot_with_behavior(self): - dlc = likelihood_threshold(self.dlc) + self.dlc = likelihood_threshold(self.dlc) trial_idx, dividers = find_trial_ids(self.trials, sort='side') feature_ext = get_speed(self.dlc, self.camera_times, self.label, feature='paw_r') feature_new = get_speed(self.dlc, self.new_times, self.label, feature='paw_r') @@ -790,7 +790,7 @@ def plot_with_behavior(self): ax22.set_xlabel('Time from first move') self.single_cluster_raster(self.new_times, self.trials['firstMovement_times'].values, trial_idx, dividers, ['g', 'y'], - ['left', 'right'], weights=feature_new, fr=False, axs=[ax31, ax32]) + ['left', 'right'], weights=feature_new, fr=False, axs=[ax31, ax32]) ax31.sharex(ax32) ax31.set_ylabel('Paw r velocity') ax31.set_title('New times') @@ -863,10 +863,9 @@ def process(self): for vals in out[:-1]: self.all_me = np.r_[self.all_me, vals] - frate = round(1 / np.nanmedian(np.diff(self.ttl_times))) toverlap = self.twin - 1 all_me = np.r_[np.full((int(self.camera_meta['fps'] * toverlap)), np.nan), self.all_me] - to_app = self.times[0] - ((np.arange(int(self.camera_meta['fps'] * toverlap), ) + 1) / frate)[::-1] + to_app = self.times[0] - ((np.arange(int(self.camera_meta['fps'] * toverlap), ) + 1) / self.frate)[::-1] times = np.r_[to_app, self.times] wg = WindowGenerator(all_me.size - 1, int(self.camera_meta['fps'] * self.twin), @@ -888,7 +887,7 @@ def process(self): shifts_filt = self.clean_shifts(shifts_filt, n=1) self.shifts_filt = self.clean_shifts(shifts_filt, n=2) - self.qc = self.qc_shifts(self.shifts, self.shifts_filt) + self.qc, self.qc_outcome = self.qc_shifts(self.shifts, self.shifts_filt) self.new_times = self.extract_times(self.shifts_filt, self.t_shifts) From e5b71cf7564d8ecb2674a8c99faca314ea67193d Mon Sep 17 00:00:00 2001 From: Mayo Faulkner Date: Tue, 26 Sep 2023 13:53:55 +0100 Subject: [PATCH 08/68] report server health to data repo not lab --- ibllib/pipes/local_server.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/ibllib/pipes/local_server.py b/ibllib/pipes/local_server.py index 895b0f20b..47f6322b5 100644 --- a/ibllib/pipes/local_server.py +++ b/ibllib/pipes/local_server.py @@ -10,7 +10,7 @@ from one.api import ONE from one.webclient import AlyxClient -from one.remote.globus import get_lab_from_endpoint_id +from one.remote.globus import get_lab_from_endpoint_id, get_local_endpoint_id from iblutil.util import setup_logger from ibllib.io.extractors.base import get_pipeline, get_task_protocol, get_session_extractor_type @@ -74,9 +74,10 @@ def report_health(one): status.update(_get_volume_usage('/mnt/s0/Data', 'raid')) status.update(_get_volume_usage('/', 'system')) - lab_names = get_lab_from_endpoint_id(alyx=one.alyx) - for ln in lab_names: - one.alyx.json_field_update(endpoint='labs', uuid=ln, field_name='json', data=status) + data_repos = one.alyx.rest('data-repository', 'list', globus_endpoint_id=get_local_endpoint_id()) + + for dr in data_repos: + one.alyx.json_field_update(endpoint='data-repository', uuid=dr['name'], field_name='json', data=status) def job_creator(root_path, one=None, dry=False, rerun=False, max_md5_size=None): From a8850a2ad947947b566c9183d4a9c3055232b51f Mon Sep 17 00:00:00 2001 From: Mayo Faulkner Date: Tue, 26 Sep 2023 14:27:34 +0100 Subject: [PATCH 09/68] circular imports --- ibllib/io/extractors/video_motion.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ibllib/io/extractors/video_motion.py b/ibllib/io/extractors/video_motion.py index c703f243a..a2bd002f0 100644 --- a/ibllib/io/extractors/video_motion.py +++ b/ibllib/io/extractors/video_motion.py @@ -21,7 +21,7 @@ from iblutil.util import Bunch from ibllib.io.extractors.ephys_fpga import get_sync_fronts, get_sync_and_chn_map import ibllib.io.raw_data_loaders as raw -from ibllib.io.extractors.camera import CameraTimestampsBpod +import ibllib.io.extractors.camera as cam import brainbox.video as video import brainbox.behavior.wheel as wh from brainbox.singlecell import bin_spikes @@ -431,7 +431,7 @@ def fix_keys(alf_object): sr = get_sync_fronts(sync, chmap[f'{self.label}_camera']) self.ttls = sr.times[::2] else: - cam_extractor = CameraTimestampsBpod(session_path=self.session_path) + cam_extractor = cam.CameraTimestampsBpod(session_path=self.session_path) cam_extractor.bpod_trials = raw.load_data(self.session_path, task_collection='raw_behavior_data') self.ttls = cam_extractor._times_from_bpod() From 5c595360574c4157ed7677a04c7cc3caf5886ba1 Mon Sep 17 00:00:00 2001 From: Mayo Faulkner Date: Wed, 27 Sep 2023 11:56:56 +0100 Subject: [PATCH 10/68] upload plot to alyx --- ibllib/io/extractors/camera.py | 2 +- ibllib/io/extractors/video_motion.py | 15 +++++++++++++++ 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/ibllib/io/extractors/camera.py b/ibllib/io/extractors/camera.py index 707e0764f..e4b201674 100644 --- a/ibllib/io/extractors/camera.py +++ b/ibllib/io/extractors/camera.py @@ -153,7 +153,7 @@ def _extract(self, sync=None, chmap=None, video_path=None, sync_label='audio', _logger.warning('Attempting to align using wheel') try: - motion_class = MotionAlignmentFullSession(self.session_path, self.label, behavior=False) + motion_class = MotionAlignmentFullSession(self.session_path, self.label, behavior=False, upload=True) new_times = motion_class.process() if not motion_class.qc_outcome: raise ValueError(f'Wheel alignment failed to pass qc: {motion_class.qc}') diff --git a/ibllib/io/extractors/video_motion.py b/ibllib/io/extractors/video_motion.py index a2bd002f0..af188c0d8 100644 --- a/ibllib/io/extractors/video_motion.py +++ b/ibllib/io/extractors/video_motion.py @@ -22,6 +22,7 @@ from ibllib.io.extractors.ephys_fpga import get_sync_fronts, get_sync_and_chn_map import ibllib.io.raw_data_loaders as raw import ibllib.io.extractors.camera as cam +from ibllib.plots.snapshot import ReportSnapshot import brainbox.video as video import brainbox.behavior.wheel as wh from brainbox.singlecell import bin_spikes @@ -400,12 +401,17 @@ def __init__(self, session_path, label, **kwargs): self.label = label self.threshold = kwargs.get('threshold', 20) self.behavior = kwargs.get('behavior', False) + self.upload = kwargs.get('upload', False) self.twin = kwargs.get('twin', 150) self.nprocess = kwargs.get('nprocess', int(cpu_count() - cpu_count() / 4)) self.load_data(sync=kwargs.get('sync', 'nidq'), location=kwargs.get('location', None), behavior=self.behavior) self.roi, self.mask = self.get_roi_mask() + if self.upload: + self.one = ONE(mode='remote') + self.eid = self.one.path2eid(self.session_path) + def load_data(self, sync='nidq', location=None, behavior=False): def fix_keys(alf_object): ob = Bunch() @@ -891,4 +897,13 @@ def process(self): self.new_times = self.extract_times(self.shifts_filt, self.t_shifts) + if self.upload: + fig = self.plot_with_behavior() if self.behavior else self.plot_without_behavior() + save_fig_path = Path(self.session_path.joinpath('snapshot', 'video', 'video_wheel_alignment.png')) + save_fig_path.parent.mkdir(exist_ok=True, parents=True) + fig.savefig(save_fig_path) + snp = ReportSnapshot(self.session_path, self.eid, content_type='session', one=self.one) + snp.outputs = [save_fig_path] + snp.register_images(widths=['orig']) + return self.new_times From 3cd2ca8da5cfc68a540c59ba9505312e5a158ee8 Mon Sep 17 00:00:00 2001 From: Mayo Faulkner Date: Wed, 27 Sep 2023 12:40:44 +0100 Subject: [PATCH 11/68] remove circular imports --- ibllib/io/extractors/camera.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/ibllib/io/extractors/camera.py b/ibllib/io/extractors/camera.py index e4b201674..4bcb0699c 100644 --- a/ibllib/io/extractors/camera.py +++ b/ibllib/io/extractors/camera.py @@ -16,7 +16,7 @@ from ibllib.io.extractors.base import get_session_extractor_type from ibllib.io.extractors.ephys_fpga import get_sync_fronts, get_sync_and_chn_map import ibllib.io.raw_data_loaders as raw -from ibllib.io.extractors.video_motion import MotionAlignmentFullSession +import ibllib.io.extractors.video_motion as vmotion from ibllib.io.extractors.base import ( BaseBpodTrialsExtractor, BaseExtractor, @@ -153,7 +153,8 @@ def _extract(self, sync=None, chmap=None, video_path=None, sync_label='audio', _logger.warning('Attempting to align using wheel') try: - motion_class = MotionAlignmentFullSession(self.session_path, self.label, behavior=False, upload=True) + motion_class = vmotion.MotionAlignmentFullSession(self.session_path, self.label, behavior=False, + upload=True) new_times = motion_class.process() if not motion_class.qc_outcome: raise ValueError(f'Wheel alignment failed to pass qc: {motion_class.qc}') From 935c0c7bb3813d7089f7ea8b8bfbf9f7489c3024 Mon Sep 17 00:00:00 2001 From: Chris Langfield <34426450+chris-langfield@users.noreply.github.com> Date: Wed, 4 Oct 2023 09:24:07 +0100 Subject: [PATCH 12/68] Density gain option (#652) * add gain option and cleanup docstring * flake --------- Co-authored-by: chris-langfield --- ibllib/plots/misc.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/ibllib/plots/misc.py b/ibllib/plots/misc.py index 133eb12e8..36cd56afb 100644 --- a/ibllib/plots/misc.py +++ b/ibllib/plots/misc.py @@ -74,13 +74,19 @@ def insert_zeros(trace): class Density: - def __init__(self, w, fs=1, cmap='Greys_r', ax=None, taxis=0, title=None, **kwargs): + def __init__(self, w, fs=30_000, cmap='Greys_r', ax=None, taxis=0, title=None, gain=None, **kwargs): """ - Matplotlib display of traces as a density display + Matplotlib display of traces as a density display using `imshow()`. :param w: 2D array (numpy array dimension nsamples, ntraces) - :param fs: sampling frequency (Hz) - :param ax: axis to plot in + :param fs: sampling frequency (Hz). [default: 30000] + :param cmap: Name of MPL colormap to use in `imshow()`. [default: 'Greys_r'] + :param ax: Axis to plot in. If `None`, a new one is created. [default: `None`] + :param taxis: Time axis of input array (w). [default: 0] + :param title: Title to display on plot. [default: `None`] + :param gain: Gain in dB to display. Note: overrides `vmin` and `vmax` kwargs to `imshow()`. + Default: [`None` (auto)] + :param kwargs: Key word arguments passed to `imshow()` :return: None """ w = w.reshape(w.shape[0], -1) @@ -98,6 +104,9 @@ def __init__(self, w, fs=1, cmap='Greys_r', ax=None, taxis=0, title=None, **kwar self.figure, ax = plt.subplots() else: self.figure = ax.get_figure() + if gain: + kwargs["vmin"] = - 4 * (10 ** (gain / 20)) + kwargs["vmax"] = -kwargs["vmin"] self.im = ax.imshow(w, aspect='auto', cmap=cmap, extent=extent, origin=origin, **kwargs) ax.set_ylabel(ylabel) ax.set_xlabel(xlabel) From 6bc95d23fdb770109c344f0f306cd76d551ad06a Mon Sep 17 00:00:00 2001 From: Gaelle Date: Fri, 6 Oct 2023 12:17:42 +0200 Subject: [PATCH 13/68] doc probe geometry --- .../loading_data/loading_raw_ephys_data.ipynb | 52 +++++++++++++++++++ 1 file changed, 52 insertions(+) diff --git a/examples/loading_data/loading_raw_ephys_data.ipynb b/examples/loading_data/loading_raw_ephys_data.ipynb index 3c5b3153d..f8fd8ed37 100644 --- a/examples/loading_data/loading_raw_ephys_data.ipynb +++ b/examples/loading_data/loading_raw_ephys_data.ipynb @@ -326,6 +326,58 @@ "destriped = destripe(raw, fs=sr.fs)" ] }, + { + "cell_type": "markdown", + "source": [ + "## Get the probe geometry" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "markdown", + "source": [ + "### Using the `eid` and `probe` information" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "import brainbox\n", + "channels = brainbox.io.one.load_channel_locations(eid, probe)\n", + "channels[probe][\"localCoordinates\"]" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "markdown", + "source": [ + "### Using the reader and the `.cbin` file" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "sr = spikeglx.Reader(bin_file)\n", + "sr.geometry" + ], + "metadata": { + "collapsed": false + } + }, { "cell_type": "markdown", "id": "9851b10d", From a49741123a90b34be005f2f375df404d95e3c7af Mon Sep 17 00:00:00 2001 From: Miles Wells Date: Fri, 6 Oct 2023 15:35:46 +0300 Subject: [PATCH 14/68] sync extension optional; reduce logging clutter --- ibllib/pipes/dynamic_pipeline.py | 2 +- ibllib/pipes/local_server.py | 6 +----- ibllib/pipes/training_status.py | 2 +- 3 files changed, 3 insertions(+), 7 deletions(-) diff --git a/ibllib/pipes/dynamic_pipeline.py b/ibllib/pipes/dynamic_pipeline.py index d4d67f374..b721d4e9c 100644 --- a/ibllib/pipes/dynamic_pipeline.py +++ b/ibllib/pipes/dynamic_pipeline.py @@ -152,7 +152,7 @@ def make_pipeline(session_path, **pkwargs): # Syncing tasks (sync, sync_args), = acquisition_description['sync'].items() sync_args['sync_collection'] = sync_args.pop('collection') # rename the key so it matches task run arguments - sync_args['sync_ext'] = sync_args.pop('extension') + sync_args['sync_ext'] = sync_args.pop('extension', None) sync_args['sync_namespace'] = sync_args.pop('acquisition_software', None) sync_kwargs = {'sync': sync, **sync_args} sync_tasks = [] diff --git a/ibllib/pipes/local_server.py b/ibllib/pipes/local_server.py index 895b0f20b..b07daf4d5 100644 --- a/ibllib/pipes/local_server.py +++ b/ibllib/pipes/local_server.py @@ -109,10 +109,7 @@ def job_creator(root_path, one=None, dry=False, rerun=False, max_md5_size=None): list of dicts A list of any datasets registered (only for legacy sessions) """ - for _ in range(10): - _logger.info('#' * 110) _logger.info('Start looking for new sessions...') - _logger.info('#' * 110) if not one: one = ONE(cache_rest=None) rc = IBLRegistrationClient(one=one) @@ -152,8 +149,7 @@ def job_creator(root_path, one=None, dry=False, rerun=False, max_md5_size=None): if pipe is not None: pipes.append(pipe) except Exception: - _logger.error(traceback.format_exc()) - _logger.warning(f'Creating session / registering raw datasets {session_path} errored') + _logger.error(f'Failed to register session %s:\n%s', session_path.relative_to(root_path), traceback.format_exc()) continue return pipes, all_datasets diff --git a/ibllib/pipes/training_status.py b/ibllib/pipes/training_status.py index fc73304c6..fa102ee68 100644 --- a/ibllib/pipes/training_status.py +++ b/ibllib/pipes/training_status.py @@ -475,7 +475,7 @@ def get_data_collection(session_path): collections = [pipeline.tasks.get(task).kwargs['collection'] for task in trials_tasks] if len(collections) == 1 and collections[0] == 'raw_behavior_data': alf_collections = ['alf'] - elif all(['raw_task_data' in c for c in collections]): + elif all('raw_task_data' in c for c in collections): alf_collections = [f'alf/task_{c[-2:]}' for c in collections] else: alf_collections = None From 6c0495662373daca5331eeb4c31f27562a6aaa35 Mon Sep 17 00:00:00 2001 From: Miles Wells Date: Fri, 6 Oct 2023 15:41:43 +0300 Subject: [PATCH 15/68] flake8 --- ibllib/pipes/local_server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ibllib/pipes/local_server.py b/ibllib/pipes/local_server.py index b07daf4d5..464cdc38d 100644 --- a/ibllib/pipes/local_server.py +++ b/ibllib/pipes/local_server.py @@ -149,7 +149,7 @@ def job_creator(root_path, one=None, dry=False, rerun=False, max_md5_size=None): if pipe is not None: pipes.append(pipe) except Exception: - _logger.error(f'Failed to register session %s:\n%s', session_path.relative_to(root_path), traceback.format_exc()) + _logger.error('Failed to register session %s:\n%s', session_path.relative_to(root_path), traceback.format_exc()) continue return pipes, all_datasets From 68fb8cd807ca91d5a51ff9c985442934825ce8ac Mon Sep 17 00:00:00 2001 From: k1o0 Date: Fri, 6 Oct 2023 16:03:36 +0300 Subject: [PATCH 16/68] Task qc extractor refactor (#649) * Habituation extract phase and position * Independent task QC method in behaviour tasks * var names * flake8 * Initialize settings property --- .gitignore | 1 + ibllib/ephys/ephysqc.py | 2 +- ibllib/io/extractors/biased_trials.py | 24 +- ibllib/io/extractors/ephys_fpga.py | 60 +++-- ibllib/io/extractors/habituation_trials.py | 18 +- ibllib/io/extractors/mesoscope.py | 21 +- ibllib/io/extractors/training_trials.py | 10 +- ibllib/io/extractors/training_wheel.py | 4 +- ibllib/io/session_params.py | 2 +- ibllib/pipes/behavior_tasks.py | 190 +++++++++------- ibllib/qc/task_extractors.py | 9 +- ibllib/qc/task_metrics.py | 252 ++++++++++----------- 12 files changed, 320 insertions(+), 273 deletions(-) diff --git a/.gitignore b/.gitignore index 906c5d9ac..e291b8572 100644 --- a/.gitignore +++ b/.gitignore @@ -7,6 +7,7 @@ __pycache__ python/scratch .idea/* .vscode/ +*.code-workspace *checkpoint.ipynb build/ venv/ diff --git a/ibllib/ephys/ephysqc.py b/ibllib/ephys/ephysqc.py index b8721bfe2..16ab9f870 100644 --- a/ibllib/ephys/ephysqc.py +++ b/ibllib/ephys/ephysqc.py @@ -580,7 +580,7 @@ def _qc_from_path(sess_path, display=True): sync, chmap = ephys_fpga.get_main_probe_sync(sess_path, bin_exists=False) _ = ephys_fpga.extract_all(sess_path, output_path=temp_alf_folder, save=True) # check that the output is complete - fpga_trials = ephys_fpga.extract_behaviour_sync(sync, chmap=chmap, display=display) + fpga_trials, *_ = ephys_fpga.extract_behaviour_sync(sync, chmap=chmap, display=display) # align with the bpod bpod2fpga = ephys_fpga.align_with_bpod(temp_alf_folder.parent) alf_trials = alfio.load_object(temp_alf_folder, 'trials') diff --git a/ibllib/io/extractors/biased_trials.py b/ibllib/io/extractors/biased_trials.py index c7c16d6c0..16d8f8111 100644 --- a/ibllib/io/extractors/biased_trials.py +++ b/ibllib/io/extractors/biased_trials.py @@ -95,12 +95,12 @@ class TrialsTableBiased(BaseBpodTrialsExtractor): intervals, goCue_times, response_times, choice, stimOn_times, contrastLeft, contrastRight, feedback_times, feedbackType, rewardVolume, probabilityLeft, firstMovement_times Additionally extracts the following wheel data: - wheel_timestamps, wheel_position, wheel_moves_intervals, wheel_moves_peak_amplitude + wheel_timestamps, wheel_position, wheelMoves_intervals, wheelMoves_peakAmplitude """ save_names = ('_ibl_trials.table.pqt', None, None, '_ibl_wheel.timestamps.npy', '_ibl_wheel.position.npy', '_ibl_wheelMoves.intervals.npy', '_ibl_wheelMoves.peakAmplitude.npy', None, None) - var_names = ('table', 'stimOff_times', 'stimFreeze_times', 'wheel_timestamps', 'wheel_position', 'wheel_moves_intervals', - 'wheel_moves_peak_amplitude', 'peakVelocity_times', 'is_final_movement') + var_names = ('table', 'stimOff_times', 'stimFreeze_times', 'wheel_timestamps', 'wheel_position', 'wheelMoves_intervals', + 'wheelMoves_peakAmplitude', 'peakVelocity_times', 'is_final_movement') def _extract(self, extractor_classes=None, **kwargs): base = [Intervals, GoCueTimes, ResponseTimes, Choice, StimOnOffFreezeTimes, ContrastLR, FeedbackTimes, FeedbackType, @@ -120,13 +120,13 @@ class TrialsTableEphys(BaseBpodTrialsExtractor): intervals, goCue_times, response_times, choice, stimOn_times, contrastLeft, contrastRight, feedback_times, feedbackType, rewardVolume, probabilityLeft, firstMovement_times Additionally extracts the following wheel data: - wheel_timestamps, wheel_position, wheel_moves_intervals, wheel_moves_peak_amplitude + wheel_timestamps, wheel_position, wheelMoves_intervals, wheelMoves_peakAmplitude """ save_names = ('_ibl_trials.table.pqt', None, None, '_ibl_wheel.timestamps.npy', '_ibl_wheel.position.npy', '_ibl_wheelMoves.intervals.npy', '_ibl_wheelMoves.peakAmplitude.npy', None, None, None, None, '_ibl_trials.quiescencePeriod.npy') - var_names = ('table', 'stimOff_times', 'stimFreeze_times', 'wheel_timestamps', 'wheel_position', 'wheel_moves_intervals', - 'wheel_moves_peak_amplitude', 'peakVelocity_times', 'is_final_movement', + var_names = ('table', 'stimOff_times', 'stimFreeze_times', 'wheel_timestamps', 'wheel_position', 'wheelMoves_intervals', + 'wheelMoves_peakAmplitude', 'peakVelocity_times', 'is_final_movement', 'phase', 'position', 'quiescence') def _extract(self, extractor_classes=None, **kwargs): @@ -154,16 +154,16 @@ class BiasedTrials(BaseBpodTrialsExtractor): None, None, '_ibl_trials.quiescencePeriod.npy') var_names = ('goCueTrigger_times', 'stimOnTrigger_times', 'itiIn_times', 'stimOffTrigger_times', 'stimFreezeTrigger_times', 'errorCueTrigger_times', 'table', 'stimOff_times', 'stimFreeze_times', 'wheel_timestamps', 'wheel_position', - 'wheel_moves_intervals', 'wheel_moves_peak_amplitude', 'peakVelocity_times', 'is_final_movement', 'included', + 'wheelMoves_intervals', 'wheelMoves_peakAmplitude', 'peakVelocity_times', 'is_final_movement', 'included', 'phase', 'position', 'quiescence') - def _extract(self, extractor_classes=None, **kwargs): + def _extract(self, extractor_classes=None, **kwargs) -> dict: base = [GoCueTriggerTimes, StimOnTriggerTimes, ItiInTimes, StimOffTriggerTimes, StimFreezeTriggerTimes, ErrorCueTriggerTimes, TrialsTableBiased, IncludedTrials, PhasePosQuiescence] # Exclude from trials table out, _ = run_extractor_classes(base, session_path=self.session_path, bpod_trials=self.bpod_trials, settings=self.settings, save=False, task_collection=self.task_collection) - return tuple(out.pop(x) for x in self.var_names) + return {k: out[k] for k in self.var_names} class EphysTrials(BaseBpodTrialsExtractor): @@ -177,16 +177,16 @@ class EphysTrials(BaseBpodTrialsExtractor): '_ibl_trials.included.npy', None, None, '_ibl_trials.quiescencePeriod.npy') var_names = ('goCueTrigger_times', 'stimOnTrigger_times', 'itiIn_times', 'stimOffTrigger_times', 'stimFreezeTrigger_times', 'errorCueTrigger_times', 'table', 'stimOff_times', 'stimFreeze_times', 'wheel_timestamps', 'wheel_position', - 'wheel_moves_intervals', 'wheel_moves_peak_amplitude', 'peakVelocity_times', 'is_final_movement', 'included', + 'wheelMoves_intervals', 'wheelMoves_peakAmplitude', 'peakVelocity_times', 'is_final_movement', 'included', 'phase', 'position', 'quiescence') - def _extract(self, extractor_classes=None, **kwargs): + def _extract(self, extractor_classes=None, **kwargs) -> dict: base = [GoCueTriggerTimes, StimOnTriggerTimes, ItiInTimes, StimOffTriggerTimes, StimFreezeTriggerTimes, ErrorCueTriggerTimes, TrialsTableEphys, IncludedTrials, PhasePosQuiescence] # Exclude from trials table out, _ = run_extractor_classes(base, session_path=self.session_path, bpod_trials=self.bpod_trials, settings=self.settings, save=False, task_collection=self.task_collection) - return tuple(out.pop(x) for x in self.var_names) + return {k: out[k] for k in self.var_names} def extract_all(session_path, save=False, bpod_trials=False, settings=False, extra_classes=None, diff --git a/ibllib/io/extractors/ephys_fpga.py b/ibllib/io/extractors/ephys_fpga.py index 98bdcdd25..74ac1e551 100644 --- a/ibllib/io/extractors/ephys_fpga.py +++ b/ibllib/io/extractors/ephys_fpga.py @@ -17,7 +17,7 @@ from iblutil.spacer import Spacer import ibllib.exceptions as err -from ibllib.io import raw_data_loaders, session_params +from ibllib.io import raw_data_loaders as raw, session_params from ibllib.io.extractors.bpod_trials import extract_all as bpod_extract_all import ibllib.io.extractors.base as extractors_base from ibllib.io.extractors.training_wheel import extract_wheel_moves @@ -554,7 +554,7 @@ def extract_behaviour_sync(sync, chmap=None, display=False, bpod_trials=None, tm ax.set_yticks([0, 1, 2, 3, 4, 5]) ax.set_ylim([0, 5]) - return trials + return trials, frame2ttl, audio, bpod def extract_sync(session_path, overwrite=False, ephys_files=None, namespace='spikeglx'): @@ -734,6 +734,7 @@ def __init__(self, *args, bpod_trials=None, bpod_extractor=None, **kwargs): super().__init__(*args, **kwargs) self.bpod2fpga = None self.bpod_trials = bpod_trials + self.frame2ttl = self.audio = self.bpod = self.settings = None if bpod_extractor: self.bpod_extractor = bpod_extractor self._update_var_names() @@ -750,14 +751,37 @@ def _update_var_names(self, bpod_fields=None, bpod_rsync_fields=None): A set of Bpod trials fields to keep. bpod_rsync_fields : tuple A set of Bpod trials fields to sync to the DAQ times. - - TODO Turn into property getter; requires ensuring the output field are the same for legacy """ if self.bpod_extractor: - self.var_names = self.bpod_extractor.var_names - self.save_names = self.bpod_extractor.save_names - self.bpod_rsync_fields = bpod_rsync_fields or self._time_fields(self.bpod_extractor.var_names) - self.bpod_fields = bpod_fields or [x for x in self.bpod_extractor.var_names if x not in self.bpod_rsync_fields] + for var_name, save_name in zip(self.bpod_extractor.var_names, self.bpod_extractor.save_names): + if var_name not in self.var_names: + self.var_names += (var_name,) + self.save_names += (save_name,) + + # self.var_names = self.bpod_extractor.var_names + # self.save_names = self.bpod_extractor.save_names + self.settings = self.bpod_extractor.settings # This is used by the TaskQC + self.bpod_rsync_fields = bpod_rsync_fields + if self.bpod_rsync_fields is None: + self.bpod_rsync_fields = tuple(self._time_fields(self.bpod_extractor.var_names)) + if 'table' in self.bpod_extractor.var_names: + if not self.bpod_trials: + self.bpod_trials = self.bpod_extractor.extract(save=False) + table_keys = alfio.AlfBunch.from_df(self.bpod_trials['table']).keys() + self.bpod_rsync_fields += tuple(self._time_fields(table_keys)) + elif bpod_rsync_fields: + self.bpod_rsync_fields = bpod_rsync_fields + excluded = (*self.bpod_rsync_fields, 'table') + if bpod_fields: + assert not set(self.bpod_fields).intersection(excluded), 'bpod_fields must not also be bpod_rsync_fields' + self.bpod_fields = bpod_fields + elif self.bpod_extractor: + self.bpod_fields = tuple(x for x in self.bpod_extractor.var_names if x not in excluded) + if 'table' in self.bpod_extractor.var_names: + if not self.bpod_trials: + self.bpod_trials = self.bpod_extractor.extract(save=False) + table_keys = alfio.AlfBunch.from_df(self.bpod_trials['table']).keys() + self.bpod_fields += (*[x for x in table_keys if x not in excluded], self.sync_field + '_bpod') @staticmethod def _time_fields(trials_attr) -> set: @@ -778,7 +802,8 @@ def _time_fields(trials_attr) -> set: pattern = re.compile(fr'^[_\w]*({"|".join(FIELDS)})[_\w]*$') return set(filter(pattern.match, trials_attr)) - def _extract(self, sync=None, chmap=None, sync_collection='raw_ephys_data', task_collection='raw_behavior_data', **kwargs): + def _extract(self, sync=None, chmap=None, sync_collection='raw_ephys_data', + task_collection='raw_behavior_data', **kwargs) -> dict: """Extracts ephys trials by combining Bpod and FPGA sync pulses""" # extract the behaviour data from bpod if sync is None or chmap is None: @@ -804,7 +829,8 @@ def _extract(self, sync=None, chmap=None, sync_collection='raw_ephys_data', task else: tmin = tmax = None - fpga_trials = extract_behaviour_sync( + # Store the cleaned frame2ttl, audio, and bpod pulses as this will be used for QC + fpga_trials, self.frame2ttl, self.audio, self.bpod = extract_behaviour_sync( sync=sync, chmap=chmap, bpod_trials=self.bpod_trials, tmin=tmin, tmax=tmax) assert self.sync_field in self.bpod_trials and self.sync_field in fpga_trials self.bpod_trials[f'{self.sync_field}_bpod'] = np.copy(self.bpod_trials[self.sync_field]) @@ -827,18 +853,20 @@ def _extract(self, sync=None, chmap=None, sync_collection='raw_ephys_data', task # extract the wheel data wheel, moves = self.get_wheel_positions(sync=sync, chmap=chmap, tmin=tmin, tmax=tmax) from ibllib.io.extractors.training_wheel import extract_first_movement_times - settings = raw_data_loaders.load_settings(session_path=self.session_path, task_collection=task_collection) - min_qt = settings.get('QUIESCENT_PERIOD', None) + if not self.settings: + self.settings = raw.load_settings(session_path=self.session_path, task_collection=task_collection) + min_qt = self.settings.get('QUIESCENT_PERIOD', None) first_move_onsets, *_ = extract_first_movement_times(moves, out, min_qt=min_qt) out.update({'firstMovement_times': first_move_onsets}) # Re-create trials table trials_table = alfio.AlfBunch({x: out.pop(x) for x in table_columns}) out['table'] = trials_table.to_df() + out.update({f'wheel_{k}': v for k, v in wheel.items()}) + out.update({f'wheelMoves_{k}': v for k, v in moves.items()}) out = {k: out[k] for k in self.var_names if k in out} # Reorder output - assert tuple(filter(lambda x: 'wheel' not in x, self.var_names)) == tuple(out.keys()) - return [out[k] for k in out] + [wheel['timestamps'], wheel['position'], - moves['intervals'], moves['peakAmplitude']] + assert self.var_names == tuple(out.keys()) + return out def get_wheel_positions(self, *args, **kwargs): """Extract wheel and wheelMoves objects. @@ -882,7 +910,7 @@ def extract_all(session_path, sync_collection='raw_ephys_data', save=True, save_ If save is True, a list of file paths to the extracted data. """ # Extract Bpod trials - bpod_raw = raw_data_loaders.load_data(session_path, task_collection=task_collection) + bpod_raw = raw.load_data(session_path, task_collection=task_collection) assert bpod_raw is not None, 'No task trials data in raw_behavior_data - Exit' bpod_trials, *_ = bpod_extract_all( session_path=session_path, bpod_trials=bpod_raw, task_collection=task_collection, diff --git a/ibllib/io/extractors/habituation_trials.py b/ibllib/io/extractors/habituation_trials.py index a78a57eef..9dedbd3d5 100644 --- a/ibllib/io/extractors/habituation_trials.py +++ b/ibllib/io/extractors/habituation_trials.py @@ -15,16 +15,15 @@ class HabituationTrials(BaseBpodTrialsExtractor): var_names = ('feedbackType', 'rewardVolume', 'stimOff_times', 'contrastLeft', 'contrastRight', 'feedback_times', 'stimOn_times', 'stimOnTrigger_times', 'intervals', 'goCue_times', 'goCueTrigger_times', 'itiIn_times', 'stimOffTrigger_times', - 'stimCenterTrigger_times', 'stimCenter_times') + 'stimCenterTrigger_times', 'stimCenter_times', 'position', 'phase') def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - exclude = ['itiIn_times', 'stimOffTrigger_times', - 'stimCenter_times', 'stimCenterTrigger_times'] - self.save_names = tuple([f'_ibl_trials.{x}.npy' if x not in exclude else None - for x in self.var_names]) + exclude = ['itiIn_times', 'stimOffTrigger_times', 'stimCenter_times', + 'stimCenterTrigger_times', 'position', 'phase'] + self.save_names = tuple(f'_ibl_trials.{x}.npy' if x not in exclude else None for x in self.var_names) - def _extract(self): + def _extract(self) -> dict: # Extract all trials... # Get all stim_sync events detected @@ -101,9 +100,14 @@ def _extract(self): ["iti"][0][0] for tr in self.bpod_trials] ) + # Phase and position + out['position'] = np.array([t['position'] for t in self.bpod_trials]) + out['phase'] = np.array([t['stim_phase'] for t in self.bpod_trials]) + # NB: We lose the last trial because the stim off event occurs at trial_num + 1 n_trials = out['stimOff_times'].size - return [out[k][:n_trials] for k in self.var_names] + # return [out[k][:n_trials] for k in self.var_names] + return {k: out[k][:n_trials] for k in self.var_names} def extract_all(session_path, save=False, bpod_trials=False, settings=False, task_collection='raw_behavior_data', save_path=None): diff --git a/ibllib/io/extractors/mesoscope.py b/ibllib/io/extractors/mesoscope.py index 93491945e..561bb6343 100644 --- a/ibllib/io/extractors/mesoscope.py +++ b/ibllib/io/extractors/mesoscope.py @@ -100,7 +100,7 @@ def __init__(self, *args, sync_collection='raw_sync_data', **kwargs): super().__init__(*args, **kwargs) self.timeline = alfio.load_object(self.session_path / sync_collection, 'DAQdata', namespace='timeline') - def _extract(self, sync=None, chmap=None, sync_collection='raw_sync_data', **kwargs): + def _extract(self, sync=None, chmap=None, sync_collection='raw_sync_data', **kwargs) -> dict: if not (sync or chmap): sync, chmap = load_timeline_sync_and_chmap( self.session_path / sync_collection, timeline=self.timeline, chmap=chmap) @@ -110,20 +110,17 @@ def _extract(self, sync=None, chmap=None, sync_collection='raw_sync_data', **kwa trials = super()._extract(sync, chmap, sync_collection, extractor_type='ephys', **kwargs) # If no protocol number is defined, trim timestamps based on Bpod trials intervals - trials_table = trials[self.var_names.index('table')] + trials_table = trials['table'] bpod = get_sync_fronts(sync, chmap['bpod']) if kwargs.get('protocol_number') is None: tmin = trials_table.intervals_0.iloc[0] - 1 tmax = trials_table.intervals_1.iloc[-1] # Ensure wheel is cut off based on trials - wheel_ts_idx = self.var_names.index('wheel_timestamps') - mask = np.logical_and(tmin <= trials[wheel_ts_idx], trials[wheel_ts_idx] <= tmax) - trials[wheel_ts_idx] = trials[wheel_ts_idx][mask] - wheel_pos_idx = self.var_names.index('wheel_position') - trials[wheel_pos_idx] = trials[wheel_pos_idx][mask] - move_idx = self.var_names.index('wheelMoves_intervals') - mask = np.logical_and(trials[move_idx][:, 0] >= tmin, trials[move_idx][:, 0] <= tmax) - trials[move_idx] = trials[move_idx][mask, :] + mask = np.logical_and(tmin <= trials['wheel_timestamps'], trials['wheel_timestamps'] <= tmax) + trials['wheel_timestamps'] = trials['wheel_timestamps'][mask] + trials['wheel_position'] = trials['wheel_position'][mask] + mask = np.logical_and(trials['wheelMoves_intervals'][:, 0] >= tmin, trials['wheelMoves_intervals'][:, 0] <= tmax) + trials['wheelMoves_intervals'] = trials['wheelMoves_intervals'][mask, :] else: tmin, tmax = get_protocol_period(self.session_path, kwargs['protocol_number'], bpod) bpod = get_sync_fronts(sync, chmap['bpod'], tmin, tmax) @@ -138,7 +135,7 @@ def _extract(self, sync=None, chmap=None, sync_collection='raw_sync_data', **kwa valve_open_times = self.get_valve_open_times(driver_ttls=driver_out) assert len(valve_open_times) == sum(trials_table.feedbackType == 1) # TODO Relax assertion correct = trials_table.feedbackType == 1 - trials[self.var_names.index('valveOpen_times')][correct] = valve_open_times + trials['valveOpen_times'][correct] = valve_open_times trials_table.feedback_times[correct] = valve_open_times # Replace audio events @@ -191,7 +188,7 @@ def first_true(arr): trials_table.feedback_times[~correct] = error_cue trials_table.goCue_times = go_cue - return trials + return {k: trials[k] for k in self.var_names} def extract_wheel_sync(self, ticks=WHEEL_TICKS, radius=WHEEL_RADIUS_CM, coding='x4', tmin=None, tmax=None): """ diff --git a/ibllib/io/extractors/training_trials.py b/ibllib/io/extractors/training_trials.py index dc13ed7dd..41a69d815 100644 --- a/ibllib/io/extractors/training_trials.py +++ b/ibllib/io/extractors/training_trials.py @@ -682,8 +682,8 @@ class TrialsTable(BaseBpodTrialsExtractor): """ save_names = ('_ibl_trials.table.pqt', None, None, '_ibl_wheel.timestamps.npy', '_ibl_wheel.position.npy', '_ibl_wheelMoves.intervals.npy', '_ibl_wheelMoves.peakAmplitude.npy', None, None) - var_names = ('table', 'stimOff_times', 'stimFreeze_times', 'wheel_timestamps', 'wheel_position', 'wheel_moves_intervals', - 'wheel_moves_peak_amplitude', 'peakVelocity_times', 'is_final_movement') + var_names = ('table', 'stimOff_times', 'stimFreeze_times', 'wheel_timestamps', 'wheel_position', 'wheelMoves_intervals', + 'wheelMoves_peakAmplitude', 'peakVelocity_times', 'is_final_movement') def _extract(self, extractor_classes=None, **kwargs): base = [Intervals, GoCueTimes, ResponseTimes, Choice, StimOnOffFreezeTimes, ContrastLR, FeedbackTimes, FeedbackType, @@ -703,16 +703,16 @@ class TrainingTrials(BaseBpodTrialsExtractor): '_ibl_wheelMoves.intervals.npy', '_ibl_wheelMoves.peakAmplitude.npy', None, None, None, None, None) var_names = ('repNum', 'goCueTrigger_times', 'stimOnTrigger_times', 'itiIn_times', 'stimOffTrigger_times', 'stimFreezeTrigger_times', 'errorCueTrigger_times', 'table', 'stimOff_times', 'stimFreeze_times', - 'wheel_timestamps', 'wheel_position', 'wheel_moves_intervals', 'wheel_moves_peak_amplitude', + 'wheel_timestamps', 'wheel_position', 'wheelMoves_intervals', 'wheelMoves_peakAmplitude', 'peakVelocity_times', 'is_final_movement', 'phase', 'position', 'quiescence') - def _extract(self): + def _extract(self) -> dict: base = [RepNum, GoCueTriggerTimes, StimOnTriggerTimes, ItiInTimes, StimOffTriggerTimes, StimFreezeTriggerTimes, ErrorCueTriggerTimes, TrialsTable, PhasePosQuiescence] out, _ = run_extractor_classes( base, session_path=self.session_path, bpod_trials=self.bpod_trials, settings=self.settings, save=False, task_collection=self.task_collection) - return tuple(out.pop(x) for x in self.var_names) + return {k: out[k] for k in self.var_names} def extract_all(session_path, save=False, bpod_trials=None, settings=None, task_collection='raw_behavior_data', save_path=None): diff --git a/ibllib/io/extractors/training_wheel.py b/ibllib/io/extractors/training_wheel.py index 617b5f1df..2f1aded8c 100644 --- a/ibllib/io/extractors/training_wheel.py +++ b/ibllib/io/extractors/training_wheel.py @@ -385,8 +385,8 @@ class Wheel(BaseBpodTrialsExtractor): save_names = ('_ibl_wheel.timestamps.npy', '_ibl_wheel.position.npy', '_ibl_wheelMoves.intervals.npy', '_ibl_wheelMoves.peakAmplitude.npy', None, '_ibl_trials.firstMovement_times.npy', None) - var_names = ('wheel_timestamps', 'wheel_position', 'wheel_moves_intervals', - 'wheel_moves_peak_amplitude', 'peakVelocity_times', 'firstMovement_times', + var_names = ('wheel_timestamps', 'wheel_position', 'wheelMoves_intervals', + 'wheelMoves_peakAmplitude', 'peakVelocity_times', 'firstMovement_times', 'is_final_movement') def _extract(self): diff --git a/ibllib/io/session_params.py b/ibllib/io/session_params.py index 34e668ced..5bcaf2873 100644 --- a/ibllib/io/session_params.py +++ b/ibllib/io/session_params.py @@ -510,7 +510,7 @@ def prepare_experiment(session_path, acquisition_description=None, local=None, r # won't be preserved by create_basic_transfer_params by default remote = False if remote is False else params['REMOTE_DATA_FOLDER_PATH'] - # THis is in the docstring but still, if the session Path is absolute, we need to make it relative + # This is in the docstring but still, if the session Path is absolute, we need to make it relative if Path(session_path).is_absolute(): session_path = Path(*session_path.parts[-3:]) diff --git a/ibllib/pipes/behavior_tasks.py b/ibllib/pipes/behavior_tasks.py index 7cc317c28..6f1c8d506 100644 --- a/ibllib/pipes/behavior_tasks.py +++ b/ibllib/pipes/behavior_tasks.py @@ -9,14 +9,12 @@ from ibllib.oneibl.registration import get_lab from ibllib.pipes import base_tasks -from ibllib.io.raw_data_loaders import load_settings +from ibllib.io.raw_data_loaders import load_settings, load_bpod_fronts from ibllib.qc.task_extractors import TaskQCExtractor from ibllib.qc.task_metrics import HabituationQC, TaskQC from ibllib.io.extractors.ephys_passive import PassiveChoiceWorld -from ibllib.io.extractors import bpod_trials -from ibllib.io.extractors.base import get_session_extractor_type from ibllib.io.extractors.bpod_trials import get_bpod_extractor -from ibllib.io.extractors.ephys_fpga import extract_all +from ibllib.io.extractors.ephys_fpga import FpgaTrials, get_sync_and_chn_map from ibllib.io.extractors.mesoscope import TimelineTrials from ibllib.pipes import training_status from ibllib.plots.figures import BehaviourPlots @@ -73,25 +71,43 @@ def signature(self): } return signature - def _run(self, update=True): + def _run(self, update=True, save=True): """ Extracts an iblrig training session """ - extractor = bpod_trials.get_bpod_extractor(self.session_path, task_collection=self.collection) - trials, output_files = extractor.extract(task_collection=self.collection, save=True) + trials, output_files = self._extract_behaviour(save=save) if trials is None: return None if self.one is None or self.one.offline: return output_files + # Run the task QC + self._run_qc(trials, update=update) + return output_files + + def _extract_behaviour(self, **kwargs): + self.extractor = get_bpod_extractor(self.session_path, task_collection=self.collection) + self.extractor.default_path = self.output_collection + return self.extractor.extract(task_collection=self.collection, **kwargs) + + def _run_qc(self, trials_data=None, update=True): + if not self.extractor or trials_data is None: + trials_data, _ = self._extract_behaviour(save=False) + if not trials_data: + raise ValueError('No trials data found') + # Compile task data for QC qc = HabituationQC(self.session_path, one=self.one) - qc.extractor = TaskQCExtractor(self.session_path, sync_collection=self.sync_collection, + qc.extractor = TaskQCExtractor(self.session_path, lazy=True, sync_collection=self.sync_collection, one=self.one, sync_type=self.sync, task_collection=self.collection) + + # Currently only the data field is accessed + qc.extractor.data = qc.extractor.rename_data(trials_data.copy()) + namespace = 'task' if self.protocol_number is None else f'task_{self.protocol_number:02}' qc.run(update=update, namespace=namespace) - return output_files + return qc class TrialRegisterRaw(base_tasks.RegisterRawDataTask, base_tasks.BehaviourTask): @@ -213,6 +229,7 @@ def _run(self, **kwargs): class ChoiceWorldTrialsBpod(base_tasks.BehaviourTask): priority = 90 job_size = 'small' + extractor = None @property def signature(self): @@ -234,38 +251,53 @@ def signature(self): } return signature - def _run(self, update=True): + def _run(self, update=True, save=True): """ Extracts an iblrig training session """ - extractor = bpod_trials.get_bpod_extractor(self.session_path, task_collection=self.collection) - extractor.default_path = self.output_collection - trials, output_files = extractor.extract(task_collection=self.collection, save=True) + trials, output_files = self._extract_behaviour(save=save) if trials is None: return None if self.one is None or self.one.offline: return output_files + # Run the task QC + self._run_qc(trials) + + return output_files + + def _extract_behaviour(self, **kwargs): + self.extractor = get_bpod_extractor(self.session_path, task_collection=self.collection) + self.extractor.default_path = self.output_collection + return self.extractor.extract(task_collection=self.collection, **kwargs) + + def _run_qc(self, trials_data=None, update=True): + if not self.extractor or trials_data is None: + trials_data, _ = self._extract_behaviour(save=False) + if not trials_data: + raise ValueError('No trials data found') + # Compile task data for QC - type = get_session_extractor_type(self.session_path, task_collection=self.collection) - # FIXME Task data should not need re-extracting - if type == 'habituation': - qc = HabituationQC(self.session_path, one=self.one) - qc.extractor = TaskQCExtractor(self.session_path, one=self.one, sync_collection=self.sync_collection, - sync_type=self.sync, task_collection=self.collection) - else: # Update wheel data - qc = TaskQC(self.session_path, one=self.one) - qc.extractor = TaskQCExtractor(self.session_path, one=self.one, sync_collection=self.sync_collection, - sync_type=self.sync, task_collection=self.collection) - qc.extractor.wheel_encoding = 'X1' + qc_extractor = TaskQCExtractor(self.session_path, lazy=True, sync_collection=self.sync_collection, one=self.one, + sync_type=self.sync, task_collection=self.collection) + qc_extractor.data = qc_extractor.rename_data(trials_data) + if type(self.extractor).__name__ == 'HabituationTrials': + qc = HabituationQC(self.session_path, one=self.one, log=_logger) + else: + qc = TaskQC(self.session_path, one=self.one, log=_logger) + qc_extractor.wheel_encoding = 'X1' + qc_extractor.settings = self.extractor.settings + qc_extractor.frame_ttls, qc_extractor.audio_ttls = load_bpod_fronts( + self.session_path, task_collection=self.collection) + qc.extractor = qc_extractor + # Aggregate and update Alyx QC fields namespace = 'task' if self.protocol_number is None else f'task_{self.protocol_number:02}' qc.run(update=update, namespace=namespace) - - return output_files + return qc -class ChoiceWorldTrialsNidq(base_tasks.BehaviourTask): +class ChoiceWorldTrialsNidq(ChoiceWorldTrialsBpod): priority = 90 job_size = 'small' @@ -312,21 +344,41 @@ def _behaviour_criterion(self, update=True): "sessions", eid, "extended_qc", {"behavior": int(good_enough)} ) - def _extract_behaviour(self): - dsets, out_files = extract_all(self.session_path, self.sync_collection, task_collection=self.collection, - save_path=self.session_path.joinpath(self.output_collection), - protocol_number=self.protocol_number, save=True) + def _extract_behaviour(self, save=True, **kwargs): + # Extract Bpod trials + bpod_trials, _ = super()._extract_behaviour(save=False, **kwargs) - return dsets, out_files + # Sync Bpod trials to FPGA + sync, chmap = get_sync_and_chn_map(self.session_path, self.sync_collection) + self.extractor = FpgaTrials(self.session_path, bpod_trials=bpod_trials, bpod_extractor=self.extractor) + outputs, files = self.extractor.extract( + save=save, sync=sync, chmap=chmap, path_out=self.session_path.joinpath(self.output_collection), + task_collection=self.collection, protocol_number=self.protocol_number, **kwargs) + return outputs, files - def _run_qc(self, trials_data, update=True, plot_qc=True): - # Run the task QC - qc = TaskQC(self.session_path, one=self.one, log=_logger) - qc.extractor = TaskQCExtractor(self.session_path, lazy=True, one=qc.one, sync_collection=self.sync_collection, + def _run_qc(self, trials_data=None, update=False, plot_qc=False): + if not self.extractor or trials_data is None: + trials_data, _ = self._extract_behaviour(save=False) + if not trials_data: + raise ValueError('No trials data found') + + # Compile task data for QC + qc_extractor = TaskQCExtractor(self.session_path, lazy=True, sync_collection=self.sync_collection, one=self.one, sync_type=self.sync, task_collection=self.collection) - # Extract extra datasets required for QC - qc.extractor.data = trials_data # FIXME This line is pointless - qc.extractor.extract_data() + qc_extractor.data = qc_extractor.rename_data(trials_data.copy()) + if type(self.extractor).__name__ == 'HabituationTrials': + qc = HabituationQC(self.session_path, one=self.one, log=_logger) + else: + qc = TaskQC(self.session_path, one=self.one, log=_logger) + qc_extractor.settings = self.extractor.settings + # Add Bpod wheel data + wheel_ts_bpod = self.extractor.bpod2fpga(self.extractor.bpod_trials['wheel_timestamps']) + qc_extractor.data['wheel_timestamps_bpod'] = wheel_ts_bpod + qc_extractor.data['wheel_position_bpod'] = self.extractor.bpod_trials['wheel_position'] + qc_extractor.wheel_encoding = 'X4' + qc_extractor.frame_ttls = self.extractor.frame2ttl + qc_extractor.audio_ttls = self.extractor.audio + qc.extractor = qc_extractor # Aggregate and update Alyx QC fields namespace = 'task' if self.protocol_number is None else f'task_{self.protocol_number:02}' @@ -345,9 +397,10 @@ def _run_qc(self, trials_data, update=True, plot_qc=True): _logger.error('Could not create Trials QC Plot') _logger.error(traceback.format_exc()) self.status = -1 + return qc - def _run(self, update=True, plot_qc=True): - dsets, out_files = self._extract_behaviour() + def _run(self, update=True, plot_qc=True, save=True): + dsets, out_files = self._extract_behaviour(save=save) if not self.one or self.one.offline: return out_files @@ -378,63 +431,24 @@ def signature(self): for fn in filter(None, extractor.save_names)] return signature - def _extract_behaviour(self): + def _extract_behaviour(self, save=True, **kwargs): """Extract the Bpod trials data and Timeline acquired signals.""" # First determine the extractor from the task protocol - extractor = get_bpod_extractor(self.session_path, self.protocol, self.collection) - ret, _ = extractor.extract(save=False, task_collection=self.collection) - bpod_trials = {k: v for k, v in zip(extractor.var_names, ret)} + bpod_trials, _ = ChoiceWorldTrialsBpod._extract_behaviour(self, save=False, **kwargs) - trials = TimelineTrials(self.session_path, bpod_trials=bpod_trials) + # Sync Bpod trials to DAQ + self.extractor = TimelineTrials(self.session_path, bpod_trials=bpod_trials, bpod_extractor=self.extractor) save_path = self.session_path / self.output_collection - if not self._spacer_support(extractor.settings): + if not self._spacer_support(self.extractor.settings): _logger.warning('Protocol spacers not supported; setting protocol_number to None') self.protocol_number = None - dsets, out_files = trials.extract( - save=True, path_out=save_path, sync_collection=self.sync_collection, - task_collection=self.collection, protocol_number=self.protocol_number) - if not isinstance(dsets, dict): - dsets = {k: v for k, v in zip(trials.var_names, dsets)} - - self.timeline = trials.timeline # Store for QC later - self.frame2ttl = trials.frame2ttl - self.audio = trials.audio + dsets, out_files = self.extractor.extract( + save=save, path_out=save_path, sync_collection=self.sync_collection, + task_collection=self.collection, protocol_number=self.protocol_number, **kwargs) return dsets, out_files - def _run_qc(self, trials_data, update=True, **kwargs): - """ - Run the task QC and update Alyx with results. - - Parameters - ---------- - trials_data : dict - The extracted trials data. - update : bool - If true, update Alyx with the result. - - Notes - ----- - - Unlike the super class, currently the QC plots are not generated. - - Expects the frame2ttl and audio attributes to be set from running _extract_behaviour. - """ - # TODO Task QC extractor for Timeline - qc = TaskQC(self.session_path, one=self.one, log=_logger) - qc.extractor = TaskQCExtractor(self.session_path, lazy=True, one=qc.one, sync_collection=self.sync_collection, - sync_type=self.sync, task_collection=self.collection) - # Extract extra datasets required for QC - qc.extractor.data = TaskQCExtractor.rename_data(trials_data.copy()) - qc.extractor.load_raw_data() - - qc.extractor.frame_ttls = self.frame2ttl - qc.extractor.audio_ttls = self.audio - # qc.extractor.bpod_ttls = channel_events('bpod') - - # Aggregate and update Alyx QC fields - namespace = 'task' if self.protocol_number is None else f'task_{self.protocol_number:02}' - qc.run(update=update, namespace=namespace) - class TrainingStatus(base_tasks.BehaviourTask): priority = 90 diff --git a/ibllib/qc/task_extractors.py b/ibllib/qc/task_extractors.py index f0d46ed02..5f5269710 100644 --- a/ibllib/qc/task_extractors.py +++ b/ibllib/qc/task_extractors.py @@ -1,4 +1,5 @@ import logging +import warnings import numpy as np from scipy.interpolate import interp1d @@ -26,16 +27,16 @@ 'wheel_position', 'wheel_timestamps'] -class TaskQCExtractor(object): +class TaskQCExtractor: def __init__(self, session_path, lazy=False, one=None, download_data=False, bpod_only=False, sync_collection=None, sync_type=None, task_collection=None): """ - A class for extracting the task data required to perform task quality control + A class for extracting the task data required to perform task quality control. :param session_path: a valid session path :param lazy: if True, the data are not extracted immediately :param one: an instance of ONE, used to download the raw data if download_data is True :param download_data: if True, any missing raw data is downloaded via ONE - :param bpod_only: extract from from raw Bpod data only, even for FPGA sessions + :param bpod_only: extract from raw Bpod data only, even for FPGA sessions """ if not is_session_path(session_path): raise ValueError('Invalid session path') @@ -151,6 +152,8 @@ def extract_data(self): intervals_bpod to be assigned to the data attribute before calling this function. :return: """ + warnings.warn('The TaskQCExtractor.extract_data will be removed in the future, ' + 'use dynamic pipeline behaviour tasks instead.', DeprecationWarning) self.log.info(f'Extracting session: {self.session_path}') self.type = self.type or get_session_extractor_type(self.session_path, task_collection=self.task_collection) # Finds the sync type when it isn't explicitly set, if ephys we assume nidq otherwise bpod diff --git a/ibllib/qc/task_metrics.py b/ibllib/qc/task_metrics.py index 36f2b4806..42361645d 100644 --- a/ibllib/qc/task_metrics.py +++ b/ibllib/qc/task_metrics.py @@ -69,21 +69,21 @@ class TaskQC(base.QC): """A class for computing task QC metrics""" criteria = dict() - criteria['default'] = {"PASS": 0.99, "WARNING": 0.90, "FAIL": 0} # Note: WARNING was 0.95 prior to Aug 2022 - criteria['_task_stimOff_itiIn_delays'] = {"PASS": 0.99, "WARNING": 0} - criteria['_task_positive_feedback_stimOff_delays'] = {"PASS": 0.99, "WARNING": 0} - criteria['_task_negative_feedback_stimOff_delays'] = {"PASS": 0.99, "WARNING": 0} - criteria['_task_wheel_move_during_closed_loop'] = {"PASS": 0.99, "WARNING": 0} - criteria['_task_response_stimFreeze_delays'] = {"PASS": 0.99, "WARNING": 0} - criteria['_task_detected_wheel_moves'] = {"PASS": 0.99, "WARNING": 0} - criteria['_task_trial_length'] = {"PASS": 0.99, "WARNING": 0} - criteria['_task_goCue_delays'] = {"PASS": 0.99, "WARNING": 0} - criteria['_task_errorCue_delays'] = {"PASS": 0.99, "WARNING": 0} - criteria['_task_stimOn_delays'] = {"PASS": 0.99, "WARNING": 0} - criteria['_task_stimOff_delays'] = {"PASS": 0.99, "WARNING": 0} - criteria['_task_stimFreeze_delays'] = {"PASS": 0.99, "WARNING": 0} - criteria['_task_iti_delays'] = {"NOT_SET": 0} - criteria['_task_passed_trial_checks'] = {"NOT_SET": 0} + criteria['default'] = {'PASS': 0.99, 'WARNING': 0.90, 'FAIL': 0} # Note: WARNING was 0.95 prior to Aug 2022 + criteria['_task_stimOff_itiIn_delays'] = {'PASS': 0.99, 'WARNING': 0} + criteria['_task_positive_feedback_stimOff_delays'] = {'PASS': 0.99, 'WARNING': 0} + criteria['_task_negative_feedback_stimOff_delays'] = {'PASS': 0.99, 'WARNING': 0} + criteria['_task_wheel_move_during_closed_loop'] = {'PASS': 0.99, 'WARNING': 0} + criteria['_task_response_stimFreeze_delays'] = {'PASS': 0.99, 'WARNING': 0} + criteria['_task_detected_wheel_moves'] = {'PASS': 0.99, 'WARNING': 0} + criteria['_task_trial_length'] = {'PASS': 0.99, 'WARNING': 0} + criteria['_task_goCue_delays'] = {'PASS': 0.99, 'WARNING': 0} + criteria['_task_errorCue_delays'] = {'PASS': 0.99, 'WARNING': 0} + criteria['_task_stimOn_delays'] = {'PASS': 0.99, 'WARNING': 0} + criteria['_task_stimOff_delays'] = {'PASS': 0.99, 'WARNING': 0} + criteria['_task_stimFreeze_delays'] = {'PASS': 0.99, 'WARNING': 0} + criteria['_task_iti_delays'] = {'NOT_SET': 0} + criteria['_task_passed_trial_checks'] = {'NOT_SET': 0} @staticmethod def _thresholding(qc_value, thresholds=None): @@ -100,7 +100,7 @@ def _thresholding(qc_value, thresholds=None): if qc_value is None or np.isnan(qc_value): return int(-1) elif (qc_value > MAX_BOUND) or (qc_value < MIN_BOUND): - raise ValueError("Values out of bound") + raise ValueError('Values out of bound') if 'PASS' in thresholds.keys() and qc_value >= thresholds['PASS']: return 0 if 'WARNING' in thresholds.keys() and qc_value >= thresholds['WARNING']: @@ -151,7 +151,7 @@ def compute(self, **kwargs): if self.extractor is None: kwargs['download_data'] = kwargs.pop('download_data', self.download_data) self.load_data(**kwargs) - self.log.info(f"Session {self.session_path}: Running QC on behavior data...") + self.log.info(f'Session {self.session_path}: Running QC on behavior data...') self.metrics, self.passed = get_bpodqc_metrics_frame( self.extractor.data, wheel_gain=self.extractor.settings['STIM_GAIN'], # The wheel gain @@ -229,7 +229,7 @@ def compute(self, download_data=None): # If download_data is None, decide based on whether eid or session path was provided ensure_data = self.download_data if download_data is None else download_data self.load_data(download_data=ensure_data) - self.log.info(f"Session {self.session_path}: Running QC on habituation data...") + self.log.info(f'Session {self.session_path}: Running QC on habituation data...') # Initialize checks prefix = '_task_' @@ -274,16 +274,16 @@ def compute(self, download_data=None): # Check event orders: trial_start < stim on < stim center < feedback < stim off check = prefix + 'trial_event_sequence' nans = ( - np.isnan(data["intervals"][:, 0]) | # noqa - np.isnan(data["stimOn_times"]) | # noqa - np.isnan(data["stimCenter_times"]) | - np.isnan(data["valveOpen_times"]) | # noqa - np.isnan(data["stimOff_times"]) + np.isnan(data['intervals'][:, 0]) | # noqa + np.isnan(data['stimOn_times']) | # noqa + np.isnan(data['stimCenter_times']) | + np.isnan(data['valveOpen_times']) | # noqa + np.isnan(data['stimOff_times']) ) - a = np.less(data["intervals"][:, 0], data["stimOn_times"], where=~nans) - b = np.less(data["stimOn_times"], data["stimCenter_times"], where=~nans) - c = np.less(data["stimCenter_times"], data["valveOpen_times"], where=~nans) - d = np.less(data["valveOpen_times"], data["stimOff_times"], where=~nans) + a = np.less(data['intervals'][:, 0], data['stimOn_times'], where=~nans) + b = np.less(data['stimOn_times'], data['stimCenter_times'], where=~nans) + c = np.less(data['stimCenter_times'], data['valveOpen_times'], where=~nans) + d = np.less(data['valveOpen_times'], data['stimOff_times'], where=~nans) metrics[check] = a & b & c & d & ~nans passed[check] = metrics[check].astype(float) @@ -291,7 +291,7 @@ def compute(self, download_data=None): # Check that the time difference between the visual stimulus center-command being # triggered and the stimulus effectively appearing in the center is smaller than 150 ms. check = prefix + 'stimCenter_delays' - metric = np.nan_to_num(data["stimCenter_times"] - data["stimCenterTrigger_times"], + metric = np.nan_to_num(data['stimCenter_times'] - data['stimCenterTrigger_times'], nan=np.inf) passed[check] = (metric <= 0.15) & (metric > 0) metrics[check] = metric @@ -375,9 +375,9 @@ def check_stimOn_goCue_delays(data, **_): """ # Calculate the difference between stimOn and goCue times. # If either are NaN, the result will be Inf to ensure that it crosses the failure threshold. - metric = np.nan_to_num(data["goCue_times"] - data["stimOn_times"], nan=np.inf) + metric = np.nan_to_num(data['goCue_times'] - data['stimOn_times'], nan=np.inf) passed = (metric < 0.01) & (metric > 0) - assert data["intervals"].shape[0] == len(metric) == len(passed) + assert data['intervals'].shape[0] == len(metric) == len(passed) return metric, passed @@ -391,9 +391,9 @@ def check_response_feedback_delays(data, **_): :param data: dict of trial data with keys ('feedback_times', 'response_times', 'intervals') """ - metric = np.nan_to_num(data["feedback_times"] - data["response_times"], nan=np.inf) + metric = np.nan_to_num(data['feedback_times'] - data['response_times'], nan=np.inf) passed = (metric < 0.01) & (metric > 0) - assert data["intervals"].shape[0] == len(metric) == len(passed) + assert data['intervals'].shape[0] == len(metric) == len(passed) return metric, passed @@ -410,13 +410,13 @@ def check_response_stimFreeze_delays(data, **_): """ # Calculate the difference between stimOn and goCue times. # If either are NaN, the result will be Inf to ensure that it crosses the failure threshold. - metric = np.nan_to_num(data["stimFreeze_times"] - data["response_times"], nan=np.inf) + metric = np.nan_to_num(data['stimFreeze_times'] - data['response_times'], nan=np.inf) # Test for valid values passed = ((metric < 0.1) & (metric > 0)).astype(float) # Finally remove no_go trials (stimFreeze triggered differently in no_go trials) # These values are ignored in calculation of proportion passed - passed[data["choice"] == 0] = np.nan - assert data["intervals"].shape[0] == len(metric) == len(passed) + passed[data['choice'] == 0] = np.nan + assert data['intervals'].shape[0] == len(metric) == len(passed) return metric, passed @@ -431,12 +431,12 @@ def check_stimOff_itiIn_delays(data, **_): 'choice') """ # If either are NaN, the result will be Inf to ensure that it crosses the failure threshold. - metric = np.nan_to_num(data["itiIn_times"] - data["stimOff_times"], nan=np.inf) + metric = np.nan_to_num(data['itiIn_times'] - data['stimOff_times'], nan=np.inf) passed = ((metric < 0.01) & (metric >= 0)).astype(float) # Remove no_go trials (stimOff triggered differently in no_go trials) # NaN values are ignored in calculation of proportion passed - metric[data["choice"] == 0] = passed[data["choice"] == 0] = np.nan - assert data["intervals"].shape[0] == len(metric) == len(passed) + metric[data['choice'] == 0] = passed[data['choice'] == 0] = np.nan + assert data['intervals'].shape[0] == len(metric) == len(passed) return metric, passed @@ -451,14 +451,14 @@ def check_iti_delays(data, **_): :param data: dict of trial data with keys ('stimOff_times', 'intervals') """ # Initialize array the length of completed trials - metric = np.full(data["intervals"].shape[0], np.nan) + metric = np.full(data['intervals'].shape[0], np.nan) passed = metric.copy() # Get the difference between stim off and the start of the next trial # Missing data are set to Inf, except for the last trial which is a NaN metric[:-1] = \ - np.nan_to_num(data["intervals"][1:, 0] - data["stimOff_times"][:-1] - 0.5, nan=np.inf) + np.nan_to_num(data['intervals'][1:, 0] - data['stimOff_times'][:-1] - 0.5, nan=np.inf) passed[:-1] = np.abs(metric[:-1]) < .5 # Last trial is not counted - assert data["intervals"].shape[0] == len(metric) == len(passed) + assert data['intervals'].shape[0] == len(metric) == len(passed) return metric, passed @@ -474,11 +474,11 @@ def check_positive_feedback_stimOff_delays(data, **_): 'correct') """ # If either are NaN, the result will be Inf to ensure that it crosses the failure threshold. - metric = np.nan_to_num(data["stimOff_times"] - data["feedback_times"] - 1, nan=np.inf) + metric = np.nan_to_num(data['stimOff_times'] - data['feedback_times'] - 1, nan=np.inf) passed = (np.abs(metric) < 0.15).astype(float) # NaN values are ignored in calculation of proportion passed; ignore incorrect trials here - metric[~data["correct"]] = passed[~data["correct"]] = np.nan - assert data["intervals"].shape[0] == len(metric) == len(passed) + metric[~data['correct']] = passed[~data['correct']] = np.nan + assert data['intervals'].shape[0] == len(metric) == len(passed) return metric, passed @@ -492,12 +492,12 @@ def check_negative_feedback_stimOff_delays(data, **_): :param data: dict of trial data with keys ('stimOff_times', 'errorCue_times', 'intervals') """ - metric = np.nan_to_num(data["stimOff_times"] - data["errorCue_times"] - 2, nan=np.inf) + metric = np.nan_to_num(data['stimOff_times'] - data['errorCue_times'] - 2, nan=np.inf) # Apply criteria passed = (np.abs(metric) < 0.15).astype(float) # Remove none negative feedback trials - metric[data["correct"]] = passed[data["correct"]] = np.nan - assert data["intervals"].shape[0] == len(metric) == len(passed) + metric[data['correct']] = passed[data['correct']] = np.nan + assert data['intervals'].shape[0] == len(metric) == len(passed) return metric, passed @@ -515,12 +515,12 @@ def check_wheel_move_before_feedback(data, **_): """ # Get tuple of wheel times and positions within 100ms of feedback traces = traces_by_trial( - data["wheel_timestamps"], - data["wheel_position"], - start=data["feedback_times"] - 0.05, - end=data["feedback_times"] + 0.05, + data['wheel_timestamps'], + data['wheel_position'], + start=data['feedback_times'] - 0.05, + end=data['feedback_times'] + 0.05, ) - metric = np.zeros_like(data["feedback_times"]) + metric = np.zeros_like(data['feedback_times']) # For each trial find the displacement for i, trial in enumerate(traces): pos = trial[1] @@ -528,12 +528,12 @@ def check_wheel_move_before_feedback(data, **_): metric[i] = pos[-1] - pos[0] # except no-go trials - metric[data["choice"] == 0] = np.nan # NaN = trial ignored for this check + metric[data['choice'] == 0] = np.nan # NaN = trial ignored for this check nans = np.isnan(metric) passed = np.zeros_like(metric) * np.nan passed[~nans] = (metric[~nans] != 0).astype(float) - assert data["intervals"].shape[0] == len(metric) == len(passed) + assert data['intervals'].shape[0] == len(metric) == len(passed) return metric, passed @@ -555,15 +555,15 @@ def _wheel_move_during_closed_loop(re_ts, re_pos, data, wheel_gain=None, tol=1, :param tol: the criterion in visual degrees """ if wheel_gain is None: - _log.warning("No wheel_gain input in function call, returning None") + _log.warning('No wheel_gain input in function call, returning None') return None, None # Get tuple of wheel times and positions over each trial's closed-loop period traces = traces_by_trial(re_ts, re_pos, - start=data["goCueTrigger_times"], - end=data["response_times"]) + start=data['goCueTrigger_times'], + end=data['response_times']) - metric = np.zeros_like(data["feedback_times"]) + metric = np.zeros_like(data['feedback_times']) # For each trial find the absolute displacement for i, trial in enumerate(traces): t, pos = trial @@ -574,16 +574,16 @@ def _wheel_move_during_closed_loop(re_ts, re_pos, data, wheel_gain=None, tol=1, metric[i] = np.abs(pos - origin).max() # Load wheel_gain and thresholds for each trial - wheel_gain = np.array([wheel_gain] * len(data["position"])) - thresh = data["position"] + wheel_gain = np.array([wheel_gain] * len(data['position'])) + thresh = data['position'] # abs displacement, s, in mm required to move 35 visual degrees s_mm = np.abs(thresh / wheel_gain) # don't care about direction criterion = cm_to_rad(s_mm * 1e-1) # convert abs displacement to radians (wheel pos is in rad) metric = metric - criterion # difference should be close to 0 rad_per_deg = cm_to_rad(1 / wheel_gain * 1e-1) passed = (np.abs(metric) < rad_per_deg * tol).astype(float) # less than 1 visual degree off - metric[data["choice"] == 0] = passed[data["choice"] == 0] = np.nan # except no-go trials - assert data["intervals"].shape[0] == len(metric) == len(passed) + metric[data['choice'] == 0] = passed[data['choice'] == 0] = np.nan # except no-go trials + assert data['intervals'].shape[0] == len(metric) == len(passed) return metric, passed @@ -642,25 +642,25 @@ def check_wheel_freeze_during_quiescence(data, **_): :param data: dict of trial data with keys ('wheel_timestamps', 'wheel_position', 'quiescence', 'intervals', 'stimOnTrigger_times') """ - assert np.all(np.diff(data["wheel_timestamps"]) >= 0) - assert data["quiescence"].size == data["stimOnTrigger_times"].size + assert np.all(np.diff(data['wheel_timestamps']) >= 0) + assert data['quiescence'].size == data['stimOnTrigger_times'].size # Get tuple of wheel times and positions over each trial's quiescence period - qevt_start_times = data["stimOnTrigger_times"] - data["quiescence"] + qevt_start_times = data['stimOnTrigger_times'] - data['quiescence'] traces = traces_by_trial( - data["wheel_timestamps"], - data["wheel_position"], + data['wheel_timestamps'], + data['wheel_position'], start=qevt_start_times, - end=data["stimOnTrigger_times"] + end=data['stimOnTrigger_times'] ) - metric = np.zeros((len(data["quiescence"]), 2)) # (n_trials, n_directions) + metric = np.zeros((len(data['quiescence']), 2)) # (n_trials, n_directions) for i, trial in enumerate(traces): t, pos = trial # Get the last position before the period began if pos.size > 0: # Find the position of the preceding sample and subtract it - idx = np.abs(data["wheel_timestamps"] - t[0]).argmin() - 1 - origin = data["wheel_position"][idx if idx != -1 else 0] + idx = np.abs(data['wheel_timestamps'] - t[0]).argmin() - 1 + origin = data['wheel_position'][idx if idx != -1 else 0] # Find the absolute min and max relative to the last sample metric[i, :] = np.abs([np.min(pos - origin), np.max(pos - origin)]) # Reduce to the largest displacement found in any direction @@ -668,7 +668,7 @@ def check_wheel_freeze_during_quiescence(data, **_): metric = 180 * metric / np.pi # convert to degrees from radians criterion = 2 # Position shouldn't change more than 2 in either direction passed = metric < criterion - assert data["intervals"].shape[0] == len(metric) == len(passed) + assert data['intervals'].shape[0] == len(metric) == len(passed) return metric, passed @@ -685,8 +685,8 @@ def check_detected_wheel_moves(data, min_qt=0, **_): """ # Depending on task version this may be a single value or an array of quiescent periods min_qt = np.array(min_qt) - if min_qt.size > data["intervals"].shape[0]: - min_qt = min_qt[:data["intervals"].shape[0]] + if min_qt.size > data['intervals'].shape[0]: + min_qt = min_qt[:data['intervals'].shape[0]] metric = data['firstMovement_times'] qevt_start = data['goCueTrigger_times'] - np.array(min_qt) @@ -714,25 +714,25 @@ def check_error_trial_event_sequence(data, **_): """ # An array the length of N trials where True means at least one event time was NaN (bad) nans = ( - np.isnan(data["intervals"][:, 0]) | - np.isnan(data["goCue_times"]) | # noqa - np.isnan(data["errorCue_times"]) | # noqa - np.isnan(data["itiIn_times"]) | # noqa - np.isnan(data["intervals"][:, 1]) + np.isnan(data['intervals'][:, 0]) | + np.isnan(data['goCue_times']) | # noqa + np.isnan(data['errorCue_times']) | # noqa + np.isnan(data['itiIn_times']) | # noqa + np.isnan(data['intervals'][:, 1]) ) # For each trial check that the events happened in the correct order (ignore NaN values) - a = np.less(data["intervals"][:, 0], data["goCue_times"], where=~nans) # Start time < go cue - b = np.less(data["goCue_times"], data["errorCue_times"], where=~nans) # Go cue < error cue - c = np.less(data["errorCue_times"], data["itiIn_times"], where=~nans) # Error cue < ITI start - d = np.less(data["itiIn_times"], data["intervals"][:, 1], where=~nans) # ITI start < end time + a = np.less(data['intervals'][:, 0], data['goCue_times'], where=~nans) # Start time < go cue + b = np.less(data['goCue_times'], data['errorCue_times'], where=~nans) # Go cue < error cue + c = np.less(data['errorCue_times'], data['itiIn_times'], where=~nans) # Error cue < ITI start + d = np.less(data['itiIn_times'], data['intervals'][:, 1], where=~nans) # ITI start < end time # For each trial check all events were in order AND all event times were not NaN metric = a & b & c & d & ~nans passed = metric.astype(float) - passed[data["correct"]] = np.nan # Look only at incorrect trials - assert data["intervals"].shape[0] == len(metric) == len(passed) + passed[data['correct']] = np.nan # Look only at incorrect trials + assert data['intervals'].shape[0] == len(metric) == len(passed) return metric, passed @@ -749,25 +749,25 @@ def check_correct_trial_event_sequence(data, **_): """ # An array the length of N trials where True means at least one event time was NaN (bad) nans = ( - np.isnan(data["intervals"][:, 0]) | - np.isnan(data["goCue_times"]) | # noqa - np.isnan(data["valveOpen_times"]) | - np.isnan(data["itiIn_times"]) | # noqa - np.isnan(data["intervals"][:, 1]) + np.isnan(data['intervals'][:, 0]) | + np.isnan(data['goCue_times']) | # noqa + np.isnan(data['valveOpen_times']) | + np.isnan(data['itiIn_times']) | # noqa + np.isnan(data['intervals'][:, 1]) ) # For each trial check that the events happened in the correct order (ignore NaN values) - a = np.less(data["intervals"][:, 0], data["goCue_times"], where=~nans) # Start time < go cue - b = np.less(data["goCue_times"], data["valveOpen_times"], where=~nans) # Go cue < feedback - c = np.less(data["valveOpen_times"], data["itiIn_times"], where=~nans) # Feedback < ITI start - d = np.less(data["itiIn_times"], data["intervals"][:, 1], where=~nans) # ITI start < end time + a = np.less(data['intervals'][:, 0], data['goCue_times'], where=~nans) # Start time < go cue + b = np.less(data['goCue_times'], data['valveOpen_times'], where=~nans) # Go cue < feedback + c = np.less(data['valveOpen_times'], data['itiIn_times'], where=~nans) # Feedback < ITI start + d = np.less(data['itiIn_times'], data['intervals'][:, 1], where=~nans) # ITI start < end time # For each trial True means all events were in order AND all event times were not NaN metric = a & b & c & d & ~nans passed = metric.astype(float) - passed[~data["correct"]] = np.nan # Look only at correct trials - assert data["intervals"].shape[0] == len(metric) == len(passed) + passed[~data['correct']] = np.nan # Look only at correct trials + assert data['intervals'].shape[0] == len(metric) == len(passed) return metric, passed @@ -799,7 +799,7 @@ def check_n_trial_events(data, **_): 'wheel_moves_peak_amplitude', 'wheel_moves_intervals', 'wheel_timestamps', 'wheel_intervals', 'stimFreeze_times'] events = [k for k in data.keys() if k.endswith('_times') and k not in exclude] - metric = np.zeros(data["intervals"].shape[0], dtype=bool) + metric = np.zeros(data['intervals'].shape[0], dtype=bool) # For each trial interval check that one of each trial event occurred. For incorrect trials, # check the error cue trigger occurred within the interval, otherwise check it is nan. @@ -822,9 +822,9 @@ def check_trial_length(data, **_): :param data: dict of trial data with keys ('feedback_times', 'goCue_times', 'intervals') """ # NaN values are usually ignored so replace them with Inf so they fail the threshold - metric = np.nan_to_num(data["feedback_times"] - data["goCue_times"], nan=np.inf) + metric = np.nan_to_num(data['feedback_times'] - data['goCue_times'], nan=np.inf) passed = (metric < 60.1) & (metric > 0) - assert data["intervals"].shape[0] == len(metric) == len(passed) + assert data['intervals'].shape[0] == len(metric) == len(passed) return metric, passed @@ -835,14 +835,14 @@ def check_goCue_delays(data, **_): effectively played is smaller than 1ms. Metric: M = goCue_times - goCueTrigger_times - Criterion: 0 < M <= 0.001 s + Criterion: 0 < M <= 0.0015 s Units: seconds [s] :param data: dict of trial data with keys ('goCue_times', 'goCueTrigger_times', 'intervals') """ - metric = np.nan_to_num(data["goCue_times"] - data["goCueTrigger_times"], nan=np.inf) + metric = np.nan_to_num(data['goCue_times'] - data['goCueTrigger_times'], nan=np.inf) passed = (metric <= 0.0015) & (metric > 0) - assert data["intervals"].shape[0] == len(metric) == len(passed) + assert data['intervals'].shape[0] == len(metric) == len(passed) return metric, passed @@ -850,16 +850,16 @@ def check_errorCue_delays(data, **_): """ Check that the time difference between the error sound being triggered and effectively played is smaller than 1ms. Metric: M = errorCue_times - errorCueTrigger_times - Criterion: 0 < M <= 0.001 s + Criterion: 0 < M <= 0.0015 s Units: seconds [s] :param data: dict of trial data with keys ('errorCue_times', 'errorCueTrigger_times', 'intervals', 'correct') """ - metric = np.nan_to_num(data["errorCue_times"] - data["errorCueTrigger_times"], nan=np.inf) + metric = np.nan_to_num(data['errorCue_times'] - data['errorCueTrigger_times'], nan=np.inf) passed = ((metric <= 0.0015) & (metric > 0)).astype(float) - passed[data["correct"]] = metric[data["correct"]] = np.nan - assert data["intervals"].shape[0] == len(metric) == len(passed) + passed[data['correct']] = metric[data['correct']] = np.nan + assert data['intervals'].shape[0] == len(metric) == len(passed) return metric, passed @@ -868,15 +868,15 @@ def check_stimOn_delays(data, **_): and the stimulus effectively appearing on the screen is smaller than 150 ms. Metric: M = stimOn_times - stimOnTrigger_times - Criterion: 0 < M < 0.150 s + Criterion: 0 < M < 0.15 s Units: seconds [s] :param data: dict of trial data with keys ('stimOn_times', 'stimOnTrigger_times', 'intervals') """ - metric = np.nan_to_num(data["stimOn_times"] - data["stimOnTrigger_times"], nan=np.inf) + metric = np.nan_to_num(data['stimOn_times'] - data['stimOnTrigger_times'], nan=np.inf) passed = (metric <= 0.15) & (metric > 0) - assert data["intervals"].shape[0] == len(metric) == len(passed) + assert data['intervals'].shape[0] == len(metric) == len(passed) return metric, passed @@ -886,15 +886,15 @@ def check_stimOff_delays(data, **_): is smaller than 150 ms. Metric: M = stimOff_times - stimOffTrigger_times - Criterion: 0 < M < 0.150 s + Criterion: 0 < M < 0.15 s Units: seconds [s] :param data: dict of trial data with keys ('stimOff_times', 'stimOffTrigger_times', 'intervals') """ - metric = np.nan_to_num(data["stimOff_times"] - data["stimOffTrigger_times"], nan=np.inf) + metric = np.nan_to_num(data['stimOff_times'] - data['stimOffTrigger_times'], nan=np.inf) passed = (metric <= 0.15) & (metric > 0) - assert data["intervals"].shape[0] == len(metric) == len(passed) + assert data['intervals'].shape[0] == len(metric) == len(passed) return metric, passed @@ -904,15 +904,15 @@ def check_stimFreeze_delays(data, **_): is smaller than 150 ms. Metric: M = stimFreeze_times - stimFreezeTrigger_times - Criterion: 0 < M < 0.150 s + Criterion: 0 < M < 0.15 s Units: seconds [s] :param data: dict of trial data with keys ('stimFreeze_times', 'stimFreezeTrigger_times', 'intervals') """ - metric = np.nan_to_num(data["stimFreeze_times"] - data["stimFreezeTrigger_times"], nan=np.inf) + metric = np.nan_to_num(data['stimFreeze_times'] - data['stimFreezeTrigger_times'], nan=np.inf) passed = (metric <= 0.15) & (metric > 0) - assert data["intervals"].shape[0] == len(metric) == len(passed) + assert data['intervals'].shape[0] == len(metric) == len(passed) return metric, passed @@ -934,7 +934,7 @@ def check_reward_volumes(data, **_): passed[correct] = (1.5 <= metric[correct]) & (metric[correct] <= 3.) # Check incorrect trials are 0 passed[~correct] = metric[~correct] == 0 - assert data["intervals"].shape[0] == len(metric) == len(passed) + assert data['intervals'].shape[0] == len(metric) == len(passed) return metric, passed @@ -946,7 +946,7 @@ def check_reward_volume_set(data, **_): :param data: dict of trial data with keys ('rewardVolume') """ - metric = data["rewardVolume"] + metric = data['rewardVolume'] passed = 0 < len(set(metric)) <= 2 and 0. in metric return metric, passed @@ -994,19 +994,19 @@ def check_stimulus_move_before_goCue(data, photodiode=None, **_): :param photodiode: the fronts from Bpod's BNC1 input or FPGA frame2ttl channel """ if photodiode is None: - _log.warning("No photodiode TTL input in function call, returning None") + _log.warning('No photodiode TTL input in function call, returning None') return None photodiode_clean = ephys_fpga._clean_frame2ttl(photodiode) - s = photodiode_clean["times"] + s = photodiode_clean['times'] s = s[~np.isnan(s)] # Remove NaNs metric = np.array([]) - for i, c in zip(data["intervals"][:, 0], data["goCue_times"]): + for i, c in zip(data['intervals'][:, 0], data['goCue_times']): metric = np.append(metric, np.count_nonzero(s[s > i] < (c - 0.02))) passed = (metric == 0).astype(float) # Remove no go trials - passed[data["choice"] == 0] = np.nan - assert data["intervals"].shape[0] == len(metric) == len(passed) + passed[data['choice'] == 0] = np.nan + assert data['intervals'].shape[0] == len(metric) == len(passed) return metric, passed @@ -1022,12 +1022,12 @@ def check_audio_pre_trial(data, audio=None, **_): :param audio: the fronts from Bpod's BNC2 input FPGA audio sync channel """ if audio is None: - _log.warning("No BNC2 input in function call, retuning None") + _log.warning('No BNC2 input in function call, retuning None') return None - s = audio["times"][~np.isnan(audio["times"])] # Audio TTLs with NaNs removed + s = audio['times'][~np.isnan(audio['times'])] # Audio TTLs with NaNs removed metric = np.array([], dtype=np.int8) - for i, c in zip(data["intervals"][:, 0], data["goCue_times"]): + for i, c in zip(data['intervals'][:, 0], data['goCue_times']): metric = np.append(metric, sum(s[s > i] < (c - 0.02))) passed = metric == 0 - assert data["intervals"].shape[0] == len(metric) == len(passed) + assert data['intervals'].shape[0] == len(metric) == len(passed) return metric, passed From b567a3c5fc059bfed57d3464856ab4a442d98814 Mon Sep 17 00:00:00 2001 From: Mayo Faulkner Date: Tue, 10 Oct 2023 12:03:22 +0100 Subject: [PATCH 17/68] remove behavior flag, infer from files --- ibllib/io/extractors/camera.py | 3 +-- ibllib/io/extractors/video_motion.py | 11 +++++++---- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/ibllib/io/extractors/camera.py b/ibllib/io/extractors/camera.py index 4bcb0699c..e287fe58c 100644 --- a/ibllib/io/extractors/camera.py +++ b/ibllib/io/extractors/camera.py @@ -153,8 +153,7 @@ def _extract(self, sync=None, chmap=None, video_path=None, sync_label='audio', _logger.warning('Attempting to align using wheel') try: - motion_class = vmotion.MotionAlignmentFullSession(self.session_path, self.label, behavior=False, - upload=True) + motion_class = vmotion.MotionAlignmentFullSession(self.session_path, self.label, upload=True) new_times = motion_class.process() if not motion_class.qc_outcome: raise ValueError(f'Wheel alignment failed to pass qc: {motion_class.qc}') diff --git a/ibllib/io/extractors/video_motion.py b/ibllib/io/extractors/video_motion.py index af188c0d8..8c62bfcc8 100644 --- a/ibllib/io/extractors/video_motion.py +++ b/ibllib/io/extractors/video_motion.py @@ -29,6 +29,7 @@ from brainbox.behavior.dlc import likelihood_threshold, get_speed from brainbox.task.trials import find_trial_ids import one.alf.io as alfio +from one.alf.exceptions import ALFObjectNotFound from one.alf.spec import is_session_path, is_uuid_string @@ -400,19 +401,18 @@ def __init__(self, session_path, label, **kwargs): self.session_path = session_path self.label = label self.threshold = kwargs.get('threshold', 20) - self.behavior = kwargs.get('behavior', False) self.upload = kwargs.get('upload', False) self.twin = kwargs.get('twin', 150) self.nprocess = kwargs.get('nprocess', int(cpu_count() - cpu_count() / 4)) - self.load_data(sync=kwargs.get('sync', 'nidq'), location=kwargs.get('location', None), behavior=self.behavior) + self.load_data(sync=kwargs.get('sync', 'nidq'), location=kwargs.get('location', None)) self.roi, self.mask = self.get_roi_mask() if self.upload: self.one = ONE(mode='remote') self.eid = self.one.path2eid(self.session_path) - def load_data(self, sync='nidq', location=None, behavior=False): + def load_data(self, sync='nidq', location=None): def fix_keys(alf_object): ob = Bunch() for key in alf_object.keys(): @@ -454,10 +454,13 @@ def fix_keys(alf_object): self.frate = round(1 / np.nanmedian(np.diff(self.ttl_times))) - if behavior: + try: self.trials = alfio.load_file_content(next(alf_path.glob('_ibl_trials.table.*.pqt'))) self.dlc = alfio.load_file_content(next(alf_path.glob(f'_ibl_{self.label}Camera.dlc.*.pqt'))) self.dlc = likelihood_threshold(self.dlc) + self.behavior = True + except ALFObjectNotFound: + self.behavior = False self.frame_example = vidio.get_video_frames_preload(self.camera_path, np.arange(10, 11), mask=np.s_[:, :, 0]) From 69972c3393a8d21798b479da455afa2839adf87a Mon Sep 17 00:00:00 2001 From: Mayo Faulkner Date: Tue, 10 Oct 2023 12:27:12 +0100 Subject: [PATCH 18/68] restrict to left and right cameras --- ibllib/io/extractors/camera.py | 4 ++++ ibllib/io/extractors/video_motion.py | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/ibllib/io/extractors/camera.py b/ibllib/io/extractors/camera.py index e287fe58c..bf3c95528 100644 --- a/ibllib/io/extractors/camera.py +++ b/ibllib/io/extractors/camera.py @@ -153,6 +153,10 @@ def _extract(self, sync=None, chmap=None, video_path=None, sync_label='audio', _logger.warning('Attempting to align using wheel') try: + if self.label not in ['left', 'right']: + # Can only use wheel alignment for left and right cameras + raise ValueError(f'Wheel alignment not supported for {self.label} camera') + motion_class = vmotion.MotionAlignmentFullSession(self.session_path, self.label, upload=True) new_times = motion_class.process() if not motion_class.qc_outcome: diff --git a/ibllib/io/extractors/video_motion.py b/ibllib/io/extractors/video_motion.py index 8c62bfcc8..ec085d2c4 100644 --- a/ibllib/io/extractors/video_motion.py +++ b/ibllib/io/extractors/video_motion.py @@ -459,7 +459,7 @@ def fix_keys(alf_object): self.dlc = alfio.load_file_content(next(alf_path.glob(f'_ibl_{self.label}Camera.dlc.*.pqt'))) self.dlc = likelihood_threshold(self.dlc) self.behavior = True - except ALFObjectNotFound: + except (ALFObjectNotFound, StopIteration): self.behavior = False self.frame_example = vidio.get_video_frames_preload(self.camera_path, np.arange(10, 11), mask=np.s_[:, :, 0]) From 659efcfa659fe991289936aa008e80a76e51fd3f Mon Sep 17 00:00:00 2001 From: Mayo Faulkner Date: Tue, 10 Oct 2023 12:33:05 +0100 Subject: [PATCH 19/68] fix glob pattern --- ibllib/io/extractors/video_motion.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ibllib/io/extractors/video_motion.py b/ibllib/io/extractors/video_motion.py index ec085d2c4..7cb3bccbb 100644 --- a/ibllib/io/extractors/video_motion.py +++ b/ibllib/io/extractors/video_motion.py @@ -455,8 +455,8 @@ def fix_keys(alf_object): self.frate = round(1 / np.nanmedian(np.diff(self.ttl_times))) try: - self.trials = alfio.load_file_content(next(alf_path.glob('_ibl_trials.table.*.pqt'))) - self.dlc = alfio.load_file_content(next(alf_path.glob(f'_ibl_{self.label}Camera.dlc.*.pqt'))) + self.trials = alfio.load_file_content(next(alf_path.glob('_ibl_trials.table*.pqt'))) + self.dlc = alfio.load_file_content(next(alf_path.glob(f'_ibl_{self.label}Camera.dlc*.pqt'))) self.dlc = likelihood_threshold(self.dlc) self.behavior = True except (ALFObjectNotFound, StopIteration): From 2a66a1392558441007ce1b7a1c55cdc998243e24 Mon Sep 17 00:00:00 2001 From: Mayo Faulkner Date: Tue, 10 Oct 2023 13:55:13 +0100 Subject: [PATCH 20/68] authenticate alyx --- ibllib/io/extractors/video_motion.py | 1 + 1 file changed, 1 insertion(+) diff --git a/ibllib/io/extractors/video_motion.py b/ibllib/io/extractors/video_motion.py index 7cb3bccbb..981af8a18 100644 --- a/ibllib/io/extractors/video_motion.py +++ b/ibllib/io/extractors/video_motion.py @@ -410,6 +410,7 @@ def __init__(self, session_path, label, **kwargs): if self.upload: self.one = ONE(mode='remote') + self.one.alyx.authenticate() self.eid = self.one.path2eid(self.session_path) def load_data(self, sync='nidq', location=None): From 04733ac63f253a2b5e3bd28a6184e00e3966e693 Mon Sep 17 00:00:00 2001 From: Mayo Faulkner Date: Tue, 10 Oct 2023 16:17:49 +0100 Subject: [PATCH 21/68] save figure with label name --- ibllib/io/extractors/video_motion.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ibllib/io/extractors/video_motion.py b/ibllib/io/extractors/video_motion.py index 981af8a18..4d567b2d0 100644 --- a/ibllib/io/extractors/video_motion.py +++ b/ibllib/io/extractors/video_motion.py @@ -903,7 +903,8 @@ def process(self): if self.upload: fig = self.plot_with_behavior() if self.behavior else self.plot_without_behavior() - save_fig_path = Path(self.session_path.joinpath('snapshot', 'video', 'video_wheel_alignment.png')) + save_fig_path = Path(self.session_path.joinpath('snapshot', 'video', + f'video_wheel_alignment_{self.label}.png')) save_fig_path.parent.mkdir(exist_ok=True, parents=True) fig.savefig(save_fig_path) snp = ReportSnapshot(self.session_path, self.eid, content_type='session', one=self.one) From 10007ec8350cb378413123ae2ebda70688c5640b Mon Sep 17 00:00:00 2001 From: Mayo Faulkner Date: Wed, 11 Oct 2023 10:20:53 +0100 Subject: [PATCH 22/68] close figure and change processes --- ibllib/io/extractors/video_motion.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ibllib/io/extractors/video_motion.py b/ibllib/io/extractors/video_motion.py index 4d567b2d0..889dc0b19 100644 --- a/ibllib/io/extractors/video_motion.py +++ b/ibllib/io/extractors/video_motion.py @@ -881,7 +881,7 @@ def process(self): wg = WindowGenerator(all_me.size - 1, int(self.camera_meta['fps'] * self.twin), int(self.camera_meta['fps'] * toverlap)) - out = Parallel(n_jobs=self.nprocess)(delayed(self.compute_shifts)(times, all_me, first, last, iw, wg) + out = Parallel(n_jobs=4)(delayed(self.compute_shifts)(times, all_me, first, last, iw, wg) for iw, (first, last) in enumerate(wg.firstlast)) self.shifts = np.array([]) @@ -910,5 +910,6 @@ def process(self): snp = ReportSnapshot(self.session_path, self.eid, content_type='session', one=self.one) snp.outputs = [save_fig_path] snp.register_images(widths=['orig']) + plt.close(fig) return self.new_times From c8d6556db9832f336178b682182ef9691fbd4b18 Mon Sep 17 00:00:00 2001 From: Mayo Faulkner Date: Wed, 11 Oct 2023 12:24:03 +0100 Subject: [PATCH 23/68] doc strings to wheel alignment module --- ibllib/io/extractors/video_motion.py | 182 ++++++++++++++++++++++++++- 1 file changed, 177 insertions(+), 5 deletions(-) diff --git a/ibllib/io/extractors/video_motion.py b/ibllib/io/extractors/video_motion.py index 889dc0b19..14de3b3f7 100644 --- a/ibllib/io/extractors/video_motion.py +++ b/ibllib/io/extractors/video_motion.py @@ -398,6 +398,17 @@ def process_key(event): class MotionAlignmentFullSession: def __init__(self, session_path, label, **kwargs): + """ + Class to extract camera times using video motion energy wheel alignment + :param session_path: path of the session + :param label: video label, only 'left' and 'right' videos are supported + :param kwargs: threshold - the threshold to apply when identifying frames with artefacts (default 20) + upload - whether to upload summary figure to alyx (default False) + twin - the window length used when computing the shifts between the wheel and video + nprocesses - the number of CPU processes to use + sync - the type of sync scheme used (options 'nidq' or 'bpod') + location - whether the code is being run on SDSC or not (options 'SDSC' or None) + """ self.session_path = session_path self.label = label self.threshold = kwargs.get('threshold', 20) @@ -414,7 +425,19 @@ def __init__(self, session_path, label, **kwargs): self.eid = self.one.path2eid(self.session_path) def load_data(self, sync='nidq', location=None): + """ + Loads relevant data from disk to perform motion alignment + :param sync: type of sync used, 'nidq' or 'bpod' + :param location: where the code is being run, if location='SDSC', the dataset uuids are removed + when loading the data + :return: + """ def fix_keys(alf_object): + """ + Given an alf object removes the dataset uuid from the keys + :param alf_object: + :return: + """ ob = Bunch() for key in alf_object.keys(): vals = alf_object[key] @@ -422,39 +445,54 @@ def fix_keys(alf_object): return ob alf_path = self.session_path.joinpath('alf') + # Load in wheel data wheel = (fix_keys(alfio.load_object(alf_path, 'wheel')) if location == 'SDSC' else alfio.load_object(alf_path, 'wheel')) self.wheel_timestamps = wheel.timestamps + # Compute interpolated wheel position and wheel times wheel_pos, self.wheel_time = wh.interpolate_position(wheel.timestamps, wheel.position, freq=1000) + # Compute wheel velocity self.wheel_vel, _ = wh.velocity_filtered(wheel_pos, 1000) + # Load in original camera times self.camera_times = alfio.load_file_content(next(alf_path.glob(f'_ibl_{self.label}Camera.times*.npy'))) + # Find raw video file and load in the metadata self.camera_path = str(next(self.session_path.joinpath('raw_video_data').glob( f'_iblrig_{self.label}Camera.raw*.mp4'))) self.camera_meta = vidio.get_video_meta(self.camera_path) # TODO should read in the description file to get the correct sync location if sync == 'nidq': + # If the sync is 'nidq' we read in the camera ttls from the spikeglx sync object sync, chmap = get_sync_and_chn_map(self.session_path, sync_collection='raw_ephys_data') sr = get_sync_fronts(sync, chmap[f'{self.label}_camera']) self.ttls = sr.times[::2] else: + # Otherwise we assume the sync is 'bpod' and we read in the camera ttls from the raw bpod data cam_extractor = cam.CameraTimestampsBpod(session_path=self.session_path) cam_extractor.bpod_trials = raw.load_data(self.session_path, task_collection='raw_behavior_data') self.ttls = cam_extractor._times_from_bpod() + # Check if the ttl and video sizes match up self.tdiff = self.ttls.size - self.camera_meta['length'] if self.tdiff < 0: + # In this case there are fewer ttls than camera frames. This is not ideal, for now we pad the ttls with + # nans but if this is too many we reject the wheel alignment based on the qc self.ttl_times = self.ttls self.times = np.r_[self.ttl_times, np.full((np.abs(self.tdiff)), np.nan)] self.short_flag = True elif self.tdiff > 0: + # In this case there are more ttls than camera frames. This happens often, for now we remove the first + # tdiff ttls from the ttls self.ttl_times = self.ttls[self.tdiff:] self.times = self.ttls[self.tdiff:] self.short_flag = False + # Compute the frame rate of the camera self.frate = round(1 / np.nanmedian(np.diff(self.ttl_times))) + # We attempt to load in some behavior data (trials and dlc). This is only needed for the summary plots, having + # trial aligned paw velocity (from the dlc) is a nice sanity check to make sure the alignment went well try: self.trials = alfio.load_file_content(next(alf_path.glob('_ibl_trials.table*.pqt'))) self.dlc = alfio.load_file_content(next(alf_path.glob(f'_ibl_{self.label}Camera.dlc*.pqt'))) @@ -463,9 +501,15 @@ def fix_keys(alf_object): except (ALFObjectNotFound, StopIteration): self.behavior = False + # Load in a single frame that we will use for the summary plot self.frame_example = vidio.get_video_frames_preload(self.camera_path, np.arange(10, 11), mask=np.s_[:, :, 0]) def get_roi_mask(self): + """ + Compute the region of interest mask for a given camera. This corresponds to a box in the video that we will + use to compute the wheel motion energy + :return: + """ if self.label == 'right': roi = ((450, 512), (120, 200)) @@ -476,27 +520,53 @@ def get_roi_mask(self): return roi, roi_mask def find_contaminated_frames(self, video_frames, thresold=20, normalise=True): + """ + Finds frames in the video that have artefacts such as the mouse's paw or a human hand. In order to determine + frames with contamination an Otsu thresholding is applied to each frame to detect the artefact from the + background image + :param video_frames: np array of video frames (nframes, nwidth, nheight) + :param thresold: threshold to differentiate artefact from background + :param normalise: whether to normalise the threshold values for each frame to the baseline + :return: mask of frames that are contaminated + """ high = np.zeros((video_frames.shape[0])) + # Iterate through each frame and compute and store the otsu threshold value for each frame for idx, frame in enumerate(video_frames): ret, _ = cv2.threshold(cv2.GaussianBlur(frame, (5, 5), 0), 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) high[idx] = ret + # If normalise is True, we divide the threshold values for each frame by the minimum value if normalise: high -= np.min(high) + # Identify the frames that have a threshold value greater than the specified threshold cutoff contaminated_frames = np.where(high > thresold)[0] return contaminated_frames def compute_motion_energy(self, first, last, wg, iw): + """ + Computes the video motion energy for frame indexes between first and last. This function is written to be run + in a parallel fashion jusing joblib.parallel + :param first: first frame index of frame interval to consider + :param last: last frame index of frame interval to consider + :param wg: WindowGenerator + :param iw: iteration of the WindowGenerator + :return: + """ if iw == wg.nwin - 1: return + # Open the video and read in the relvant video frames between first idx and last idx cap = cv2.VideoCapture(self.camera_path) frames = vidio.get_video_frames_preload(cap, np.arange(first, last), mask=self.mask) + # Identify if any of the frames have artefacts in them idx = self.find_contaminated_frames(frames, self.threshold) + # If some of the frames are contaminated we find all the continuous intervals of contamination + # and set the value for contaminated pixels for these frames to the average of the first frame before and after + # this contamination interval if len(idx) != 0: before_status = False @@ -504,6 +574,9 @@ def compute_motion_energy(self, first, last, wg, iw): counter = 0 n_frames = 200 + # If it is the first frame that is contaminated, we need to read in a bit more of the video to find a + # frame prior to contamination. We attempt this 20 times, after that we just take the value for the first + # frame while np.any(idx == 0) and counter < 20 and iw != 0: n_before_offset = (counter + 1) * n_frames first -= n_frames @@ -518,6 +591,9 @@ def compute_motion_energy(self, first, last, wg, iw): print(f'In before: {counter}') counter = 0 + # If it is the last frame that is contaminated, we need to read in a bit more of the video to find a + # frame after the contamination. We attempt this 20 times, after that we just take the value for the last + # frame while np.any(idx == frames.shape[0] - 1) and counter < 20 and iw != wg.nwin - 1: n_after_offset = (counter + 1) * n_frames last += n_frames @@ -531,6 +607,8 @@ def compute_motion_energy(self, first, last, wg, iw): if counter > 0: print(f'In after: {counter}') + # We find all the continuous intervals that contain contamination and fix the affected pixels + # by taking the average value of the frame prior and after contamination intervals = np.split(idx, np.where(np.diff(idx) != 1)[0] + 1) for ints in intervals: if len(ints) > 0 and ints[0] == 0: @@ -538,22 +616,28 @@ def compute_motion_energy(self, first, last, wg, iw): if len(ints) > 0 and ints[-1] == frames.shape[0] - 1: ints = ints[:-1] th_all = np.zeros_like(frames[0]) + # We find all affected pixels for idx in ints: img = np.copy(frames[idx]) blur = cv2.GaussianBlur(img, (5, 5), 0) ret, th = cv2.threshold(blur, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) th = cv2.GaussianBlur(th, (5, 5), 10) th_all += th + # Compute the average image of the frame prior and after the interval vals = np.mean(np.dstack([frames[ints[0] - 1], frames[ints[-1] + 1]]), axis=-1) + # For each frame set the affected pixels to the value of the clean average image for idx in ints: img = frames[idx] img[th_all > 0] = vals[th_all > 0] + # If we have read in extra video frames we need to cut these off and make sure we only + # consider the frames between the interval first and last given as args if before_status: frames = frames[n_before_offset:] if after_status: frames = frames[:(-1 * n_after_offset)] + # Once the frames have been cleaned we compute the motion energy between frames frame_me, _ = video.motion_energy(frames, diff=2, normalize=False) cap.release() @@ -561,32 +645,54 @@ def compute_motion_energy(self, first, last, wg, iw): return frame_me[2:] def compute_shifts(self, times, me, first, last, iw, wg): + """ + Compute the cross-correlation between the video motion energy and the wheel velocity to find the mismatch + between the camera ttls and the video frames. This function is written to run in a parallel manner using + joblib.parallel + + :param times: the times of the video frames across the whole session (ttls) + :param me: the video motion energy computed across the whole session + :param first: first time idx to consider + :param last: last time idx to consider + :param wg: WindowGenerator + :param iw: iteration of the WindowGenerator + :return: + """ + # If we are in the last window we exit if iw == wg.nwin - 1: return np.nan, np.nan + + # Find the time interval we are interested in t_first = times[first] t_last = times[last] + + # If both times during this interval are nan exit if np.isnan(t_last) and np.isnan(t_first): return np.nan, np.nan + # If only the last time is nan, we find the last non nan time value elif np.isnan(t_last): t_last = times[np.where(~np.isnan(times))[0][-1]] + # Find the mask of timepoints that fall in this interval mask = np.logical_and(times >= t_first, times <= t_last) + # Restrict the video motion energy to this interval and normalise the values align_me = me[np.where(mask)[0]] align_me = (align_me - np.nanmin(align_me)) / (np.nanmax(align_me) - np.nanmin(align_me)) - # Find closest timepoints in wheel that match the camera times + # Find closest timepoints in wheel that match the time interval wh_mask = np.logical_and(self.wheel_time >= t_first, self.wheel_time <= t_last) if np.sum(wh_mask) == 0: return np.nan, np.nan + # Find the mask for the wheel times xs = np.searchsorted(self.wheel_time[wh_mask], times[mask]) xs[xs == np.sum(wh_mask)] = np.sum(wh_mask) - 1 # Convert to normalized speed vs = np.abs(self.wheel_vel[wh_mask][xs]) vs = (vs - np.min(vs)) / (np.max(vs) - np.min(vs)) + # Account for nan values in the video motion energy isnan = np.isnan(align_me) - if np.sum(isnan) > 0: where_nan = np.where(isnan)[0] assert where_nan[0] == 0 @@ -595,12 +701,22 @@ def compute_shifts(self, times, me, first, last, iw, wg): if np.all(isnan): return np.nan, np.nan + # Compute the cross correlation between the video motion energy and the wheel speed xcorr = signal.correlate(align_me[~isnan], vs[~isnan]) + # The max value of the cross correlation indicates the shift that needs to be applied + # The +2 comes from the fact that the video motion energy was computed from the difference between frames shift = np.nanargmax(xcorr) - align_me[~isnan].size + 2 return shift, t_first + (t_last - t_first) / 2 def clean_shifts(self, x, n=1): + """ + Removes artefacts from the computed shifts across time. We assume that the shifts should never increase + over time and that the jump between consecutive shifts shouldn't be greater than 1 + :param x: computed shifts + :param n: condition to apply + :return: + """ y = x.copy() dy = np.diff(y, prepend=y[0]) while True: @@ -625,6 +741,17 @@ def clean_shifts(self, x, n=1): return np.cumsum(dy) + y[0] def qc_shifts(self, shifts, shifts_filt): + """ + Compute qc values for the wheel alignment. We consider 4 things + 1. The number of camera ttl values that are missing (when we have less ttls than video frames) + 2. The number of shifts that have nan values, this means the video motion energy computation + 3. The number of large jumps (>10) between the computed shifts + 4. The number of jumps (>1) between the shifts after they have been cleaned + + :param shifts: np.array of shifts over session + :param shifts_filt: np.array of shifts after being cleaned over session + :return: + """ ttl_per = (np.abs(self.tdiff) / self.camera_meta['length']) * 100 if self.tdiff < 0 else 0 nan_per = (np.sum(np.isnan(shifts_filt)) / shifts_filt.size) * 100 @@ -654,11 +781,21 @@ def qc_shifts(self, shifts, shifts_filt): return qc, qc_outcome def extract_times(self, shifts_filt, t_shifts): + """ + Extracts new camera times after applying the computed shifts across the session + :param shifts_filt: filtered shifts computed across session + :param t_shifts: time point of computed shifts + :return: + """ + + # Compute the interpolation function to apply to the ttl times t_new = t_shifts - (shifts_filt * 1 / self.frate) fcn = interpolate.interp1d(t_shifts, t_new, fill_value="extrapolate") + # Apply the function and get out new times new_times = fcn(self.ttl_times) + # If we are missing ttls then interpolate and append the correct number at the end if self.tdiff < 0: to_app = (np.arange(np.abs(self.tdiff), ) + 1) / self.frate + new_times[-1] new_times = np.r_[new_times, to_app] @@ -667,8 +804,21 @@ def extract_times(self, shifts_filt, t_shifts): @staticmethod def single_cluster_raster(spike_times, events, trial_idx, dividers, colors, labels, weights=None, fr=True, - norm=False, - axs=None): + norm=False, axs=None): + """ + Compute and plot trial aligned spike rasters and psth + :param spike_times: times of variable + :param events: trial times to align to + :param trial_idx: trial idx to sort by + :param dividers: + :param colors: + :param labels: + :param weights: + :param fr: + :param norm: + :param axs: + :return: + """ pre_time = 0.4 post_time = 1 raster_bin = 0.01 @@ -737,6 +887,10 @@ def single_cluster_raster(spike_times, events, trial_idx, dividers, colors, labe return fig, axs def plot_with_behavior(self): + """ + Makes a summary figure of the alignment when behaviour data is available + :return: + """ self.dlc = likelihood_threshold(self.dlc) trial_idx, dividers = find_trial_ids(self.trials, sort='side') @@ -817,6 +971,10 @@ def plot_with_behavior(self): return fig def plot_without_behavior(self): + """ + Makes a summary figure of the alignment when behaviour data is not available + :return: + """ fig = plt.figure() fig.set_size_inches(7, 7) @@ -863,6 +1021,20 @@ def plot_without_behavior(self): return fig def process(self): + """ + Main function used to apply the video motion wheel alignment to the camera times. This function does the + following + 1. Computes the video motion energy across the whole session (computed in windows and parallelised) + 2. Computes the shift that should be applied to the camera times across the whole session by computing + the cross correlation between the video motion energy and the wheel speed (computed in + overlapping windows and parallelised) + 3. Removes artefacts from the computed shifts + 4. Computes the qc for the wheel alignment + 5. Extracts the new camera times using the shifts computed from the video wheel alignment + 6. If upload is True, creates a summary plot of the alignment and uploads the figure to the relevant session + on alyx + :return: + """ # Compute the motion energy of the wheel for the whole video wg = WindowGenerator(self.camera_meta['length'], 5000, 4) @@ -881,7 +1053,7 @@ def process(self): wg = WindowGenerator(all_me.size - 1, int(self.camera_meta['fps'] * self.twin), int(self.camera_meta['fps'] * toverlap)) - out = Parallel(n_jobs=4)(delayed(self.compute_shifts)(times, all_me, first, last, iw, wg) + out = Parallel(n_jobs=self.nprocess)(delayed(self.compute_shifts)(times, all_me, first, last, iw, wg) for iw, (first, last) in enumerate(wg.firstlast)) self.shifts = np.array([]) From c6b0784db491b3b29c56c46365a8a37bbbd1248d Mon Sep 17 00:00:00 2001 From: k1o0 Date: Wed, 11 Oct 2023 14:44:52 +0300 Subject: [PATCH 24/68] Mesoscope brain location (#656) * Update surgery JSON with normal vector * topLeftDeg -> Deg[topLeft] * Document pipes package * oneibl: partial update of session fields if session exists * Correct keys in meta for update_surgery_json * Rename UUID field * register all meta files * Handle sessions without task data * PostDLC dynamic task * add slices info to db for multi-plane recordings * fix syntax error to get other tasks going * save figure with label name * close figure and change processes * Dry mode works in register_fov without ONE instance --------- Co-authored-by: Samuel Picard Co-authored-by: olivier Co-authored-by: Mayo Faulkner --- ibllib/__init__.py | 5 +- ibllib/io/extractors/mesoscope.py | 20 ++- ibllib/io/extractors/video_motion.py | 97 +++++-------- ibllib/oneibl/data_handlers.py | 17 +-- ibllib/oneibl/registration.py | 121 ++++++++------- ibllib/pipes/__init__.py | 30 +++- ibllib/pipes/dynamic_pipeline.py | 30 +++- ibllib/pipes/ephys_preprocessing.py | 5 + ibllib/pipes/local_server.py | 63 ++++++-- ibllib/pipes/mesoscope_tasks.py | 109 ++++++++++---- ibllib/pipes/misc.py | 1 + ibllib/pipes/purge_rig_data.py | 11 +- ibllib/pipes/tasks.py | 50 +++++-- ibllib/pipes/training_preprocessing.py | 6 + ibllib/pipes/video_tasks.py | 194 ++++++++++++++++++++++--- ibllib/plots/figures.py | 42 +++--- ibllib/tests/test_mesoscope.py | 62 ++++++-- 17 files changed, 606 insertions(+), 257 deletions(-) diff --git a/ibllib/__init__.py b/ibllib/__init__.py index 4e328ee70..55e4b2b5d 100644 --- a/ibllib/__init__.py +++ b/ibllib/__init__.py @@ -2,13 +2,12 @@ import logging import warnings -__version__ = '2.26' +__version__ = '2.27' warnings.filterwarnings('always', category=DeprecationWarning, module='ibllib') # if this becomes a full-blown library we should let the logging configuration to the discretion of the dev # who uses the library. However since it can also be provided as an app, the end-users should be provided -# with an useful default logging in standard output without messing with the complex python logging system -# -*- coding:utf-8 -*- +# with a useful default logging in standard output without messing with the complex python logging system USE_LOGGING = True #%(asctime)s,%(msecs)d if USE_LOGGING: diff --git a/ibllib/io/extractors/mesoscope.py b/ibllib/io/extractors/mesoscope.py index 561bb6343..78ed21674 100644 --- a/ibllib/io/extractors/mesoscope.py +++ b/ibllib/io/extractors/mesoscope.py @@ -23,24 +23,34 @@ def patch_imaging_meta(meta: dict) -> dict: """ - Patch imaging meta data for compatibility across versions. + Patch imaging metadata for compatibility across versions. A copy of the dict is NOT returned. Parameters ---------- - dict : dict + meta : dict A folder path that contains a rawImagingData.meta file. Returns ------- dict - The loaded meta data file, updated to the most recent version. + The loaded metadata file, updated to the most recent version. """ - # 2023-05-17 (unversioned) adds nFrames and channelSaved keys - if parse_version(meta.get('version') or '0.0.0') <= parse_version('0.0.0'): + # 2023-05-17 (unversioned) adds nFrames, channelSaved keys, MM and Deg keys + version = parse_version(meta.get('version') or '0.0.0') + if version <= parse_version('0.0.0'): if 'channelSaved' not in meta: meta['channelSaved'] = next((x['channelIdx'] for x in meta['FOV'] if 'channelIdx' in x), []) + fields = ('topLeft', 'topRight', 'bottomLeft', 'bottomRight') + for fov in meta.get('FOV', []): + for unit in ('Deg', 'MM'): + if unit not in fov: # topLeftDeg, etc. -> Deg[topLeft] + fov[unit] = {f: fov.pop(f + unit, None) for f in fields} + elif version == parse_version('0.1.0'): + for fov in meta.get('FOV', []): + if 'roiUuid' in fov: + fov['roiUUID'] = fov.pop('roiUuid') return meta diff --git a/ibllib/io/extractors/video_motion.py b/ibllib/io/extractors/video_motion.py index 981af8a18..4756b2e3a 100644 --- a/ibllib/io/extractors/video_motion.py +++ b/ibllib/io/extractors/video_motion.py @@ -40,11 +40,7 @@ def find_nearest(array, value): class MotionAlignment: - roi = { - 'left': ((800, 1020), (233, 1096)), - 'right': ((426, 510), (104, 545)), - 'body': ((402, 481), (31, 103)) - } + roi = {'left': ((800, 1020), (233, 1096)), 'right': ((426, 510), (104, 545)), 'body': ((402, 481), (31, 103))} def __init__(self, eid=None, one=None, log=logging.getLogger(__name__), **kwargs): self.one = one or ONE() @@ -94,12 +90,9 @@ def line_select_callback(eclick, erelease): return np.array([[x1, x2], [y1, y2]]) plt.imshow(frame) - roi = RectangleSelector(plt.gca(), line_select_callback, - drawtype='box', useblit=True, - button=[1, 3], # don't use middle button - minspanx=5, minspany=5, - spancoords='pixels', - interactive=True) + roi = RectangleSelector(plt.gca(), line_select_callback, drawtype='box', useblit=True, button=[1, 3], + # don't use middle button + minspanx=5, minspany=5, spancoords='pixels', interactive=True) plt.show() ((x1, x2, *_), (y1, *_, y2)) = roi.corners col = np.arange(round(x1), round(x2), dtype=int) @@ -115,14 +108,13 @@ def load_data(self, download=False): self.data.wheel = self.one.load_object(self.eid, 'wheel') self.data.trials = self.one.load_object(self.eid, 'trials') cam = self.one.load(self.eid, ['camera.times'], dclass_output=True) - self.data.camera_times = {vidio.label_from_path(url): ts - for ts, url in zip(cam.data, cam.url)} + self.data.camera_times = {vidio.label_from_path(url): ts for ts, url in zip(cam.data, cam.url)} else: alf_path = self.session_path / 'alf' self.data.wheel = alfio.load_object(alf_path, 'wheel', short_keys=True) self.data.trials = alfio.load_object(alf_path, 'trials') - self.data.camera_times = {vidio.label_from_path(x): alfio.load_file_content(x) - for x in alf_path.glob('*Camera.times*')} + self.data.camera_times = {vidio.label_from_path(x): alfio.load_file_content(x) for x in + alf_path.glob('*Camera.times*')} assert all(x is not None for x in self.data.values()) def _set_eid_or_path(self, session_path_or_eid): @@ -191,8 +183,7 @@ def align_motion(self, period=(-np.inf, np.inf), side='left', sd_thresh=10, disp roi = (*[slice(*r) for r in self.roi[side]], 0) try: # TODO Add function arg to make grayscale - self.alignment.frames = \ - vidio.get_video_frames_preload(camera_path, frame_numbers, mask=roi) + self.alignment.frames = vidio.get_video_frames_preload(camera_path, frame_numbers, mask=roi) assert self.alignment.frames.size != 0 except AssertionError: self.log.error('Failed to open video') @@ -239,8 +230,8 @@ def align_motion(self, period=(-np.inf, np.inf), side='left', sd_thresh=10, disp y = np.pad(self.alignment.df, 1, 'edge') ax[0].plot(x, y, '-x', label='wheel motion energy') thresh = stDev > sd_thresh - ax[0].vlines(x[np.array(np.pad(thresh, 1, 'constant', constant_values=False))], 0, 1, - linewidth=0.5, linestyle=':', label=f'>{sd_thresh} s.d. diff') + ax[0].vlines(x[np.array(np.pad(thresh, 1, 'constant', constant_values=False))], 0, 1, linewidth=0.5, linestyle=':', + label=f'>{sd_thresh} s.d. diff') ax[1].plot(t[interp_mask], np.abs(v[interp_mask])) # Plot other stuff @@ -307,9 +298,7 @@ def init_plot(): data['frame_num'] = 0 mkr = find_nearest(wheel.timestamps[wheel_mask], ts_0) - data['marker'], = ax.plot( - wheel.timestamps[wheel_mask][mkr], - wheel.position[wheel_mask][mkr], 'r-x') + data['marker'], = ax.plot(wheel.timestamps[wheel_mask][mkr], wheel.position[wheel_mask][mkr], 'r-x') ax.set_ylabel('Wheel position (rad))') ax.set_xlabel('Time (s))') return @@ -338,19 +327,13 @@ def animate(i): data['im'].set_data(frame) mkr = find_nearest(wheel.timestamps[wheel_mask], t_x) - data['marker'].set_data( - wheel.timestamps[wheel_mask][mkr], - wheel.position[wheel_mask][mkr] - ) + data['marker'].set_data(wheel.timestamps[wheel_mask][mkr], wheel.position[wheel_mask][mkr]) return data['im'], data['ln'], data['marker'] anim = animation.FuncAnimation(fig, animate, init_func=init_plot, - frames=(range(len(self.alignment.df)) - if save - else cycle(range(60))), - interval=20, blit=False, - repeat=not save, cache_frame_data=False) + frames=(range(len(self.alignment.df)) if save else cycle(range(60))), interval=20, + blit=False, repeat=not save, cache_frame_data=False) anim.running = False def process_key(event): @@ -422,14 +405,12 @@ def fix_keys(alf_object): return ob alf_path = self.session_path.joinpath('alf') - wheel = (fix_keys(alfio.load_object(alf_path, 'wheel')) if location == 'SDSC' - else alfio.load_object(alf_path, 'wheel')) + wheel = (fix_keys(alfio.load_object(alf_path, 'wheel')) if location == 'SDSC' else alfio.load_object(alf_path, 'wheel')) self.wheel_timestamps = wheel.timestamps wheel_pos, self.wheel_time = wh.interpolate_position(wheel.timestamps, wheel.position, freq=1000) self.wheel_vel, _ = wh.velocity_filtered(wheel_pos, 1000) self.camera_times = alfio.load_file_content(next(alf_path.glob(f'_ibl_{self.label}Camera.times*.npy'))) - self.camera_path = str(next(self.session_path.joinpath('raw_video_data').glob( - f'_iblrig_{self.label}Camera.raw*.mp4'))) + self.camera_path = str(next(self.session_path.joinpath('raw_video_data').glob(f'_iblrig_{self.label}Camera.raw*.mp4'))) self.camera_meta = vidio.get_video_meta(self.camera_path) # TODO should read in the description file to get the correct sync location @@ -521,8 +502,7 @@ def compute_motion_energy(self, first, last, wg, iw): while np.any(idx == frames.shape[0] - 1) and counter < 20 and iw != wg.nwin - 1: n_after_offset = (counter + 1) * n_frames last += n_frames - extra_frames = vidio.get_video_frames_preload(cap, frame_numbers=np.arange(last, last + n_frames), - mask=self.mask) + extra_frames = vidio.get_video_frames_preload(cap, frame_numbers=np.arange(last, last + n_frames), mask=self.mask) frames = np.concatenate([frames, extra_frames], axis=0) idx = self.find_contaminated_frames(frames, self.threshold) after_status = True @@ -666,8 +646,7 @@ def extract_times(self, shifts_filt, t_shifts): return new_times @staticmethod - def single_cluster_raster(spike_times, events, trial_idx, dividers, colors, labels, weights=None, fr=True, - norm=False, + def single_cluster_raster(spike_times, events, trial_idx, dividers, colors, labels, weights=None, fr=True, norm=False, axs=None): pre_time = 0.4 post_time = 1 @@ -687,8 +666,7 @@ def single_cluster_raster(spike_times, events, trial_idx, dividers, colors, labe dividers = [0] + dividers + [len(trial_idx)] if axs is None: - fig, axs = plt.subplots(2, 1, figsize=(4, 6), gridspec_kw={'height_ratios': [1, 3], 'hspace': 0}, - sharex=True) + fig, axs = plt.subplots(2, 1, figsize=(4, 6), gridspec_kw={'height_ratios': [1, 3], 'hspace': 0}, sharex=True) else: fig = axs[0].get_figure() @@ -707,8 +685,7 @@ def single_cluster_raster(spike_times, events, trial_idx, dividers, colors, labe psth_div = np.nanmean(psth[t_ids], axis=0) std_div = np.nanstd(psth[t_ids], axis=0) / np.sqrt(len(t_ids)) - axs[0].fill_between(t_psth, psth_div - std_div, - psth_div + std_div, alpha=0.4, color=colors[lid]) + axs[0].fill_between(t_psth, psth_div - std_div, psth_div + std_div, alpha=0.4, color=colors[lid]) axs[0].plot(t_psth, psth_div, alpha=1, color=colors[lid]) lab_max = idx[np.argmax(t_ints)] @@ -726,8 +703,7 @@ def single_cluster_raster(spike_times, events, trial_idx, dividers, colors, labe secax = axs[1].secondary_yaxis('right') secax.set_yticks(label_pos) - secax.set_yticklabels(label, rotation=90, - rotation_mode='anchor', ha='center') + secax.set_yticklabels(label, rotation=90, rotation_mode='anchor', ha='center') for ic, c in enumerate(np.array(colors)[lidx]): secax.get_yticklabels()[ic].set_color(c) @@ -778,8 +754,7 @@ def plot_with_behavior(self): ax02.set_ylabel('Frames') ax02.set_xlabel('Time in session') - ax03.plot(self.camera_times, (self.camera_times - self.new_times) * self.camera_meta['fps'], - 'k', label='extracted - new') + ax03.plot(self.camera_times, (self.camera_times - self.new_times) * self.camera_meta['fps'], 'k', label='extracted - new') ax03.legend() ax03.set_ylim(-5, 5) ax03.set_ylabel('Frames') @@ -792,8 +767,8 @@ def plot_with_behavior(self): ax11.set_title('Wheel') ax12.set_xlabel('Time from first move') - self.single_cluster_raster(self.camera_times, self.trials['firstMovement_times'].values, trial_idx, dividers, - ['g', 'y'], ['left', 'right'], weights=feature_ext, fr=False, axs=[ax21, ax22]) + self.single_cluster_raster(self.camera_times, self.trials['firstMovement_times'].values, trial_idx, dividers, ['g', 'y'], + ['left', 'right'], weights=feature_ext, fr=False, axs=[ax21, ax22]) ax21.sharex(ax22) ax21.set_ylabel('Paw r velocity') ax21.set_title('Extracted times') @@ -808,8 +783,7 @@ def plot_with_behavior(self): ax41.imshow(self.frame_example[0]) rect = matplotlib.patches.Rectangle((self.roi[1][1], self.roi[0][0]), self.roi[1][0] - self.roi[1][1], - self.roi[0][1] - self.roi[0][0], - linewidth=4, edgecolor='g', facecolor='none') + self.roi[0][1] - self.roi[0][0], linewidth=4, edgecolor='g', facecolor='none') ax41.add_patch(rect) ax42.plot(self.all_me) @@ -845,8 +819,7 @@ def plot_without_behavior(self): ax02.set_ylabel('Frames') ax02.set_xlabel('Time in session') - ax03.plot(self.camera_times, (self.camera_times - self.new_times) * self.camera_meta['fps'], - 'k', label='extracted - new') + ax03.plot(self.camera_times, (self.camera_times - self.new_times) * self.camera_meta['fps'], 'k', label='extracted - new') ax03.legend() ax03.set_ylim(-5, 5) ax03.set_ylabel('Frames') @@ -854,8 +827,7 @@ def plot_without_behavior(self): ax04.imshow(self.frame_example[0]) rect = matplotlib.patches.Rectangle((self.roi[1][1], self.roi[0][0]), self.roi[1][0] - self.roi[1][1], - self.roi[0][1] - self.roi[0][0], - linewidth=4, edgecolor='g', facecolor='none') + self.roi[0][1] - self.roi[0][0], linewidth=4, edgecolor='g', facecolor='none') ax04.add_patch(rect) ax05.plot(self.all_me) @@ -866,8 +838,8 @@ def process(self): # Compute the motion energy of the wheel for the whole video wg = WindowGenerator(self.camera_meta['length'], 5000, 4) - out = Parallel(n_jobs=self.nprocess)(delayed(self.compute_motion_energy)(first, last, wg, iw) - for iw, (first, last) in enumerate(wg.firstlast)) + out = Parallel(n_jobs=self.nprocess)( + delayed(self.compute_motion_energy)(first, last, wg, iw) for iw, (first, last) in enumerate(wg.firstlast)) # Concatenate the motion energy into one big array self.all_me = np.array([]) for vals in out[:-1]: @@ -878,11 +850,11 @@ def process(self): to_app = self.times[0] - ((np.arange(int(self.camera_meta['fps'] * toverlap), ) + 1) / self.frate)[::-1] times = np.r_[to_app, self.times] - wg = WindowGenerator(all_me.size - 1, int(self.camera_meta['fps'] * self.twin), - int(self.camera_meta['fps'] * toverlap)) + wg = WindowGenerator(all_me.size - 1, int(self.camera_meta['fps'] * self.twin), int(self.camera_meta['fps'] * toverlap)) - out = Parallel(n_jobs=self.nprocess)(delayed(self.compute_shifts)(times, all_me, first, last, iw, wg) - for iw, (first, last) in enumerate(wg.firstlast)) + out = Parallel(n_jobs=4)( + delayed(self.compute_shifts)(times, all_me, first, last, iw, wg) for iw, (first, last) in enumerate(wg.firstlast) + ) self.shifts = np.array([]) self.t_shifts = np.array([]) @@ -903,11 +875,12 @@ def process(self): if self.upload: fig = self.plot_with_behavior() if self.behavior else self.plot_without_behavior() - save_fig_path = Path(self.session_path.joinpath('snapshot', 'video', 'video_wheel_alignment.png')) + save_fig_path = Path(self.session_path.joinpath('snapshot', 'video', f'video_wheel_alignment_{self.label}.png')) save_fig_path.parent.mkdir(exist_ok=True, parents=True) fig.savefig(save_fig_path) snp = ReportSnapshot(self.session_path, self.eid, content_type='session', one=self.one) snp.outputs = [save_fig_path] snp.register_images(widths=['orig']) + plt.close(fig) return self.new_times diff --git a/ibllib/oneibl/data_handlers.py b/ibllib/oneibl/data_handlers.py index 19c737e15..b41fac1f4 100644 --- a/ibllib/oneibl/data_handlers.py +++ b/ibllib/oneibl/data_handlers.py @@ -131,7 +131,8 @@ def __init__(self, session_path, signatures, one=None): # For cortex lab we need to get the endpoint from the ibl alyx if self.lab == 'cortexlab': - self.globus.add_endpoint(f'flatiron_{self.lab}', alyx=ONE(base_url='https://alyx.internationalbrainlab.org').alyx) + alyx = AlyxClient(base_url='https://alyx.internationalbrainlab.org', cache_rest=None) + self.globus.add_endpoint(f'flatiron_{self.lab}', alyx=alyx) else: self.globus.add_endpoint(f'flatiron_{self.lab}', alyx=self.one.alyx) @@ -140,21 +141,19 @@ def __init__(self, session_path, signatures, one=None): def setUp(self): """Function to download necessary data to run tasks using globus-sdk.""" if self.lab == 'cortexlab': - one = ONE(base_url='https://alyx.internationalbrainlab.org') - df = super().getData(one=one) + df = super().getData(one=ONE(base_url='https://alyx.internationalbrainlab.org')) else: - one = self.one - df = super().getData() + df = super().getData(one=self.one) if len(df) == 0: - # If no datasets found in the cache only work off local file system do not attempt to download any missing data - # using globus + # If no datasets found in the cache only work off local file system do not attempt to + # download any missing data using Globus return # Check for space on local server. If less that 500 GB don't download new data space_free = shutil.disk_usage(self.globus.endpoints['local']['root_path'])[2] if space_free < 500e9: - _logger.warning('Space left on server is < 500GB, wont redownload new data') + _logger.warning('Space left on server is < 500GB, won\'t re-download new data') return rel_sess_path = '/'.join(df.iloc[0]['session_path'].split('/')[-3:]) @@ -190,7 +189,7 @@ def uploadData(self, outputs, version, **kwargs): return register_dataset(outputs, one=self.one, versions=versions, repository=data_repo, **kwargs) def cleanUp(self): - """Clean up, remove the files that were downloaded from globus once task has completed.""" + """Clean up, remove the files that were downloaded from Globus once task has completed.""" for file in self.local_paths: os.unlink(file) diff --git a/ibllib/oneibl/registration.py b/ibllib/oneibl/registration.py index 0996f01e0..554735e15 100644 --- a/ibllib/oneibl/registration.py +++ b/ibllib/oneibl/registration.py @@ -172,31 +172,14 @@ def register_session(self, ses_path, file_list=True, projects=None, procedures=N # Read in the experiment description file if it exists and get projects and procedures from here experiment_description_file = session_params.read_params(ses_path) + _, subject, date, number, *_ = folder_parts(ses_path) if experiment_description_file is None: collections = ['raw_behavior_data'] else: - projects = experiment_description_file.get('projects', projects) - procedures = experiment_description_file.get('procedures', procedures) - collections = ensure_list(session_params.get_task_collection(experiment_description_file)) - - # read meta data from the rig for the session from the task settings file - task_data = (raw.load_bpod(ses_path, collection) for collection in sorted(collections)) - # Filter collections where settings file was not found - if not (task_data := list(zip(*filter(lambda x: x[0] is not None, task_data)))): - raise ValueError(f'_iblrig_taskSettings.raw.json not found in {ses_path} Abort.') - settings, task_data = task_data - if len(settings) != len(collections): - raise ValueError(f'_iblrig_taskSettings.raw.json not found in {ses_path} Abort.') - - # Do some validation - _, subject, date, number, *_ = folder_parts(ses_path) - assert len({x['SUBJECT_NAME'] for x in settings}) == 1 and settings[0]['SUBJECT_NAME'] == subject - assert len({x['SESSION_DATE'] for x in settings}) == 1 and settings[0]['SESSION_DATE'] == date - assert len({x['SESSION_NUMBER'] for x in settings}) == 1 and settings[0]['SESSION_NUMBER'] == number - assert len({x['IS_MOCK'] for x in settings}) == 1 - assert len({md['PYBPOD_BOARD'] for md in settings}) == 1 - assert len({md.get('IBLRIG_VERSION') for md in settings}) == 1 - # assert len({md['IBLRIG_VERSION_TAG'] for md in settings}) == 1 + # Combine input projects/procedures with those in experiment description + projects = list({*experiment_description_file.get('projects', []), *(projects or [])}) + procedures = list({*experiment_description_file.get('procedures', []), *(procedures or [])}) + collections = session_params.get_task_collection(experiment_description_file) # query Alyx endpoints for subject, error if not found subject = self.assert_exists(subject, 'subjects') @@ -206,31 +189,62 @@ def register_session(self, ses_path, file_list=True, projects=None, procedures=N date_range=date, number=number, details=True, query_type='remote') - users = [] - for user in filter(None, map(lambda x: x.get('PYBPOD_CREATOR'), settings)): - user = self.assert_exists(user[0], 'users') # user is list of [username, uuid] - users.append(user['username']) - - # extract information about session duration and performance - start_time, end_time = _get_session_times(str(ses_path), settings, task_data) - n_trials, n_correct_trials = _get_session_performance(settings, task_data) - - # TODO Add task_protocols to Alyx sessions endpoint - task_protocols = [md['PYBPOD_PROTOCOL'] + md['IBLRIG_VERSION_TAG'] for md in settings] - # unless specified label the session projects with subject projects - projects = subject['projects'] if projects is None else projects - # makes sure projects is a list - projects = [projects] if isinstance(projects, str) else projects - - # unless specified label the session procedures with task protocol lookup - procedures = procedures or list(set(filter(None, map(self._alyx_procedure_from_task, task_protocols)))) - procedures = [procedures] if isinstance(procedures, str) else procedures - json_fields_names = ['IS_MOCK', 'IBLRIG_VERSION'] - json_field = {k: settings[0].get(k) for k in json_fields_names} - # The poo count field is only updated if the field is defined in at least one of the settings - poo_counts = [md.get('POOP_COUNT') for md in settings if md.get('POOP_COUNT') is not None] - if poo_counts: - json_field['POOP_COUNT'] = int(sum(poo_counts)) + if collections is None: # No task data + assert len(session) != 0, 'no session on Alyx and no tasks in experiment description' + # Fetch the full session JSON and assert that some basic information is present. + # Basically refuse to extract the data if key information is missing + session_details = self.one.alyx.rest('sessions', 'read', id=session_id[0], no_cache=True) + required = ('location', 'start_time', 'lab', 'users') + missing = [k for k in required if not session_details[k]] + assert not any(missing), 'missing session information: ' + ', '.join(missing) + task_protocols = task_data = settings = [] + json_field = None + users = session_details['users'] + else: # Get session info from task data + collections = ensure_list(collections) + # read meta data from the rig for the session from the task settings file + task_data = (raw.load_bpod(ses_path, collection) for collection in sorted(collections)) + # Filter collections where settings file was not found + if not (task_data := list(zip(*filter(lambda x: x[0] is not None, task_data)))): + raise ValueError(f'_iblrig_taskSettings.raw.json not found in {ses_path} Abort.') + settings, task_data = task_data + if len(settings) != len(collections): + raise ValueError(f'_iblrig_taskSettings.raw.json not found in {ses_path} Abort.') + + # Do some validation + assert len({x['SUBJECT_NAME'] for x in settings}) == 1 and settings[0]['SUBJECT_NAME'] == subject['nickname'] + assert len({x['SESSION_DATE'] for x in settings}) == 1 and settings[0]['SESSION_DATE'] == date + assert len({x['SESSION_NUMBER'] for x in settings}) == 1 and settings[0]['SESSION_NUMBER'] == number + assert len({x['IS_MOCK'] for x in settings}) == 1 + assert len({md['PYBPOD_BOARD'] for md in settings}) == 1 + assert len({md.get('IBLRIG_VERSION') for md in settings}) == 1 + # assert len({md['IBLRIG_VERSION_TAG'] for md in settings}) == 1 + + users = [] + for user in filter(None, map(lambda x: x.get('PYBPOD_CREATOR'), settings)): + user = self.assert_exists(user[0], 'users') # user is list of [username, uuid] + users.append(user['username']) + + # extract information about session duration and performance + start_time, end_time = _get_session_times(str(ses_path), settings, task_data) + n_trials, n_correct_trials = _get_session_performance(settings, task_data) + + # TODO Add task_protocols to Alyx sessions endpoint + task_protocols = [md['PYBPOD_PROTOCOL'] + md['IBLRIG_VERSION_TAG'] for md in settings] + # unless specified label the session projects with subject projects + projects = subject['projects'] if projects is None else projects + # makes sure projects is a list + projects = [projects] if isinstance(projects, str) else projects + + # unless specified label the session procedures with task protocol lookup + procedures = procedures or list(set(filter(None, map(self._alyx_procedure_from_task, task_protocols)))) + procedures = [procedures] if isinstance(procedures, str) else procedures + json_fields_names = ['IS_MOCK', 'IBLRIG_VERSION'] + json_field = {k: settings[0].get(k) for k in json_fields_names} + # The poo count field is only updated if the field is defined in at least one of the settings + poo_counts = [md.get('POOP_COUNT') for md in settings if md.get('POOP_COUNT') is not None] + if poo_counts: + json_field['POOP_COUNT'] = int(sum(poo_counts)) if not session: # Create session and weighings ses_ = {'subject': subject['nickname'], @@ -258,9 +272,13 @@ def register_session(self, ses_path, file_list=True, projects=None, procedures=N user = self.one.alyx.user self.register_weight(subject['nickname'], md['SUBJECT_WEIGHT'], date_time=md['SESSION_DATETIME'], user=user) - else: # if session exists update the JSON field - session = self.one.alyx.rest('sessions', 'read', id=session_id[0], no_cache=True) - self.one.alyx.json_field_update('sessions', session['id'], data=json_field) + else: # if session exists update a few key fields + data = {'procedures': procedures, 'projects': projects} + if task_protocols: + data['task_protocol'] = '/'.join(task_protocols) + session = self.one.alyx.rest('sessions', 'partial_update', id=session_id[0], data=data) + if json_field: + session['json'] = self.one.alyx.json_field_update('sessions', session['id'], data=json_field) _logger.info(session['url'] + ' ') # create associated water administration if not found @@ -279,7 +297,8 @@ def register_session(self, ses_path, file_list=True, projects=None, procedures=N return session, None # register all files that match the Alyx patterns and file_list - rename_files_compatibility(ses_path, settings[0]['IBLRIG_VERSION_TAG']) + if any(settings): + rename_files_compatibility(ses_path, settings[0]['IBLRIG_VERSION_TAG']) F = filter(lambda x: self._register_bool(x.name, file_list), self.find_files(ses_path)) recs = self.register_files(F, created_by=users[0] if users else None, versions=ibllib.__version__) return session, recs diff --git a/ibllib/pipes/__init__.py b/ibllib/pipes/__init__.py index 2b68cdb04..95e8c6ce9 100644 --- a/ibllib/pipes/__init__.py +++ b/ibllib/pipes/__init__.py @@ -1,8 +1,28 @@ -#!/usr/bin/env python -# -*- coding:utf-8 -*- -# @Author: Niccolò Bonacchi -# @Date: Friday, July 5th 2019, 11:46:37 am -from ibllib.io.flags import FLAG_FILE_NAMES +"""IBL preprocessing pipeline. + +This module concerns the data extraction and preprocessing for IBL data. The lab servers routinely +call `local_server.job_creator` to search for new sessions to extract. The job creator registers +the new session to Alyx (i.e. creates a new session record on the database), if required, then +deduces a set of tasks (a.k.a. the pipeline [*]_) from the 'experiment.description' file at the +root of the session (see `dynamic_pipeline.make_pipeline`). If no file exists one is created, +inferring the acquisition hardware from the task protocol. The new session's pipeline tasks are +then registered for another process (or server) to query. + +Another process calls `local_server.task_queue` to get a list of queued tasks from Alyx, then +`local_server.tasks_runner` to loop through tasks. Each task is run by called +`tasks.run_alyx_task` with a dictionary of task information, including the Task class and its +parameters. + +.. [*] A pipeline is a collection of tasks that depend on one another. A pipeline consists of + tasks associated with the same session path. Unlike pipelines, tasks are represented in Alyx. + A pipeline can be recreated given a list of task dictionaries. The order is defined by the + 'parents' field of each task. + +Notes +----- +All new tasks are subclasses of the base_tasks.DynamicTask class. All others are defunct and shall +be removed in the future. +""" def assign_task(task_deck, session_path, task, **kwargs): diff --git a/ibllib/pipes/dynamic_pipeline.py b/ibllib/pipes/dynamic_pipeline.py index 7c8fd6065..bc2caaf1b 100644 --- a/ibllib/pipes/dynamic_pipeline.py +++ b/ibllib/pipes/dynamic_pipeline.py @@ -1,3 +1,8 @@ +"""Task pipeline creation from an acquisition description. + +The principal function here is `make_pipeline` which reads an `_ibl_experiment.description.yaml` +file and determines the set of tasks required to preprocess the session. +""" import logging import re from collections import OrderedDict @@ -9,7 +14,6 @@ import ibllib.io.session_params as sess_params import ibllib.io.extractors.base -import ibllib.pipes.ephys_preprocessing as epp import ibllib.pipes.tasks as mtasks import ibllib.pipes.base_tasks as bstasks import ibllib.pipes.widefield_tasks as wtasks @@ -307,14 +311,12 @@ def make_pipeline(session_path, **pkwargs): if 'cameras' in devices: cams = list(devices['cameras'].keys()) subset_cams = [c for c in cams if c in ('left', 'right', 'body', 'belly')] - video_kwargs = {'device_collection': 'raw_video_data', - 'cameras': cams} + video_kwargs = {'device_collection': 'raw_video_data', 'cameras': cams} video_compressed = sess_params.get_video_compressed(acquisition_description) if video_compressed: # This is for widefield case where the video is already compressed - tasks[tn] = type((tn := 'VideoConvert'), (vtasks.VideoConvert,), {})( - **kwargs, **video_kwargs) + tasks[tn] = type((tn := 'VideoConvert'), (vtasks.VideoConvert,), {})(**kwargs, **video_kwargs) dlc_parent_task = tasks['VideoConvert'] tasks[tn] = type((tn := f'VideoSyncQC_{sync}'), (vtasks.VideoSyncQcCamlog,), {})( **kwargs, **video_kwargs, **sync_kwargs) @@ -335,11 +337,25 @@ def make_pipeline(session_path, **pkwargs): if sync_kwargs['sync'] != 'bpod': # Here we restrict to videos that we support (left, right or body) + # Currently there is no plan to run DLC on the belly cam + subset_cams = [c for c in cams if c in ('left', 'right', 'body')] video_kwargs['cameras'] = subset_cams tasks[tn] = type((tn := 'DLC'), (vtasks.DLC,), {})( **kwargs, **video_kwargs, parents=[dlc_parent_task]) - tasks['PostDLC'] = type('PostDLC', (epp.EphysPostDLC,), {})( - **kwargs, parents=[tasks['DLC'], tasks[f'VideoSyncQC_{sync}']]) + + # The PostDLC plots require a trials object for QC + # Find the first task that outputs a trials.table dataset + trials_task = ( + t for t in tasks.values() if any('trials.table' in f for f in t.signature.get('output_files', [])) + ) + if trials_task := next(trials_task, None): + parents = [tasks['DLC'], tasks[f'VideoSyncQC_{sync}'], trials_task] + trials_collection = getattr(trials_task, 'output_collection', 'alf') + else: + parents = [tasks['DLC'], tasks[f'VideoSyncQC_{sync}']] + trials_collection = 'alf' + tasks[tn] = type((tn := 'PostDLC'), (vtasks.EphysPostDLC,), {})( + **kwargs, cameras=subset_cams, trials_collection=trials_collection, parents=parents) # Audio tasks if 'microphone' in devices: diff --git a/ibllib/pipes/ephys_preprocessing.py b/ibllib/pipes/ephys_preprocessing.py index 9cef81a34..26cef7050 100644 --- a/ibllib/pipes/ephys_preprocessing.py +++ b/ibllib/pipes/ephys_preprocessing.py @@ -1,3 +1,8 @@ +"""(Deprecated) Electrophysiology data preprocessing tasks. + +These tasks are part of the old pipeline. This module has been replaced by the `ephys_tasks` module +and the dynamic pipeline. +""" import logging import re import shutil diff --git a/ibllib/pipes/local_server.py b/ibllib/pipes/local_server.py index 47f6322b5..e04037b22 100644 --- a/ibllib/pipes/local_server.py +++ b/ibllib/pipes/local_server.py @@ -1,3 +1,9 @@ +"""Lab server pipeline construction and task runner. + +This is the module called by the job services on the lab servers. See +iblscripts/deploy/serverpc/crons for the service scripts that employ this module. +""" +import logging import time from datetime import datetime from pathlib import Path @@ -11,7 +17,6 @@ from one.api import ONE from one.webclient import AlyxClient from one.remote.globus import get_lab_from_endpoint_id, get_local_endpoint_id -from iblutil.util import setup_logger from ibllib.io.extractors.base import get_pipeline, get_task_protocol, get_session_extractor_type from ibllib.pipes import tasks, training_preprocessing, ephys_preprocessing @@ -21,8 +26,10 @@ from ibllib.io.session_params import read_params from ibllib.pipes.dynamic_pipeline import make_pipeline, acquisition_description_legacy_session -_logger = setup_logger(__name__, level='INFO') -LARGE_TASKS = ['EphysVideoCompress', 'TrainingVideoCompress', 'SpikeSorting', 'EphysDLC'] +_logger = logging.getLogger(__name__) +LARGE_TASKS = [ + 'EphysVideoCompress', 'TrainingVideoCompress', 'SpikeSorting', 'EphysDLC', 'MesoscopePreprocess' +] def _get_pipeline_class(session_path, one): @@ -65,7 +72,7 @@ def _get_volume_usage(vol, label=''): def report_health(one): """ - Get a few indicators and label the json field of the corresponding lab with them + Get a few indicators and label the json field of the corresponding lab with them. """ status = {'python_version': sys.version, 'ibllib_version': pkg_resources.get_distribution("ibllib").version, @@ -163,11 +170,20 @@ def job_creator(root_path, one=None, dry=False, rerun=False, max_md5_size=None): def task_queue(mode='all', lab=None, alyx=None): """ Query waiting jobs from the specified Lab - :param mode: Whether to return all waiting tasks, or only small or large (specified in LARGE_TASKS) jobs - :param lab: lab name as per Alyx, otherwise try to infer from local globus install - :param one: ONE instance - ------- + Parameters + ---------- + mode : {'all', 'small', 'large'} + Whether to return all waiting tasks, or only small or large (specified in LARGE_TASKS) jobs. + lab : str + Lab name as per Alyx, otherwise try to infer from local Globus install. + alyx : one.webclient.AlyxClient + An Alyx instance. + + Returns + ------- + list of dict + A list of Alyx tasks associated with `lab` that have a 'Waiting' status. """ alyx = alyx or AlyxClient(cache_rest=None) if lab is None: @@ -207,14 +223,29 @@ def task_queue(mode='all', lab=None, alyx=None): def tasks_runner(subjects_path, tasks_dict, one=None, dry=False, count=5, time_out=None, **kwargs): """ Function to run a list of tasks (task dictionary from Alyx query) on a local server - :param subjects_path: - :param tasks_dict: - :param one: - :param dry: - :param count: maximum number of tasks to run - :param time_out: between each task, if time elapsed is greater than time out, returns (seconds) - :param kwargs: - :return: list of dataset dictionaries + + Parameters + ---------- + subjects_path : str, pathlib.Path + The location of the subject session folders, e.g. '/mnt/s0/Data/Subjects'. + tasks_dict : list of dict + A list of tasks to run. Typically the output of `task_queue`. + one : one.api.OneAlyx + An instance of ONE. + dry : bool, default=False + If true, simply prints the full session paths and task names without running the tasks. + count : int, default=5 + The maximum number of tasks to run from the tasks_dict list. + time_out : float, optional + The time in seconds to run tasks before exiting. If set this will run tasks until the + timeout has elapsed. NB: Only checks between tasks and will not interrupt a running task. + **kwargs + See ibllib.pipes.tasks.run_alyx_task. + + Returns + ------- + list of pathlib.Path + A list of datasets registered to Alyx. """ if one is None: one = ONE(cache_rest=None) diff --git a/ibllib/pipes/mesoscope_tasks.py b/ibllib/pipes/mesoscope_tasks.py index e922eaefb..fee1a9c4a 100644 --- a/ibllib/pipes/mesoscope_tasks.py +++ b/ibllib/pipes/mesoscope_tasks.py @@ -15,6 +15,7 @@ import logging import subprocess import shutil +import uuid from pathlib import Path from itertools import chain from collections import defaultdict, Counter @@ -461,25 +462,35 @@ def _create_db(self, meta): Inputs to suite2p run that deviate from default parameters. """ - # Currently only supporting single plane, assert that this is the case - # FIXME This checks for zstacks but not dual plane mode - if not isinstance(meta['scanImageParams']['hStackManager']['zs'], int): - raise NotImplementedError('Multi-plane imaging not yet supported, data seems to be multi-plane') - # Computing dx and dy - cXY = np.array([fov['topLeftDeg'] for fov in meta['FOV']]) + cXY = np.array([fov['Deg']['topLeft'] for fov in meta['FOV']]) cXY -= np.min(cXY, axis=0) nXnYnZ = np.array([fov['nXnYnZ'] for fov in meta['FOV']]) - sW = np.sqrt(np.sum((np.array([fov['topRightDeg'] for fov in meta['FOV']]) - np.array( - [fov['topLeftDeg'] for fov in meta['FOV']])) ** 2, axis=1)) - sH = np.sqrt(np.sum((np.array([fov['bottomLeftDeg'] for fov in meta['FOV']]) - np.array( - [fov['topLeftDeg'] for fov in meta['FOV']])) ** 2, axis=1)) + + # Currently supporting z-stacks but not supporting dual plane / volumetric imaging, assert that this is not the case + if np.any(nXnYnZ[:, 2] > 1): + raise NotImplementedError('Dual-plane imaging not yet supported, data seems to more than one plane per FOV') + + sW = np.sqrt(np.sum((np.array([fov['Deg']['topRight'] for fov in meta['FOV']]) - np.array( + [fov['Deg']['topLeft'] for fov in meta['FOV']])) ** 2, axis=1)) + sH = np.sqrt(np.sum((np.array([fov['Deg']['bottomLeft'] for fov in meta['FOV']]) - np.array( + [fov['Deg']['topLeft'] for fov in meta['FOV']])) ** 2, axis=1)) pixSizeX = nXnYnZ[:, 0] / sW pixSizeY = nXnYnZ[:, 1] / sH dx = np.round(cXY[:, 0] * pixSizeX).astype(dtype=np.int32) dy = np.round(cXY[:, 1] * pixSizeY).astype(dtype=np.int32) nchannels = len(meta['channelSaved']) if isinstance(meta['channelSaved'], list) else 1 + # Computing number of unique z-planes (slices in tiff) + # FIXME this should work if all FOVs are discrete or if all FOVs are continuous, but may not work for combination of both + slice_ids = [fov['slice_id'] for fov in meta['FOV']] + nplanes = len(set(slice_ids)) + + # Figuring out how many SI Rois we have (one unique ROI may have several FOVs) + # FIXME currently unused + # roiUUIDs = np.array([fov['roiUUID'] for fov in meta['FOV']]) + # nrois = len(np.unique(roiUUIDs)) + db = { 'data_path': sorted(map(str, self.session_path.glob(f'{self.device_collection}'))), 'save_path0': str(self.session_path.joinpath('alf')), @@ -498,13 +509,13 @@ def _create_db(self, meta): 'block_size': [128, 128], 'save_mat': True, # save the data to Fall.mat 'move_bin': True, # move the binary file to save_path - 'scalefactor': 1, # scale manually in x to account for overlap between adjacent ribbons UCL mesoscope 'mesoscan': True, - 'nplanes': 1, + 'nplanes': nplanes, 'nrois': len(meta['FOV']), 'nchannels': nchannels, 'fs': meta['scanImageParams']['hRoiManager']['scanVolumeRate'], 'lines': [list(np.asarray(fov['lineIdx']) - 1) for fov in meta['FOV']], # subtracting 1 to make 0-based + 'slices': slice_ids, # this tells us which FOV corresponds to which tiff slices 'tau': self.get_default_tau(), # deduce the GCamp used from Alyx mouse line (defaults to 1.5; that of GCaMP6s) 'functional_chan': 1, # for now, eventually find(ismember(meta.channelSaved == meta.channelID.green)) 'align_by_chan': 1, # for now, eventually find(ismember(meta.channelSaved == meta.channelID.red)) @@ -691,13 +702,14 @@ def _run(self, *args, provenance=Provenance.ESTIMATE, **kwargs): Notes ----- - Once the FOVs have been registered they cannot be updated with with task. Rerunning this - task will result in an error. + - Once the FOVs have been registered they cannot be updated with this task. Rerunning this + task will result in an error. + - This task modifies the first meta JSON file. All meta files are registered by this task. """ # Load necessary data (filename, collection, _), *_ = self.signature['input_files'] - meta_file = next(self.session_path.glob(f'{collection}/{filename}'), None) - meta = alfio.load_file_content(meta_file) or {} + meta_files = sorted(self.session_path.glob(f'{collection}/{filename}')) + meta = mesoscope.patch_imaging_meta(alfio.load_file_content(meta_files[0]) or {}) nFOV = len(meta.get('FOV', [])) suffix = None if provenance is Provenance.HISTOLOGY else provenance.name.lower() @@ -707,7 +719,7 @@ def _run(self, *args, provenance=Provenance.ESTIMATE, **kwargs): mean_image_mlapdv, mean_image_ids = self.project_mlapdv(meta) # Save the meta data file with new coordinate fields - with open(meta_file, 'w') as fp: + with open(meta_files[0], 'w') as fp: json.dump(meta, fp) # Save the mean image datasets @@ -736,7 +748,47 @@ def _run(self, *args, provenance=Provenance.ESTIMATE, **kwargs): # Register FOVs in Alyx self.register_fov(meta, suffix) - return sorted([meta_file, *roi_files, *mean_image_files]) + return sorted([*meta_files, *roi_files, *mean_image_files]) + + def update_surgery_json(self, meta, normal_vector): + """ + Update surgery JSON with surface normal vector. + + Adds the key 'surface_normal_unit_vector' to the most recent surgery JSON, containing the + provided three element vector. The recorded craniotomy center must match the coordinates + in the provided meta file. + + Parameters + ---------- + meta : dict + The imaging meta data file containing the 'centerMM' key. + normal_vector : array_like + A three element unit vector normal to the surface of the craniotomy center. + + Returns + ------- + dict + The updated surgery record, or None if no surgeries found. + """ + if not self.one or self.one.offline: + _logger.warning('failed to update surgery JSON: ONE offline') + return + # Update subject JSON with unit normal vector of craniotomy centre (used in histology) + subject = self.one.path2ref(self.session_path, parse=False)['subject'] + surgeries = self.one.alyx.rest('surgeries', 'list', subject=subject, procedure='craniotomy') + if not surgeries: + _logger.error(f'Surgery not found for subject "{subject}"') + return + surgery = surgeries[0] # Check most recent surgery in list + center = (meta['centerMM']['ML'], meta['centerMM']['AP']) + match = (k for k, v in surgery['json'].items() if + str(k).startswith('craniotomy') and np.allclose(v['center'], center)) + if (key := next(match, None)) is None: + _logger.error('Failed to update surgery JSON: no matching craniotomy found') + return surgery + data = {key: {**surgery['json'][key], 'surface_normal_unit_vector': tuple(normal_vector)}} + surgery['json'] = self.one.alyx.json_field_update('subjects', subject, data=data) + return surgery def roi_mlapdv(self, nFOV: int, suffix=None): """ @@ -755,9 +807,9 @@ def roi_mlapdv(self, nFOV: int, suffix=None): Returns ------- - dict of int: numpy.array + dict of int : numpy.array A map of field of view to ROI MLAPDV coordinates. - dict of int: numpy.array + dict of int : numpy.array A map of field of view to ROI brain location IDs. """ all_mlapdv = {} @@ -842,8 +894,11 @@ def register_fov(self, meta: dict, suffix: str = None) -> (list, list): slice_counts = Counter(f['roiUUID'] for f in meta.get('FOV', [])) # Create a new stack in Alyx for all stacks containing more than one slice. # Map of ScanImage ROI UUID to Alyx ImageStack UUID. - stack_ids = {i: self.one.alyx.rest('imaging-stack', 'create', data={'name': i})['id'] - for i in slice_counts if slice_counts[i] > 1} + if dry: + stack_ids = {i: uuid.uuid4() for i in slice_counts if slice_counts[i] > 1} + else: + stack_ids = {i: self.one.alyx.rest('imaging-stack', 'create', data={'name': i})['id'] + for i in slice_counts if slice_counts[i] > 1} for i, fov in enumerate(meta.get('FOV', [])): assert set(fov.keys()) >= {'MLAPDV', 'nXnYnZ', 'roiUUID'} @@ -962,6 +1017,9 @@ def project_mlapdv(self, meta, atlas=None): # Get the surface normal unit vector of dorsal triangle normal_vector = surface_normal(dorsal_triangle) + # Update the surgery JSON field with normal unit vector, for use in histology alignment + self.update_surgery_json(meta, normal_vector) + # find the coordDV that sits on the triangular face and had [coordML, coordAP] coordinates; # the three vertices defining the triangle face_vertices = points[dorsal_connectivity_list[face_ind, :], :] @@ -1005,13 +1063,6 @@ def project_mlapdv(self, meta, atlas=None): # xx and yy are in mm in coverslip space points = ((0, fov['nXnYnZ'][0] - 1), (0, fov['nXnYnZ'][1] - 1)) - if 'MM' not in fov: - fov['MM'] = { - 'topLeft': fov.pop('topLeftMM'), - 'topRight': fov.pop('topRightMM'), - 'bottomLeft': fov.pop('bottomLeftMM'), - 'bottomRight': fov.pop('bottomRightMM') - } # The four corners of the FOV, determined by taking the center of the craniotomy in MM, # the x-y coordinates of the imaging window center (from the tiled reference image) in # galvanometer units, and the x-y coordinates of the FOV center in galvanometer units. diff --git a/ibllib/pipes/misc.py b/ibllib/pipes/misc.py index d3911533f..39871ad00 100644 --- a/ibllib/pipes/misc.py +++ b/ibllib/pipes/misc.py @@ -1,3 +1,4 @@ +"""Miscellaneous pipeline utility functions.""" import ctypes import hashlib import json diff --git a/ibllib/pipes/purge_rig_data.py b/ibllib/pipes/purge_rig_data.py index abe0251da..9b7afba05 100644 --- a/ibllib/pipes/purge_rig_data.py +++ b/ibllib/pipes/purge_rig_data.py @@ -1,13 +1,12 @@ -#!/usr/bin/env python -# -*- coding:utf-8 -*- -# @Author: Niccolò Bonacchi -# @Date: Thursday, March 28th 2019, 7:53:44 pm """ -Purge data from RIG +Purge data from acquisition PC. + +Steps: + - Find all files by rglob - Find all sessions of the found files - Check Alyx if corresponding datasetTypes have been registered as existing -sessions and files on Flatiron + sessions and files on Flatiron - Delete local raw file if found on Flatiron """ import argparse diff --git a/ibllib/pipes/tasks.py b/ibllib/pipes/tasks.py index b6e632579..25a645385 100644 --- a/ibllib/pipes/tasks.py +++ b/ibllib/pipes/tasks.py @@ -1,3 +1,4 @@ +"""The abstract Pipeline and Task superclasses and concrete task runner.""" from pathlib import Path import abc import logging @@ -602,22 +603,39 @@ def name(self): def run_alyx_task(tdict=None, session_path=None, one=None, job_deck=None, max_md5_size=None, machine=None, clobber=True, location='server', mode='log'): """ - Runs a single Alyx job and registers output datasets - :param tdict: - :param session_path: - :param one: - :param job_deck: optional list of job dictionaries belonging to the session. Needed - to check dependency status if the jdict has a parent field. If jdict has a parent and - job_deck is not entered, will query the database - :param max_md5_size: in bytes, if specified, will not compute the md5 checksum above a given - filesize to save time - :param machine: string identifying the machine the task is run on, optional - :param clobber: bool, if True any existing logs are overwritten, default is True - :param location: where you are running the task, 'server' - local lab server, 'remote' - any - compute node/ computer, 'SDSC' - flatiron compute node, 'AWS' - using data from aws s3 - :param mode: str ('log' or 'raise') behaviour to adopt if an error occured. If 'raise', it - will Raise the error at the very end of this function (ie. after having labeled the tasks) - :return: + Runs a single Alyx job and registers output datasets. + + Parameters + ---------- + tdict : dict + An Alyx task dictionary to instantiate and run. + session_path : str, pathlib.Path + A session path containing the task input data. + one : one.api.OneAlyx + An instance of ONE. + job_deck : list of dict, optional + A list of all tasks in the same pipeline. If None, queries Alyx to get this. + max_md5_size : int, optional + An optional maximum file size in bytes. Files with sizes larger than this will not have + their MD5 checksum calculated to save time. + machine : str, optional + A string identifying the machine the task is run on. + clobber : bool, default=True + If true any existing logs are overwritten on Alyx. + location : {'remote', 'server', 'sdsc', 'aws'} + Where you are running the task, 'server' - local lab server, 'remote' - any + compute node/ computer, 'sdsc' - Flatiron compute node, 'aws' - using data from AWS S3 + node. + mode : {'log', 'raise}, default='log' + Behaviour to adopt if an error occurred. If 'raise', it will raise the error at the very + end of this function (i.e. after having labeled the tasks). + + Returns + ------- + Task + The instantiated task object that was run. + list of pathlib.Path + A list of registered datasets. """ registered_dsets = [] # here we need to check parents' status, get the job_deck if not available diff --git a/ibllib/pipes/training_preprocessing.py b/ibllib/pipes/training_preprocessing.py index b47adcc65..db41f8992 100644 --- a/ibllib/pipes/training_preprocessing.py +++ b/ibllib/pipes/training_preprocessing.py @@ -1,3 +1,9 @@ +"""(Deprecated) Training data preprocessing tasks. + +These tasks are part of the old pipeline. This module has been replaced by the dynamic pipeline +and the `behavior_tasks` module. +""" + import logging from collections import OrderedDict from one.alf.files import session_path_parts diff --git a/ibllib/pipes/video_tasks.py b/ibllib/pipes/video_tasks.py index eaf00aaa0..7f501a065 100644 --- a/ibllib/pipes/video_tasks.py +++ b/ibllib/pipes/video_tasks.py @@ -1,9 +1,13 @@ import logging import subprocess -import cv2 import traceback from pathlib import Path +import cv2 +import pandas as pd +import numpy as np + +from ibllib.qc.dlc import DlcQC from ibllib.io import ffmpeg, raw_daq_loaders from ibllib.pipes import base_tasks from ibllib.io.video import get_video_meta @@ -11,6 +15,9 @@ from ibllib.qc.camera import run_all_qc as run_camera_qc from ibllib.misc import check_nvidia_driver from ibllib.io.video import label_from_path, assert_valid_label +from ibllib.plots.snapshot import ReportSnapshot +from ibllib.plots.figures import dlc_qc_plot +from brainbox.behavior.dlc import likelihood_threshold, get_licks, get_pupil_diameter, get_smooth_pupil_diameter _logger = logging.getLogger('ibllib') @@ -48,9 +55,9 @@ def assert_expected_outputs(self, raise_error=True): required = any('Camera.frameData' in x or 'Camera.timestamps' in x for x in map(str, files)) if not (everything_is_fine and required): for out in self.outputs: - _logger.error(f"{out}") + _logger.error(f'{out}') if raise_error: - raise FileNotFoundError("Missing outputs after task completion") + raise FileNotFoundError('Missing outputs after task completion') return everything_is_fine, files @@ -120,7 +127,7 @@ def _run(self): # convert the avi files to mp4 avi_file = next(self.session_path.joinpath(self.device_collection).glob(f'{cam}_cam*.avi')) mp4_file = self.session_path.joinpath(self.device_collection, f'_iblrig_{cam}Camera.raw.mp4') - command2run = f"ffmpeg -i {str(avi_file)} -c:v copy -c:a copy -y {str(mp4_file)}" + command2run = f'ffmpeg -i {str(avi_file)} -c:v copy -c:a copy -y {str(mp4_file)}' process = subprocess.Popen(command2run, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) info, error = process.communicate() @@ -191,7 +198,7 @@ def _run(self, qc=True, **kwargs): class VideoSyncQcBpod(base_tasks.VideoTask): """ Task to sync camera timestamps to main DAQ timestamps - N.B Signatures only reflect new daq naming convention, non compatible with ephys when not running on server + N.B Signatures only reflect new daq naming convention, non-compatible with ephys when not running on server """ priority = 40 job_size = 'small' @@ -241,7 +248,7 @@ def _run(self, **kwargs): class VideoSyncQcNidq(base_tasks.VideoTask): """ Task to sync camera timestamps to main DAQ timestamps - N.B Signatures only reflect new daq naming convention, non compatible with ephys when not running on server + N.B Signatures only reflect new daq naming convention, non-compatible with ephys when not running on server """ priority = 40 job_size = 'small' @@ -328,19 +335,19 @@ def _check_dlcenv(self): f'Scripts run_dlc.sh and run_dlc.py do not exist in {self.scripts}' assert len(list(self.scripts.rglob('run_motion.*'))) == 2, \ f'Scripts run_motion.sh and run_motion.py do not exist in {self.scripts}' - assert self.dlcenv.exists(), f"DLC environment does not exist in assumed location {self.dlcenv}" + assert self.dlcenv.exists(), f'DLC environment does not exist in assumed location {self.dlcenv}' command2run = f"source {self.dlcenv}; python -c 'import iblvideo; print(iblvideo.__version__)'" process = subprocess.Popen( command2run, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, - executable="/bin/bash" + executable='/bin/bash' ) info, error = process.communicate() if process.returncode != 0: raise AssertionError(f"DLC environment check failed\n{error.decode('utf-8')}") - version = info.decode("utf-8").strip().split('\n')[-1] + version = info.decode('utf-8').strip().split('\n')[-1] return version @staticmethod @@ -378,11 +385,11 @@ def _run(self, cams=None, overwrite=False): file_mp4 = next(self.session_path.joinpath('raw_video_data').glob(f'_iblrig_{cam}Camera.raw*.mp4')) if not file_mp4.exists(): # In this case we set the status to Incomplete. - _logger.error(f"No raw video file available for {cam}, skipping.") + _logger.error(f'No raw video file available for {cam}, skipping.') self.status = -3 continue if not self._video_intact(file_mp4): - _logger.error(f"Corrupt raw video file {file_mp4}") + _logger.error(f'Corrupt raw video file {file_mp4}') self.status = -1 continue # Check that dlc environment is ok, shell scripts exists, and get iblvideo version, GPU addressable @@ -398,13 +405,13 @@ def _run(self, cams=None, overwrite=False): shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, - executable="/bin/bash", + executable='/bin/bash', ) info, error = process.communicate() # info_str = info.decode("utf-8").strip() # _logger.info(info_str) if process.returncode != 0: - error_str = error.decode("utf-8").strip() + error_str = error.decode('utf-8').strip() _logger.error(f'DLC failed for {cam}Camera.\n\n' f'++++++++ Output of subprocess for debugging ++++++++\n\n' f'{error_str}\n' @@ -423,13 +430,13 @@ def _run(self, cams=None, overwrite=False): shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, - executable="/bin/bash", + executable='/bin/bash', ) info, error = process.communicate() - # info_str = info.decode("utf-8").strip() + # info_str = info.decode('utf-8').strip() # _logger.info(info_str) if process.returncode != 0: - error_str = error.decode("utf-8").strip() + error_str = error.decode('utf-8').strip() _logger.error(f'Motion energy failed for {cam}Camera.\n\n' f'++++++++ Output of subprocess for debugging ++++++++\n\n' f'{error_str}\n' @@ -440,7 +447,7 @@ def _run(self, cams=None, overwrite=False): f'{cam}Camera.ROIMotionEnergy*.npy'))) actual_outputs.append(next(self.session_path.joinpath('alf').glob( f'{cam}ROIMotionEnergy.position*.npy'))) - except BaseException: + except Exception: _logger.error(traceback.format_exc()) self.status = -1 continue @@ -450,3 +457,156 @@ def _run(self, cams=None, overwrite=False): actual_outputs = None self.status = -1 return actual_outputs + + +class EphysPostDLC(base_tasks.VideoTask): + """ + The post_dlc task takes dlc traces as input and computes useful quantities, as well as qc. + """ + io_charge = 90 + level = 3 + force = True + + def __int__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.trials_collection = kwargs.get('trials_collection', 'alf') + + @property + def signature(self): + return { + 'input_files': [(f'_ibl_{cam}Camera.dlc.pqt', 'alf', True) for cam in self.cameras] + + [(f'_ibl_{cam}Camera.times.npy', 'alf', True) for cam in self.cameras] + + # the following are required for the DLC plot only + # they are not strictly required, some plots just might be skipped + # In particular the raw videos don't need to be downloaded as they can be streamed + [(f'_iblrig_{cam}Camera.raw.mp4', self.device_collection, True) for cam in self.cameras] + + [(f'{cam}ROIMotionEnergy.position.npy', 'alf', False) for cam in self.cameras] + + # The trials table is used in the DLC QC, however this is not an essential dataset + [('_ibl_trials.table.pqt', self.trials_collection, False)], + 'output_files': [(f'_ibl_{cam}Camera.features.pqt', 'alf', True) for cam in self.cameras] + + [('licks.times.npy', 'alf', True)] + } + + def _run(self, overwrite=True, run_qc=True, plot_qc=True): + """ + Run the PostDLC task. Returns a list of file locations for the output files in signature. The created plot + (dlc_qc_plot.png) is not returned, but saved in session_path/snapshots and uploaded to Alyx as a note. + + :param overwrite: bool, whether to recompute existing output files (default is False). + Note that the dlc_qc_plot will be (re-)computed even if overwrite = False + :param run_qc: bool, whether to run the DLC QC (default is True) + :param plot_qc: book, whether to create the dlc_qc_plot (default is True) + + """ + # Check if output files exist locally + exist, output_files = self.assert_expected(self.signature['output_files'], silent=True) + if exist and not overwrite: + _logger.warning('EphysPostDLC outputs exist and overwrite=False, skipping computations of outputs.') + else: + if exist and overwrite: + _logger.warning('EphysPostDLC outputs exist and overwrite=True, overwriting existing outputs.') + # Find all available DLC files + dlc_files = list(Path(self.session_path).joinpath('alf').glob('_ibl_*Camera.dlc.*')) + for dlc_file in dlc_files: + _logger.debug(dlc_file) + output_files = [] + combined_licks = [] + + for dlc_file in dlc_files: + # Catch unforeseen exceptions and move on to next cam + try: + cam = label_from_path(dlc_file) + # load dlc trace and camera times + dlc = pd.read_parquet(dlc_file) + dlc_thresh = likelihood_threshold(dlc, 0.9) + # try to load respective camera times + try: + dlc_t = np.load(next(Path(self.session_path).joinpath('alf').glob(f'_ibl_{cam}Camera.times.*npy'))) + times = True + if dlc_t.shape[0] == 0: + _logger.error(f'camera.times empty for {cam} camera. ' + f'Computations using camera.times will be skipped') + self.status = -1 + times = False + elif dlc_t.shape[0] < len(dlc_thresh): + _logger.error(f'Camera times shorter than DLC traces for {cam} camera. ' + f'Computations using camera.times will be skipped') + self.status = -1 + times = 'short' + except StopIteration: + self.status = -1 + times = False + _logger.error(f'No camera.times for {cam} camera. ' + f'Computations using camera.times will be skipped') + # These features are only computed from left and right cam + if cam in ('left', 'right'): + features = pd.DataFrame() + # If camera times are available, get the lick time stamps for combined array + if times is True: + _logger.info(f'Computing lick times for {cam} camera.') + combined_licks.append(get_licks(dlc_thresh, dlc_t)) + elif times is False: + _logger.warning(f'Skipping lick times for {cam} camera as no camera.times available') + elif times == 'short': + _logger.warning(f'Skipping lick times for {cam} camera as camera.times are too short') + # Compute pupil diameter, raw and smoothed + _logger.info(f'Computing raw pupil diameter for {cam} camera.') + features['pupilDiameter_raw'] = get_pupil_diameter(dlc_thresh) + try: + _logger.info(f'Computing smooth pupil diameter for {cam} camera.') + features['pupilDiameter_smooth'] = get_smooth_pupil_diameter(features['pupilDiameter_raw'], + cam) + except Exception: + _logger.error(f'Computing smooth pupil diameter for {cam} camera failed, saving all NaNs.') + _logger.error(traceback.format_exc()) + features['pupilDiameter_smooth'] = np.nan + # Save to parquet + features_file = Path(self.session_path).joinpath('alf', f'_ibl_{cam}Camera.features.pqt') + features.to_parquet(features_file) + output_files.append(features_file) + + # For all cams, compute DLC QC if times available + if run_qc is True and times in [True, 'short']: + # Setting download_data to False because at this point the data should be there + qc = DlcQC(self.session_path, side=cam, one=self.one, download_data=False) + qc.run(update=True) + else: + if times is False: + _logger.warning(f'Skipping QC for {cam} camera as no camera.times available') + if not run_qc: + _logger.warning(f'Skipping QC for {cam} camera as run_qc=False') + + except Exception: + _logger.error(traceback.format_exc()) + self.status = -1 + continue + + # Combined lick times + if len(combined_licks) > 0: + lick_times_file = Path(self.session_path).joinpath('alf', 'licks.times.npy') + np.save(lick_times_file, sorted(np.concatenate(combined_licks))) + output_files.append(lick_times_file) + else: + _logger.warning('No lick times computed for this session.') + + if plot_qc: + _logger.info('Creating DLC QC plot') + try: + session_id = self.one.path2eid(self.session_path) + fig_path = self.session_path.joinpath('snapshot', 'dlc_qc_plot.png') + if not fig_path.parent.exists(): + fig_path.parent.mkdir(parents=True, exist_ok=True) + fig = dlc_qc_plot(self.session_path, one=self.one, cameras=self.cameras, device_collection=self.device_collection, + trials_collection=self.trials_collection) + fig.savefig(fig_path) + fig.clf() + snp = ReportSnapshot(self.session_path, session_id, one=self.one) + snp.outputs = [fig_path] + snp.register_images(widths=['orig'], + function=str(dlc_qc_plot.__module__) + '.' + str(dlc_qc_plot.__name__)) + except Exception: + _logger.error('Could not create and/or upload DLC QC Plot') + _logger.error(traceback.format_exc()) + self.status = -1 + + return output_files diff --git a/ibllib/plots/figures.py b/ibllib/plots/figures.py index 17762ce01..34a444e9b 100644 --- a/ibllib/plots/figures.py +++ b/ibllib/plots/figures.py @@ -669,7 +669,8 @@ def raw_destripe(raw, fs, t0, i_plt, n_plt, return fig, axs -def dlc_qc_plot(session_path, one=None): +def dlc_qc_plot(session_path, one=None, device_collection='raw_video_data', + cameras=('left', 'right', 'body'), trials_collection='alf'): """ Creates DLC QC plot. Data is searched first locally, then on Alyx. Panels that lack required data are skipped. @@ -707,14 +708,13 @@ def dlc_qc_plot(session_path, one=None): if one.alyx.base_url == 'https://alyx.cortexlab.net': one = ONE(base_url='https://alyx.internationalbrainlab.org') data = {} - cams = ['left', 'right', 'body'] session_path = Path(session_path) # Load data for each camera - for cam in cams: + for cam in cameras: # Load a single frame for each video # Check if video data is available locally,if yes, load a single frame - video_path = session_path.joinpath('raw_video_data', f'_iblrig_{cam}Camera.raw.mp4') + video_path = session_path.joinpath(device_collection, f'_iblrig_{cam}Camera.raw.mp4') if video_path.exists(): data[f'{cam}_frame'] = get_video_frame(video_path, frame_number=5 * 60 * SAMPLING[cam])[:, :, 0] # If not, try to stream a frame (try three times) @@ -725,7 +725,7 @@ def dlc_qc_plot(session_path, one=None): try: data[f'{cam}_frame'] = get_video_frame(video_url, frame_number=5 * 60 * SAMPLING[cam])[:, :, 0] break - except BaseException: + except Exception: if tries < 2: tries += 1 logger.info(f"Streaming {cam} video failed, retrying x{tries}") @@ -757,19 +757,20 @@ def dlc_qc_plot(session_path, one=None): data[f'{cam}_{feat}'] = None # If we have no frame and/or no DLC and/or no times for all cams, raise an error, something is really wrong - assert any([data[f'{cam}_frame'] is not None for cam in cams]), "No camera data could be loaded, aborting." - assert any([data[f'{cam}_dlc'] is not None for cam in cams]), "No DLC data could be loaded, aborting." - assert any([data[f'{cam}_times'] is not None for cam in cams]), "No camera times data could be loaded, aborting." + assert any(data[f'{cam}_frame'] is not None for cam in cameras), "No camera data could be loaded, aborting." + assert any(data[f'{cam}_dlc'] is not None for cam in cameras), "No DLC data could be loaded, aborting." + assert any(data[f'{cam}_times'] is not None for cam in cameras), "No camera times data could be loaded, aborting." # Load session level data for alf_object in ['trials', 'wheel', 'licks']: try: - data[f'{alf_object}'] = alfio.load_object(session_path.joinpath('alf'), alf_object) # load locally + data[f'{alf_object}'] = alfio.load_object(session_path.joinpath(trials_collection), alf_object) # load locally continue except ALFObjectNotFound: pass try: - data[f'{alf_object}'] = one.load_object(one.path2eid(session_path), alf_object) # then try from alyx + # then try from alyx + data[f'{alf_object}'] = one.load_object(one.path2eid(session_path), alf_object, collection=trials_collection) except ALFObjectNotFound: logger.warning(f"Could not load {alf_object} object, some plots have to be skipped.") data[f'{alf_object}'] = None @@ -786,7 +787,7 @@ def dlc_qc_plot(session_path, one=None): # Make a list of panels, if inputs are missing, instead input a text to display panels = [] # Panel A, B, C: Trace on frame - for cam in cams: + for cam in cameras: if data[f'{cam}_frame'] is not None and data[f'{cam}_dlc'] is not None: panels.append((plot_trace_on_frame, {'frame': data[f'{cam}_frame'], 'dlc_df': data[f'{cam}_dlc'], 'cam': cam})) @@ -795,15 +796,14 @@ def dlc_qc_plot(session_path, one=None): # If trials data is not there, we cannot plot any of the trial average plots, skip all remaining panels if data['trials'] is None: - panels.extend([(None, 'No trial data,\ncannot compute trial avgs') for i in range(7)]) + panels.extend([(None, 'No trial data,\ncannot compute trial avgs')] * 7) else: # Panel D: Motion energy - camera_dict = {'left': {'motion_energy': data['left_ROIMotionEnergy'], 'times': data['left_times']}, - 'right': {'motion_energy': data['right_ROIMotionEnergy'], 'times': data['right_times']}, - 'body': {'motion_energy': data['body_ROIMotionEnergy'], 'times': data['body_times']}} - for cam in ['left', 'right', 'body']: # Remove cameras where we don't have motion energy AND camera times - if camera_dict[cam]['motion_energy'] is None or camera_dict[cam]['times'] is None: - _ = camera_dict.pop(cam) + camera_dict = {} + for cam in cameras: # Remove cameras where we don't have motion energy AND camera times + d = {'motion_energy': data.get(f'{cam}_ROIMotionEnergy'), 'times': data.get(f'{cam}_times')} + if not any(x is None for x in d.values()): + camera_dict[cam] = d if len(camera_dict) > 0: panels.append((plot_motion_energy_hist, {'camera_dict': camera_dict, 'trials_df': data['trials']})) else: @@ -833,7 +833,7 @@ def dlc_qc_plot(session_path, one=None): 'trials_df': data['trials'], 'feature': 'nose_tip', 'legend': False, 'cam': cam})) else: - panels.extend([(None, 'Data missing or corrupt\nSpeed histograms') for i in range(2)]) + panels.extend([(None, 'Data missing or corrupt\nSpeed histograms')] * 2) # Panel H and I: Lick plots if data['licks'] and data['licks'].times.shape[0] > 0: @@ -846,7 +846,7 @@ def dlc_qc_plot(session_path, one=None): # Try if all data is there for left cam first, otherwise right for cam in ['left', 'right']: fail = False - if (data[f'{cam}_times'] is not None and data[f'{cam}_features'] is not None + if (data.get(f'{cam}_times') is not None and data.get(f'{cam}_features') is not None and len(data[f'{cam}_times']) >= len(data[f'{cam}_features']) and not np.all(np.isnan(data[f'{cam}_features'].pupilDiameter_smooth))): break @@ -872,7 +872,7 @@ def dlc_qc_plot(session_path, one=None): else: try: panel[0](**panel[1]) - except BaseException: + except Exception: logger.error(f'Error in {panel[0].__name__}\n' + traceback.format_exc()) ax.text(.5, .5, f'Error while plotting\n{panel[0].__name__}', color='r', fontweight='bold', fontsize=12, horizontalalignment='center', verticalalignment='center', transform=ax.transAxes) diff --git a/ibllib/tests/test_mesoscope.py b/ibllib/tests/test_mesoscope.py index 4579d202b..b828386db 100644 --- a/ibllib/tests/test_mesoscope.py +++ b/ibllib/tests/test_mesoscope.py @@ -35,23 +35,22 @@ def test_meta(self): """ expected = { 'data_path': [str(self.img_path)], + 'save_path0': str(self.session_path.joinpath('alf')), 'fast_disk': '', + 'look_one_level_down': False, 'num_workers': -1, - 'save_path0': str(self.session_path.joinpath('alf')), - 'move_bin': True, + 'num_workers_roi': -1, 'keep_movie_raw': False, 'delete_bin': False, 'batch_size': 500, - 'combined': False, - 'look_one_level_down': False, - 'num_workers_roi': -1, 'nimg_init': 400, + 'combined': False, 'nonrigid': True, 'maxregshift': 0.05, 'denoise': 1, 'block_size': [128, 128], 'save_mat': True, - 'scalefactor': 1, + 'move_bin': True, 'mesoscan': True, 'nplanes': 1, 'tau': 1.5, @@ -61,6 +60,7 @@ def test_meta(self): 'nchannels': 1, 'fs': 6.8, 'lines': [[3, 4, 5]], + 'slices': [0], 'dx': np.array([0], dtype=int), 'dy': np.array([0], dtype=int), } @@ -69,7 +69,7 @@ def test_meta(self): 'scanImageParams': {'hStackManager': {'zs': 320}, 'hRoiManager': {'scanVolumeRate': 6.8}}, 'FOV': [{'topLeftDeg': [-1, 1.3], 'topRightDeg': [3, 1.3], 'bottomLeftDeg': [-1, 5.2], - 'nXnYnZ': [512, 512, 1], 'channelIdx': 2, 'lineIdx': [4, 5, 6]}] + 'nXnYnZ': [512, 512, 1], 'channelIdx': 2, 'lineIdx': [4, 5, 6], 'slice_id': 0}] } with open(self.img_path.joinpath('_ibl_rawImagingData.meta.json'), 'w') as f: json.dump(meta, f) @@ -150,6 +150,41 @@ def test_nearest_neighbour_1d(self): np.testing.assert_array_equal(val, [1., 1., 1., 3., 3., 2., 5., 5.]) np.testing.assert_array_equal(ind, [1, 1, 1, 4, 4, 0, 3, 3]) + def test_update_surgery_json(self): + """Test for MesoscopeFOV.update_surgery_json method. + + Here we mock the Alyx object and simply check the method's calls. + """ + one = ONE(**TEST_DB) + task = MesoscopeFOV('/foo/bar/subject/2020-01-01/001', one=one) + record = {'json': {'craniotomy_00': {'center': [1., -3.]}, 'craniotomy_01': {'center': [2.7, -1.3]}}} + normal_vector = np.array([0.5, 1., 0.]) + meta = {'centerMM': {'ML': 2.7, 'AP': -1.30000000001}} + with mock.patch.object(one.alyx, 'rest', return_value=[record, {}]), \ + mock.patch.object(one.alyx, 'json_field_update') as mock_rest: + task.update_surgery_json(meta, normal_vector) + expected = {'craniotomy_01': {'center': [2.7, -1.3], + 'surface_normal_unit_vector': (0.5, 1., 0.)}} + mock_rest.assert_called_once_with('subjects', 'subject', data=expected) + + # Check errors and warnings + # No matching craniotomy center + with self.assertLogs('ibllib.pipes.mesoscope_tasks', 'ERROR'), \ + mock.patch.object(one.alyx, 'rest', return_value=[record, {}]): + task.update_surgery_json({'centerMM': {'ML': 0., 'AP': 0.}}, normal_vector) + # No matching surgery records + with self.assertLogs('ibllib.pipes.mesoscope_tasks', 'ERROR'), \ + mock.patch.object(one.alyx, 'rest', return_value=[]): + task.update_surgery_json(meta, normal_vector) + # ONE offline + one.mode = 'local' + try: + with self.assertLogs('ibllib.pipes.mesoscope_tasks', 'WARNING'): + task.update_surgery_json(meta, normal_vector) + finally: + # ONE function is cached so we must reset the mode for other tests + one.mode = 'auto' + class TestRegisterFOV(unittest.TestCase): """Test for MesoscopeFOV.register_fov method.""" @@ -173,7 +208,7 @@ def test_register_fov(self): 'bottomLeft': [2317.3, -2181.4, -466.3], 'bottomRight': [2862.7, -2206.9, -679.4], 'center': [2596.1, -1900.5, -588.6]} meta = {'FOV': [{'MLAPDV': mlapdv, 'nXnYnZ': [512, 512, 1], 'roiUUID': 0}]} - with unittest.mock.patch.object(self.one.alyx, 'rest') as mock_rest: + with unittest.mock.patch.object(task.one.alyx, 'rest') as mock_rest: task.register_fov(meta, 'estimate') calls = mock_rest.call_args_list self.assertEqual(3, len(calls)) @@ -197,8 +232,8 @@ def test_register_fov(self): # Check dry mode with suffix input = None for file in self.session_path.joinpath('alf', 'FOV_00').glob('mpciMeanImage.*'): file.replace(file.with_name(file.name.replace('_estimate', ''))) - self.one.mode = 'local' - with unittest.mock.patch.object(self.one.alyx, 'rest') as mock_rest: + task.one.mode = 'local' + with unittest.mock.patch.object(task.one.alyx, 'rest') as mock_rest: out = task.register_fov(meta, None) mock_rest.assert_not_called() self.assertEqual(1, len(out)) @@ -206,3 +241,10 @@ def test_register_fov(self): locations = out[0]['location'] self.assertEqual(1, len(locations)) self.assertEqual('L', locations[0].get('provenance', 'L')) + + def tearDown(self) -> None: + """ + The ONE function is cached and therefore the One object persists beyond this test. + Here we return the mode back to the default after testing behaviour in offline mode. + """ + self.one.mode = 'auto' From 6e591014d3c7b6cb0fdbbe112fee3c4c5a6ee697 Mon Sep 17 00:00:00 2001 From: Mayo Faulkner Date: Wed, 11 Oct 2023 15:45:04 +0100 Subject: [PATCH 25/68] change nprocess --- ibllib/io/extractors/camera.py | 2 +- ibllib/io/extractors/video_motion.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/ibllib/io/extractors/camera.py b/ibllib/io/extractors/camera.py index bf3c95528..5aa353adf 100644 --- a/ibllib/io/extractors/camera.py +++ b/ibllib/io/extractors/camera.py @@ -157,7 +157,7 @@ def _extract(self, sync=None, chmap=None, video_path=None, sync_label='audio', # Can only use wheel alignment for left and right cameras raise ValueError(f'Wheel alignment not supported for {self.label} camera') - motion_class = vmotion.MotionAlignmentFullSession(self.session_path, self.label, upload=True) + motion_class = vmotion.MotionAlignmentFullSession(self.session_path, self.label, sync='nidq', upload=True) new_times = motion_class.process() if not motion_class.qc_outcome: raise ValueError(f'Wheel alignment failed to pass qc: {motion_class.qc}') diff --git a/ibllib/io/extractors/video_motion.py b/ibllib/io/extractors/video_motion.py index 14de3b3f7..7afda4c87 100644 --- a/ibllib/io/extractors/video_motion.py +++ b/ibllib/io/extractors/video_motion.py @@ -1053,8 +1053,8 @@ def process(self): wg = WindowGenerator(all_me.size - 1, int(self.camera_meta['fps'] * self.twin), int(self.camera_meta['fps'] * toverlap)) - out = Parallel(n_jobs=self.nprocess)(delayed(self.compute_shifts)(times, all_me, first, last, iw, wg) - for iw, (first, last) in enumerate(wg.firstlast)) + out = Parallel(n_jobs=1)(delayed(self.compute_shifts)(times, all_me, first, last, iw, wg) + for iw, (first, last) in enumerate(wg.firstlast)) self.shifts = np.array([]) self.t_shifts = np.array([]) From c9d71270de4cde47929723f3e2aeccdf458240c2 Mon Sep 17 00:00:00 2001 From: Mayo Faulkner Date: Thu, 12 Oct 2023 10:11:26 +0100 Subject: [PATCH 26/68] better logging and release notes --- ibllib/io/extractors/camera.py | 6 +++--- release_notes.md | 4 ++++ 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/ibllib/io/extractors/camera.py b/ibllib/io/extractors/camera.py index 5aa353adf..7612c3e9e 100644 --- a/ibllib/io/extractors/camera.py +++ b/ibllib/io/extractors/camera.py @@ -160,13 +160,13 @@ def _extract(self, sync=None, chmap=None, video_path=None, sync_label='audio', motion_class = vmotion.MotionAlignmentFullSession(self.session_path, self.label, sync='nidq', upload=True) new_times = motion_class.process() if not motion_class.qc_outcome: - raise ValueError(f'Wheel alignment failed to pass qc: {motion_class.qc}') + raise ValueError(f'Wheel alignment for {self.label} camera failed to pass qc: {motion_class.qc}') else: - _logger.warning(f'Wheel alignment successful, qc: {motion_class.qc}') + _logger.warning(f'Wheel alignment for {self.label} camera successful, qc: {motion_class.qc}') return new_times except Exception as err: - _logger.critical(f'Failed to align with wheel: {err}') + _logger.critical(f'Failed to align with wheel for {self.label} camera: {err}') if length < raw_ts.size: df = raw_ts.size - length diff --git a/release_notes.md b/release_notes.md index beeb99179..a7aab07df 100644 --- a/release_notes.md +++ b/release_notes.md @@ -1,3 +1,7 @@ +## Develop +- Add full video wheel motion alignment code to ibllib.io.extractors.video_motion module +- Change FPGA camera extractor to attempt wheel alignment if audio alignment fails + ## Release Notes 2.26 ### features From 5f97c0b9307c3bf6f74c0804687433c3d88bcb6e Mon Sep 17 00:00:00 2001 From: Miles Wells Date: Tue, 17 Oct 2023 16:11:34 +0300 Subject: [PATCH 27/68] Move url2uri to ONE --- ibllib/oneibl/patcher.py | 9 +-------- requirements.txt | 2 +- 2 files changed, 2 insertions(+), 9 deletions(-) diff --git a/ibllib/oneibl/patcher.py b/ibllib/oneibl/patcher.py index 290322f10..d34691710 100644 --- a/ibllib/oneibl/patcher.py +++ b/ibllib/oneibl/patcher.py @@ -4,7 +4,6 @@ from collections import defaultdict from itertools import starmap from subprocess import Popen, PIPE, STDOUT -from urllib.parse import urlparse import subprocess import logging from getpass import getpass @@ -18,6 +17,7 @@ from one.webclient import AlyxClient from one.converters import path_from_dataset from one.remote import globus +from one.remote.aws import url2uri from ibllib.oneibl.registration import register_dataset @@ -34,13 +34,6 @@ SDSC_PATCH_PATH = PurePosixPath('/home/datauser/temp') -def url2uri(data_path): - parsed = urlparse(data_path) - assert parsed.netloc and parsed.scheme and parsed.path - bucket_name = parsed.netloc.split('.')[0] - return f's3://{bucket_name}{parsed.path}' - - def _run_command(cmd, dry=True): _logger.info(cmd) if dry: diff --git a/requirements.txt b/requirements.txt index c8adceff8..43d0760de 100644 --- a/requirements.txt +++ b/requirements.txt @@ -26,7 +26,7 @@ tqdm>=4.32.1 ibl-neuropixel>=0.8.1 iblutil>=1.7.0 labcams # widefield extractor -ONE-api>=2.3 +ONE-api>=2.4 slidingRP>=1.0.0 # steinmetz lab refractory period metrics wfield==0.3.7 # widefield extractor frozen for now (2023/07/15) until Joao fixes latest version psychofit From eab63da38483d6e9dc5cc9167b5240b39ed9b9f0 Mon Sep 17 00:00:00 2001 From: k1o0 Date: Wed, 18 Oct 2023 16:02:57 +0300 Subject: [PATCH 28/68] mpciROIs.uuids (#663) * mpciROIs.uuids * Added tests --- brainbox/behavior/training.py | 5 +++-- ibllib/io/extractors/mesoscope.py | 1 + ibllib/pipes/mesoscope_tasks.py | 6 ++++++ ibllib/tests/extractors/test_ephys_trials.py | 4 ++++ ibllib/tests/test_mesoscope.py | 20 ++++++++++++++++++++ 5 files changed, 34 insertions(+), 2 deletions(-) diff --git a/brainbox/behavior/training.py b/brainbox/behavior/training.py index 4a247d819..2e6c9f9fd 100644 --- a/brainbox/behavior/training.py +++ b/brainbox/behavior/training.py @@ -83,7 +83,8 @@ class TrainingStatus(IntFlag): ... assert TrainingStatus[status.upper()] in ~TrainingStatus.FAILED, 'Subject untrained' ... assert TrainingStatus[status.upper()] in TrainingStatus.TRAINED ^ TrainingStatus.READY - # Get the next training status + Get the next training status + >>> next(member for member in sorted(TrainingStatus) if member > TrainingStatus[status.upper()]) @@ -91,7 +92,7 @@ class TrainingStatus(IntFlag): ----- - ~TrainingStatus.TRAINED means any status but trained 1a or trained 1b. - A subject may acheive both TRAINED_1A and TRAINED_1B within a single session, therefore it - is possible to have skipped the TRAINED_1A session status. + is possible to have skipped the TRAINED_1A session status. """ UNTRAINABLE = auto() UNBIASABLE = auto() diff --git a/ibllib/io/extractors/mesoscope.py b/ibllib/io/extractors/mesoscope.py index 78ed21674..4def5ed3a 100644 --- a/ibllib/io/extractors/mesoscope.py +++ b/ibllib/io/extractors/mesoscope.py @@ -280,6 +280,7 @@ def get_wheel_positions(self, ticks=WHEEL_TICKS, radius=WHEEL_RADIUS_CM, coding= moves = extract_wheel_moves(wheel['timestamps'], wheel['position']) if display: + assert self.bpod_trials, 'no bpod trials to compare' fig, (ax0, ax1) = plt.subplots(nrows=2, sharex=True) bpod_ts = self.bpod_trials['wheel_timestamps'] bpod_pos = self.bpod_trials['wheel_position'] diff --git a/ibllib/pipes/mesoscope_tasks.py b/ibllib/pipes/mesoscope_tasks.py index fee1a9c4a..683395431 100644 --- a/ibllib/pipes/mesoscope_tasks.py +++ b/ibllib/pipes/mesoscope_tasks.py @@ -216,6 +216,7 @@ def signature(self): ('mpciROIs.stackPos.npy', 'alf/FOV*', True), ('mpciROIs.mpciROITypes.npy', 'alf/FOV*', True), ('mpciROIs.cellClassifier.npy', 'alf/FOV*', True), + ('mpciROIs.uuids.csv', 'alf/FOV*', True), ('mpciROITypes.names.tsv', 'alf/FOV*', True), ('mpciROIs.masks.npy', 'alf/FOV*', True), ('mpciROIs.neuropilMasks.npy', 'alf/FOV*', True), @@ -328,6 +329,11 @@ def _rename_outputs(self, suite2p_dir, frameQC_names, frameQC, rename_dict=None) np.save(fov_dir.joinpath('mpciROIs.stackPos.npy'), np.asarray([(*s['med'], 0) for s in stat], dtype=int)) np.save(fov_dir.joinpath('mpciROIs.cellClassifier.npy'), np.asarray(iscell[:, 1], dtype=float)) np.save(fov_dir.joinpath('mpciROIs.mpciROITypes.npy'), np.asarray(iscell[:, 0], dtype=np.int16)) + # clusters uuids + uuid_list = ['uuids'] + list(map(str, [uuid.uuid4() for _ in range(len(iscell))])) + with open(fov_dir.joinpath('mpciROIs.uuids.csv'), 'w+') as fid: + fid.write('\n'.join(uuid_list)) + pd.DataFrame([(0, 'no cell'), (1, 'cell')], columns=['roi_values', 'roi_labels'] ).to_csv(fov_dir.joinpath('mpciROITypes.names.tsv'), sep='\t', index=False) # ROI and neuropil masks diff --git a/ibllib/tests/extractors/test_ephys_trials.py b/ibllib/tests/extractors/test_ephys_trials.py index ba49d31bb..d5483792f 100644 --- a/ibllib/tests/extractors/test_ephys_trials.py +++ b/ibllib/tests/extractors/test_ephys_trials.py @@ -90,6 +90,10 @@ def test_align_to_trial(self): desired_out = np.array([4, 13, np.nan, 33, np.nan]) self.assertTrue(np.allclose(desired_out, t_event_nans, equal_nan=True, atol=0, rtol=0)) + # test errors + self.assertRaises(ValueError, ephys_fpga._assign_events_to_trial, np.array([0., 2., 1.]), t_event) + self.assertRaises(ValueError, ephys_fpga._assign_events_to_trial, t_trial_start, np.array([0., 2., 1.])) + def test_wheel_trace_from_sync(self): pos_ = - np.array([-1, 0, -1, -2, -1, -2]) * (np.pi / ephys_fpga.WHEEL_TICKS) ta = np.array([1, 2, 3, 4, 5, 6]) diff --git a/ibllib/tests/test_mesoscope.py b/ibllib/tests/test_mesoscope.py index b828386db..7ae3e5cd4 100644 --- a/ibllib/tests/test_mesoscope.py +++ b/ibllib/tests/test_mesoscope.py @@ -4,6 +4,7 @@ from unittest import mock import tempfile import json +from itertools import chain from pathlib import Path from one.api import ONE @@ -11,6 +12,7 @@ from ibllib.pipes.mesoscope_tasks import MesoscopePreprocess, MesoscopeFOV, \ find_triangle, surface_normal, _nearest_neighbour_1d +from ibllib.io.extractors import mesoscope from ibllib.tests import TEST_DB # Mock suit2p which is imported in MesoscopePreprocess @@ -248,3 +250,21 @@ def tearDown(self) -> None: Here we return the mode back to the default after testing behaviour in offline mode. """ self.one.mode = 'auto' + + +class TestImagingMeta(unittest.TestCase): + """Test raw imaging metadata versioning.""" + def test_patch_imaging_meta(self): + """Test for ibllib.io.extractors.mesoscope.patch_imaging_meta function.""" + meta = {'version': '0.1.0', 'FOV': [{'roiUuid': None}, {'roiUUID': None}]} + new_meta = mesoscope.patch_imaging_meta(meta) + self.assertEqual(set(chain(*map(dict.keys, new_meta['FOV']))), {'roiUUID'}) + meta = {'FOV': [ + dict.fromkeys(['topLeftDeg', 'topRightDeg', 'bottomLeftDeg', 'bottomRightDeg']), + dict.fromkeys(['topLeftMM', 'topRightMM', 'bottomLeftMM', 'bottomRightMM']) + ]} + new_meta = mesoscope.patch_imaging_meta(meta) + self.assertIn('channelSaved', new_meta) + self.assertCountEqual(new_meta['FOV'][0], ('Deg', 'MM')) + expected = ('topLeft', 'topRight', 'bottomLeft', 'bottomRight') + self.assertCountEqual(new_meta['FOV'][0]['MM'], expected) From f8b7ad216e18b093f7c16662ed117961a45e7652 Mon Sep 17 00:00:00 2001 From: Gaelle Date: Wed, 18 Oct 2023 15:05:39 +0200 Subject: [PATCH 29/68] raw audio spectrogram example --- .../loading_data/loading_raw_audio_data.ipynb | 168 ++++++++++++++++++ 1 file changed, 168 insertions(+) create mode 100644 examples/loading_data/loading_raw_audio_data.ipynb diff --git a/examples/loading_data/loading_raw_audio_data.ipynb b/examples/loading_data/loading_raw_audio_data.ipynb new file mode 100644 index 000000000..cbd88aa03 --- /dev/null +++ b/examples/loading_data/loading_raw_audio_data.ipynb @@ -0,0 +1,168 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "5683982d", + "metadata": {}, + "source": [ + "# Loading Raw Audio Data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6b2485da", + "metadata": { + "nbsphinx": "hidden" + }, + "outputs": [], + "source": [ + "# Turn off logging, this is a hidden cell on docs page\n", + "import logging\n", + "logger = logging.getLogger('ibllib')\n", + "logger.setLevel(logging.CRITICAL)" + ] + }, + { + "cell_type": "markdown", + "id": "16345774", + "metadata": {}, + "source": [ + "The audio file is saved from the microphone. It is useful to look at it to plot a spectrogram and confirm the sounds played during the task are indeed audible." + ] + }, + { + "cell_type": "markdown", + "id": "8d62c890", + "metadata": {}, + "source": [ + "## Relevant datasets\n", + "* _iblrig_micData.raw.flac\n" + ] + }, + { + "cell_type": "markdown", + "id": "bc23fdf7", + "metadata": {}, + "source": [ + "## Loading" + ] + }, + { + "cell_type": "markdown", + "id": "9103084d", + "metadata": {}, + "source": [ + "### Loading raw audio file" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2b807296", + "metadata": { + "ibl_execute": false + }, + "outputs": [], + "source": [ + "from one.api import ONE\n", + "import soundfile as sf\n", + "\n", + "one = ONE()\n", + "eid = '4ecb5d24-f5cc-402c-be28-9d0f7cb14b3a'\n", + "\n", + "# -- Get raw data\n", + "filename = one.load_dataset(eid, '_iblrig_micData.raw.flac', download_only=True)\n", + "with open(filename, 'rb') as f:\n", + " wav, fs = sf.read(f)" + ] + }, + { + "cell_type": "markdown", + "id": "203d23c1", + "metadata": {}, + "source": [ + "## Plot the spectrogram" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "811e3533", + "metadata": { + "ibl_execute": false + }, + "outputs": [], + "source": [ + "from ibllib.io.extractors.training_audio import welchogram\n", + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "\n", + "# -- Compute spectrogram over first 2 minutes\n", + "t_idx = 120 * fs\n", + "tscale, fscale, W, detect = welchogram(fs, wav[:t_idx])\n", + "\n", + "# -- Put data into single variable\n", + "TF = {}\n", + "\n", + "TF['power'] = W.astype(np.single)\n", + "TF['frequencies'] = fscale[None, :].astype(np.single)\n", + "TF['onset_times'] = detect\n", + "TF['times_mic'] = tscale[:, None].astype(np.single)\n", + "\n", + "# # -- Plot spectrogram\n", + "tlims = TF['times_mic'][[0, -1]].flatten()\n", + "flims = TF['frequencies'][0, [0, -1]].flatten()\n", + "fig = plt.figure(figsize=[16, 7])\n", + "ax = plt.axes()\n", + "im = ax.imshow(20 * np.log10(TF['power'].T), aspect='auto', cmap=plt.get_cmap('magma'),\n", + " extent=np.concatenate((tlims, flims)),\n", + " origin='lower')\n", + "ax.set_xlabel(r'Time (s)')\n", + "ax.set_ylabel(r'Frequency (Hz)')\n", + "plt.colorbar(im)\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "bef6702e", + "metadata": {}, + "source": [ + "## More details\n", + "* [Description of audio datasets](https://docs.google.com/document/d/1OqIqqakPakHXRAwceYLwFY9gOrm8_P62XIfCTnHwstg/edit#heading=h.n61f0vdcplxp)" + ] + }, + { + "cell_type": "markdown", + "id": "4e9dd4b9", + "metadata": {}, + "source": [ + "## Useful modules\n", + "* [ibllib.io.extractors.training_audio](https://int-brain-lab.github.io/iblenv/_autosummary/ibllib.io.extractors.training_audio.html#module-ibllib.io.extractors.training_audio)" + ] + } + ], + "metadata": { + "celltoolbar": "Edit Metadata", + "kernelspec": { + "display_name": "Python [conda env:iblenv] *", + "language": "python", + "name": "conda-env-iblenv-py" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.7" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} From 183a68d7dee9d9e730c137f3090705f0621a1ede Mon Sep 17 00:00:00 2001 From: Gaelle Date: Wed, 18 Oct 2023 15:58:56 +0200 Subject: [PATCH 30/68] fix from brainbox.io.one import load_channel_locations --- examples/loading_data/loading_raw_ephys_data.ipynb | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/loading_data/loading_raw_ephys_data.ipynb b/examples/loading_data/loading_raw_ephys_data.ipynb index f8fd8ed37..42cfb3517 100644 --- a/examples/loading_data/loading_raw_ephys_data.ipynb +++ b/examples/loading_data/loading_raw_ephys_data.ipynb @@ -349,8 +349,8 @@ "execution_count": null, "outputs": [], "source": [ - "import brainbox\n", - "channels = brainbox.io.one.load_channel_locations(eid, probe)\n", + "from brainbox.io.one import load_channel_locations\n", + "channels = load_channel_locations(eid, probe)\n", "channels[probe][\"localCoordinates\"]" ], "metadata": { From 21233e34177d3863f28167960abebf6011c329b4 Mon Sep 17 00:00:00 2001 From: Mayo Faulkner Date: Thu, 19 Oct 2023 15:09:23 +0100 Subject: [PATCH 31/68] fix the docs --- .../loading_data/loading_raw_ephys_data.ipynb | 52 ++++++++++--------- 1 file changed, 27 insertions(+), 25 deletions(-) diff --git a/examples/loading_data/loading_raw_ephys_data.ipynb b/examples/loading_data/loading_raw_ephys_data.ipynb index 42cfb3517..9394b675e 100644 --- a/examples/loading_data/loading_raw_ephys_data.ipynb +++ b/examples/loading_data/loading_raw_ephys_data.ipynb @@ -328,55 +328,57 @@ }, { "cell_type": "markdown", + "id": "723df072", + "metadata": {}, "source": [ "## Get the probe geometry" - ], - "metadata": { - "collapsed": false - } + ] }, { "cell_type": "markdown", + "id": "35a0db60", + "metadata": {}, "source": [ "### Using the `eid` and `probe` information" - ], - "metadata": { - "collapsed": false - } + ] }, { "cell_type": "code", "execution_count": null, + "id": "c2335beb", + "metadata": {}, "outputs": [], "source": [ "from brainbox.io.one import load_channel_locations\n", - "channels = load_channel_locations(eid, probe)\n", - "channels[probe][\"localCoordinates\"]" - ], - "metadata": { - "collapsed": false - } + "from one.api import ONE\n", + "import spikeglx\n", + "import numpy as np\n", + "one = ONE()\n", + "\n", + "pid = 'da8dfec1-d265-44e8-84ce-6ae9c109b8bd'\n", + "eid, probe = one.pid2eid(pid)\n", + "channels = load_channel_locations(eid, probe)[probe]\n", + "channel_geometry = np.c_[channels['axial_um'], channels['lateral_um']]" + ] }, { "cell_type": "markdown", + "id": "360f668c", + "metadata": {}, "source": [ "### Using the reader and the `.cbin` file" - ], - "metadata": { - "collapsed": false - } + ] }, { "cell_type": "code", "execution_count": null, + "id": "763d6ac8", + "metadata": {}, "outputs": [], "source": [ "sr = spikeglx.Reader(bin_file)\n", "sr.geometry" - ], - "metadata": { - "collapsed": false - } + ] }, { "cell_type": "markdown", @@ -395,9 +397,9 @@ "metadata": { "celltoolbar": "Edit Metadata", "kernelspec": { - "display_name": "Python [conda env:iblenv] *", + "display_name": "Python 3 (ipykernel)", "language": "python", - "name": "conda-env-iblenv-py" + "name": "python3" }, "language_info": { "codemirror_mode": { @@ -409,7 +411,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.7" + "version": "3.9.16" } }, "nbformat": 4, From e761f676ef53b111730f6a3c705f26e05935f512 Mon Sep 17 00:00:00 2001 From: Miles Wells Date: Mon, 23 Oct 2023 15:09:36 +0300 Subject: [PATCH 32/68] Test GlobusPatcher --- ibllib/io/extractors/ephys_fpga.py | 11 +++- ibllib/io/extractors/mesoscope.py | 8 +-- ibllib/oneibl/data_handlers.py | 13 ++++- ibllib/oneibl/patcher.py | 84 +++++++++++++++++++----------- ibllib/oneibl/registration.py | 4 +- ibllib/pipes/__init__.py | 29 ++++++++--- ibllib/pipes/behavior_tasks.py | 1 + ibllib/qc/critical_reasons.py | 1 + ibllib/tests/test_oneibl.py | 83 ++++++++++++++++++++++++++++- 9 files changed, 188 insertions(+), 46 deletions(-) diff --git a/ibllib/io/extractors/ephys_fpga.py b/ibllib/io/extractors/ephys_fpga.py index 74ac1e551..eff8c9ab6 100644 --- a/ibllib/io/extractors/ephys_fpga.py +++ b/ibllib/io/extractors/ephys_fpga.py @@ -821,13 +821,20 @@ def _extract(self, sync=None, chmap=None, sync_collection='raw_ephys_data', self.bpod_trials.update(trials_table) self.bpod_trials['intervals_bpod'] = np.copy(self.bpod_trials['intervals']) + bpod = get_sync_fronts(sync, chmap['bpod']) # Get the spacer times for this protocol if (protocol_number := kwargs.get('protocol_number')) is not None: # look for spacer # The spacers are TTLs generated by Bpod at the start of each protocol - bpod = get_sync_fronts(sync, chmap['bpod']) tmin, tmax = get_protocol_period(self.session_path, protocol_number, bpod) else: - tmin = tmax = None + # Older sessions don't have protocol spacers so we sync the Bpod intervals here to + # find the approximate end time of the protocol (this will exclude the passive signals + # in ephysChoiceWorld that tend to ruin the final trial extraction). + t_trial_start, *_ = _assign_events_bpod(bpod['times'], bpod['polarities']) + bpod_start = self.bpod_trials['intervals_bpod'][:, 0] + fcn, *_ = neurodsp.utils.sync_timestamps(bpod_start, t_trial_start) + tmin = fcn(trials_table['intervals'][0, 0]) - 1 + tmax = fcn(trials_table['intervals'][-1, 1]) + 1 # Store the cleaned frame2ttl, audio, and bpod pulses as this will be used for QC fpga_trials, self.frame2ttl, self.audio, self.bpod = extract_behaviour_sync( diff --git a/ibllib/io/extractors/mesoscope.py b/ibllib/io/extractors/mesoscope.py index 561bb6343..57709032d 100644 --- a/ibllib/io/extractors/mesoscope.py +++ b/ibllib/io/extractors/mesoscope.py @@ -111,7 +111,6 @@ def _extract(self, sync=None, chmap=None, sync_collection='raw_sync_data', **kwa # If no protocol number is defined, trim timestamps based on Bpod trials intervals trials_table = trials['table'] - bpod = get_sync_fronts(sync, chmap['bpod']) if kwargs.get('protocol_number') is None: tmin = trials_table.intervals_0.iloc[0] - 1 tmax = trials_table.intervals_1.iloc[-1] @@ -122,15 +121,16 @@ def _extract(self, sync=None, chmap=None, sync_collection='raw_sync_data', **kwa mask = np.logical_and(trials['wheelMoves_intervals'][:, 0] >= tmin, trials['wheelMoves_intervals'][:, 0] <= tmax) trials['wheelMoves_intervals'] = trials['wheelMoves_intervals'][mask, :] else: + bpod = get_sync_fronts(sync, chmap['bpod']) tmin, tmax = get_protocol_period(self.session_path, kwargs['protocol_number'], bpod) - bpod = get_sync_fronts(sync, chmap['bpod'], tmin, tmax) + self.bpod = get_sync_fronts(sync, chmap['bpod'], tmin, tmax) self.frame2ttl = get_sync_fronts(sync, chmap['frame2ttl'], tmin, tmax) # save for later access by QC # Replace valve open times with those extracted from the DAQ # TODO Let's look at the expected open length based on calibration and reward volume - assert len(bpod['times']) > 0, 'No Bpod TTLs detected on DAQ' - _, driver_out, _, = _assign_events_bpod(bpod['times'], bpod['polarities'], False) + assert len(self.bpod['times']) > 0, 'No Bpod TTLs detected on DAQ' + _, driver_out, _, = _assign_events_bpod(self.bpod['times'], self.bpod['polarities'], False) # Use the driver TTLs to find the valve open times that correspond to the valve opening valve_open_times = self.get_valve_open_times(driver_ttls=driver_out) assert len(valve_open_times) == sum(trials_table.feedbackType == 1) # TODO Relax assertion diff --git a/ibllib/oneibl/data_handlers.py b/ibllib/oneibl/data_handlers.py index 6df9453a0..9616c89ae 100644 --- a/ibllib/oneibl/data_handlers.py +++ b/ibllib/oneibl/data_handlers.py @@ -1,3 +1,9 @@ +"""Downloading of task dependent datasets and registration of task output datasets. + +The DataHandler class is used by the pipes.tasks.Task class to ensure dependent datasets are +present and to register and upload the output datasets. For examples on how to run a task using +specific data handlers, see :func:`ibllib.pipes.tasks`. +""" import logging import pandas as pd from pathlib import Path @@ -231,7 +237,10 @@ def uploadData(self, outputs, version, **kwargs): class RemoteAwsDataHandler(DataHandler): def __init__(self, task, session_path, signature, one=None): """ - Data handler for running tasks on remote compute node. Will download missing data from private ibl s3 AWS data bucket + Data handler for running tasks on remote compute node. + + This will download missing data from the private IBL S3 AWS data bucket. New datasets are + uploaded via Globus. :param session_path: path to session :param signature: input and output file signatures @@ -330,7 +339,7 @@ def cleanUp(self): class RemoteGlobusDataHandler(DataHandler): """ - Data handler for running tasks on remote compute node. Will download missing data using globus + Data handler for running tasks on remote compute node. Will download missing data using Globus. :param session_path: path to session :param signature: input and output file signatures diff --git a/ibllib/oneibl/patcher.py b/ibllib/oneibl/patcher.py index d34691710..5cbb1689d 100644 --- a/ibllib/oneibl/patcher.py +++ b/ibllib/oneibl/patcher.py @@ -1,3 +1,25 @@ +"""A module for ad-hoc dataset modification and registration. + +Unlike the DataHandler class in oneibl.data_handlers, the Patcher class allows one to fully remove +datasets (delete them from the database and repositories), and to overwrite datasets on both the +main repositories and the local servers. Additionally the Patcher can handle datasets from +multiple sessions at once. + +Examples +-------- +Delete a dataset from Alyx and all associated repositories. + +>>> dataset_id = 'f4aafe6c-a7ab-4390-82cd-2c0e245322a5' +>>> task_ids, files_by_repo = IBLGlobusPatcher(AlyxClient(), 'admin').delete_dataset(dataset_id) + +Patch some local datasets using Globus + +>>> from one.api import ONE +>>> patcher = GlobusPatcher('admin', ONE(), label='UCLA audio times patch') +>>> responses = patcher.patch_datasets(file_paths) # register the new datasets to Alyx +>>> patcher.launch_transfers(local_servers=True) # transfer to all remote repositories + +""" import abc import ftplib from pathlib import Path, PurePosixPath, WindowsPath @@ -18,6 +40,7 @@ from one.converters import path_from_dataset from one.remote import globus from one.remote.aws import url2uri +from one.util import ensure_list from ibllib.oneibl.registration import register_dataset @@ -58,7 +81,7 @@ def sdsc_path_from_dataset(dset, root_path=SDSC_ROOT_PATH): """ Returns sdsc file path from a dset record or a list of dsets records from REST :param dset: dset dictionary or list of dictionaries from ALyx rest endpoint - :param root_path: (optional) the prefix path such as one download directory or sdsc root + :param root_path: (optional) the prefix path such as one download directory or SDSC root """ return path_from_dataset(dset, root_path=root_path, uuid=True) @@ -68,7 +91,7 @@ def globus_path_from_dataset(dset, repository=None, uuid=False): Returns local one file path from a dset record or a list of dsets records from REST :param dset: dset dictionary or list of dictionaries from ALyx rest endpoint :param repository: (optional) repository name of the file record (if None, will take - the first filerecord with an URL) + the first filerecord with a URL) """ return path_from_dataset(dset, root_path=PurePosixPath('/'), repository=repository, uuid=uuid) @@ -91,7 +114,7 @@ def _patch_dataset(self, path, dset_id=None, dry=False, ftp=False): assert dset_id assert is_uuid_string(dset_id) assert path.exists() - dset = self.one.alyx.rest('datasets', "read", id=dset_id) + dset = self.one.alyx.rest('datasets', 'read', id=dset_id) fr = next(fr for fr in dset['file_records'] if 'flatiron' in fr['data_repository']) remote_path = Path(fr['data_repository_path']).joinpath(fr['relative_path']) remote_path = add_uuid_string(remote_path, dset_id).as_posix() @@ -134,7 +157,7 @@ def register_datasets(self, file_list, **kwargs): nses = len(register_dict) for i, label in enumerate(register_dict): _files = register_dict[label]['files'] - _logger.info(f"{i}/{nses} {label}, registering {len(_files)} files") + _logger.info(f"{i + 1}/{nses} {label}, registering {len(_files)} files") responses.append(self.register_dataset(_files, **kwargs)) return responses @@ -157,7 +180,7 @@ def patch_dataset(self, file_list, dry=False, ftp=False, **kwargs): file_list = [Path(file_list)] assert len(set([get_session_path(f) for f in file_list])) == 1 assert all([Path(f).exists() for f in file_list]) - response = self.register_dataset(file_list, dry=dry, **kwargs) + response = ensure_list(self.register_dataset(file_list, dry=dry, **kwargs)) if dry: return # from the dataset info, set flatIron flag to exists=True @@ -182,7 +205,7 @@ def patch_datasets(self, file_list, **kwargs): nses = len(register_dict) for i, label in enumerate(register_dict): _files = register_dict[label]['files'] - _logger.info(f"{i}/{nses} {label}, registering {len(_files)} files") + _logger.info(f'{i + 1}/{nses} {label}, registering {len(_files)} files') responses.extend(self.patch_dataset(_files, **kwargs)) return responses @@ -195,27 +218,28 @@ def _rm(self, *args, **kwargs): pass -class GlobusPatcher(Patcher): +class GlobusPatcher(Patcher, globus.Globus): """ Requires GLOBUS keys access """ def __init__(self, client_name='default', one=None, label='ibllib patch'): - assert one - self.globus = globus.Globus(client_name) + assert one and not one.offline + Patcher.__init__(self, one=one) + globus.Globus.__init__(self, client_name) self.label = label # get a dictionary of data repositories from Alyx (with globus ids) - self.globus.fetch_endpoints_from_alyx(one.alyx) - flatiron_id = self.globus.endpoints['flatiron_cortexlab']['id'] - if 'flatiron' not in self.globus.endpoints: - self.globus.add_endpoint(flatiron_id, 'flatiron', root_path='/') - self.globus.endpoints['flatiron'] = self.globus.endpoints['flatiron_cortexlab'] + self.fetch_endpoints_from_alyx(one.alyx) + flatiron_id = self.endpoints['flatiron_cortexlab']['id'] + if 'flatiron' not in self.endpoints: + self.add_endpoint(flatiron_id, 'flatiron', root_path='/') + self.endpoints['flatiron'] = self.endpoints['flatiron_cortexlab'] # transfers/delete from the current computer to the flatiron: mandatory and executed first - local_id = self.globus.endpoints['local']['id'] + local_id = self.endpoints['local']['id'] self.globus_transfer = globus_sdk.TransferData( - self.globus.client, local_id, flatiron_id, verify_checksum=True, sync_level='checksum', label=label) - self.globus_delete = globus_sdk.DeleteData(self.globus.client, flatiron_id, label=label) + self.client, local_id, flatiron_id, verify_checksum=True, sync_level='checksum', label=label) + self.globus_delete = globus_sdk.DeleteData(self.client, flatiron_id, label=label) # transfers/delete from flatiron to optional third parties to synchronize / delete self.globus_transfers_locals = {} self.globus_deletes_locals = {} @@ -227,18 +251,18 @@ def _scp(self, local_path, remote_path, dry=True): ) _logger.info(f"Globus copy {local_path} to {remote_path}") if not dry: - if isinstance(self.globus_transfer, globus_sdk.transfer.data.TransferData): - self.globus_transfer.add_item(local_path, remote_path) + if isinstance(self.globus_transfer, globus_sdk.TransferData): + self.globus_transfer.add_item(local_path, remote_path.as_posix()) else: self.globus_transfer.path_src.append(local_path) - self.globus_transfer.path_dest.append(remote_path) + self.globus_transfer.path_dest.append(remote_path.as_posix()) return 0, '' def _rm(self, flatiron_path, dry=True): flatiron_path = Path('/').joinpath(flatiron_path.relative_to(Path(FLATIRON_MOUNT))) _logger.info(f'Globus del {flatiron_path}') if not dry: - if isinstance(self.globus_delete, globus_sdk.transfer.data.DeleteData): + if isinstance(self.globus_delete, globus_sdk.DeleteData): self.globus_delete.add_item(flatiron_path) else: self.globus_delete.path.append(flatiron_path) @@ -258,23 +282,25 @@ def patch_datasets(self, file_list, **kwargs): # get the flatiron path fr = next(fr for fr in dset['file_records'] if 'flatiron' in fr['data_repository']) relative_path = add_uuid_string(fr['relative_path'], dset['id']).as_posix() - flatiron_path = self.globus.to_address(relative_path, fr['data_repository']) + flatiron_path = self.to_address(relative_path, fr['data_repository']) # loop over the remaining repositories (local servers) and create a transfer # from flatiron to the local server for fr in dset['file_records']: if fr['data_repository'] == DMZ_REPOSITORY: continue - repo_gid = self.globus.endpoints[fr['data_repository']]['id'] - flatiron_id = self.globus.endpoints['flatiron']['id'] + if fr['data_repository'] not in self.endpoints: + continue + repo_gid = self.endpoints[fr['data_repository']]['id'] + flatiron_id = self.endpoints['flatiron']['id'] if repo_gid == flatiron_id: continue # if there is no transfer already created, initialize it if repo_gid not in self.globus_transfers_locals: self.globus_transfers_locals[repo_gid] = globus_sdk.TransferData( - self.globus.client, flatiron_id, repo_gid, verify_checksum=True, + self.client, flatiron_id, repo_gid, verify_checksum=True, sync_level='checksum', label=f"{self.label} on {fr['data_repository']}") # get the local server path and create the transfer item - local_server_path = self.globus.to_address(fr['relative_path'], fr['data_repository']) + local_server_path = self.to_address(fr['relative_path'], fr['data_repository']) self.globus_transfers_locals[repo_gid].add_item(flatiron_path, local_server_path) return responses @@ -285,7 +311,7 @@ def launch_transfers(self, local_servers=False): :param: local_servers (False): if True, sync the local servers after the main transfer :return: None """ - gtc = self.globus.client + gtc = self.client def _wait_for_task(resp): # patcher.transfer_client.get_task(task_id='364fbdd2-4deb-11eb-8ffb-0a34088e79f9') @@ -339,11 +365,11 @@ def launch_transfers_secondary(self): for lt in self.globus_transfers_locals: transfer = self.globus_transfers_locals[lt] if len(transfer['DATA']) > 0: - self.globus.client.submit_transfer(transfer) + self.client.submit_transfer(transfer) for ld in self.globus_deletes_locals: delete = self.globus_deletes_locals[ld] if len(transfer['DATA']) > 0: - self.globus.client.submit_delete(delete) + self.client.submit_delete(delete) class IBLGlobusPatcher(Patcher, globus.Globus): diff --git a/ibllib/oneibl/registration.py b/ibllib/oneibl/registration.py index 0996f01e0..62d4ee4d0 100644 --- a/ibllib/oneibl/registration.py +++ b/ibllib/oneibl/registration.py @@ -4,7 +4,7 @@ import logging import itertools -from pkg_resources import parse_version +from packaging import version from one.alf.files import get_session_path, folder_parts, get_alf_path from one.registration import RegistrationClient, get_dataset_type from one.remote.globus import get_local_endpoint_id, get_lab_from_endpoint_id @@ -351,7 +351,7 @@ def _alyx_procedure_from_task_type(task_type): def rename_files_compatibility(ses_path, version_tag): if not version_tag: return - if parse_version(version_tag) <= parse_version('3.2.3'): + if version.parse(version_tag) <= version.parse('3.2.3'): task_code = ses_path.glob('**/_ibl_trials.iti_duration.npy') for fn in task_code: fn.replace(fn.parent.joinpath('_ibl_trials.itiDuration.npy')) diff --git a/ibllib/pipes/__init__.py b/ibllib/pipes/__init__.py index 2b68cdb04..1f405a7c3 100644 --- a/ibllib/pipes/__init__.py +++ b/ibllib/pipes/__init__.py @@ -7,17 +7,34 @@ def assign_task(task_deck, session_path, task, **kwargs): """ + Assigns a task to a task deck with the task name as key. + + This is a convenience function when creating a large task deck. Parameters ---------- - task_deck : - session_path - task - kwargs + task_deck : dict + A dictionary of tasks to add to. + session_path : str, pathlib.Path + A session path to pass to the task. + task : ibllib.pipes.tasks.Task + A task class to instantiate and assign. + **kwargs + Optional keyword arguments to pass to the task. + + Examples + -------- + >>> from ibllib.pipes.video_tasks import VideoCompress + >>> task_deck = {} + >>> session_path = './subject/2023-01-01/001' + >>> assign_task(task_deck, session_path, VideoCompress, cameras=('left',)) + {'VideoCompress': } - Returns - ------- + Using partial for convenience + >>> from functools import partial + >>> assign = partial(assign_task, task_deck, session_path) + >>> assign(VideoCompress, cameras=('left',)) """ t = task(session_path, **kwargs) task_deck[t.name] = t diff --git a/ibllib/pipes/behavior_tasks.py b/ibllib/pipes/behavior_tasks.py index 6f1c8d506..e1026573d 100644 --- a/ibllib/pipes/behavior_tasks.py +++ b/ibllib/pipes/behavior_tasks.py @@ -378,6 +378,7 @@ def _run_qc(self, trials_data=None, update=False, plot_qc=False): qc_extractor.wheel_encoding = 'X4' qc_extractor.frame_ttls = self.extractor.frame2ttl qc_extractor.audio_ttls = self.extractor.audio + qc_extractor.bpod_ttls = self.extractor.bpod # used only in iblapps task QC viewer qc.extractor = qc_extractor # Aggregate and update Alyx QC fields diff --git a/ibllib/qc/critical_reasons.py b/ibllib/qc/critical_reasons.py index bf50bda20..d964abd53 100644 --- a/ibllib/qc/critical_reasons.py +++ b/ibllib/qc/critical_reasons.py @@ -563,6 +563,7 @@ class TaskSignOffNote(SignOffNote): 'raw trial data does not exist', 'wheel data corrupt', 'task data could not be synced', + 'stimulus timings unreliable' ] diff --git a/ibllib/tests/test_oneibl.py b/ibllib/tests/test_oneibl.py index aa9483d6b..180b14027 100644 --- a/ibllib/tests/test_oneibl.py +++ b/ibllib/tests/test_oneibl.py @@ -1,11 +1,13 @@ import unittest -import tempfile from unittest import mock +import tempfile from pathlib import PurePosixPath, Path import json import datetime import random import string +import uuid +from itertools import chain from requests import HTTPError import numpy as np @@ -78,6 +80,85 @@ def mock_input(prompt): return FTP_pars[next(k for k in FTP_pars.keys() if k in prompt.replace(',', '').split())] +class TestGlobusPatcher(unittest.TestCase): + """Tests for the ibllib.oneibl.patcher.GlobusPatcher class.""" + + globus_sdk_mock = None + """unittest.mock._patch: Mock object for globus_sdk package.""" + + @mock.patch('one.remote.globus._setup') + def setUp(self, _) -> None: + # Create a temp dir for writing datasets to + self.tempdir = tempfile.TemporaryDirectory() + # The github CI root dir contains an alias/symlink so we must resolve it + self.root_path = Path(self.tempdir.name).resolve() + self.addCleanup(self.tempdir.cleanup) + # Mock the Globus setup process so the parameters aren't overwritten + self.pars = iopar.from_dict({ + 'GLOBUS_CLIENT_ID': '123', + 'refresh_token': '456', + 'local_endpoint': str(uuid.uuid1()), + 'local_path': str(self.root_path), + 'access_token': 'abc', + 'expires_at_seconds': datetime.datetime.now().timestamp() + 60**2 + }) + # Mock the globus SDK so that no actual tasks are submitted + self.globus_sdk_mock = mock.patch('one.remote.globus.globus_sdk') + self.globus_sdk_mock.start() + self.addCleanup(self.globus_sdk_mock.stop) + self.one = ONE(**TEST_DB) + with mock.patch('one.remote.globus.load_client_params', return_value=self.pars): + self.globus_patcher = patcher.GlobusPatcher(one=self.one) + + def test_patch_datasets(self): + """Tests for GlobusPatcher.patch_datasets and GlobusPatcher.launch_transfers methods.""" + # Create a couple of datasets to patch + file_list = ['ZFM-01935/2021-02-05/001/alf/_ibl_wheelMoves.intervals.npy', + 'ZM_1743/2019-06-14/001/alf/_ibl_wheel.position.npy'] + dids = ['80fabd30-9dc8-4778-b349-d175af63e1bd', 'fede964f-55cd-4267-95e0-327454e68afb'] + # These exist on the test database, so get their info in order to mock registration response + for r in (responses := self.one.alyx.rest('datasets', 'list', django=f'pk__in,{dids}')): + r['id'] = r['url'].split('/')[-1] + assert len(responses) == 2, f'one or both datasets {dids} not on database' + # Create the files on disk + for file in (file_list := list(map(self.root_path.joinpath, file_list))): + file.parent.mkdir(exist_ok=True, parents=True) + file.touch() + + # Mock the post method of AlyxClient and assert that it was called during registration + with mock.patch.object(self.one.alyx, 'post') as rest_mock: + rest_mock.side_effect = responses + self.globus_patcher.patch_datasets(file_list) + self.assertEqual(rest_mock.call_count, 2) + for call, file in zip(rest_mock.call_args_list, file_list): + self.assertEqual(call.args[0], '/register-file') + path = file.relative_to(self.root_path).as_posix() + self.assertTrue(path.startswith(call.kwargs['data']['path'])) + self.assertTrue(path.endswith(call.kwargs['data']['filenames'][0])) + + # Check whether the globus transfers were updated + self.assertIsNotNone(self.globus_patcher.globus_transfer) + transfer_data = self.globus_patcher.globus_transfer['DATA'] + self.assertEqual(len(transfer_data), len(file_list)) + for data, file, did in zip(transfer_data, file_list, dids): + path = file.relative_to(self.root_path).as_posix() + self.assertTrue(data['source_path'].endswith(path)) + self.assertIn(did, data['destination_path'], 'failed to add UUID to destination file name') + + # Check added local server transfers + self.assertTrue(len(self.globus_patcher.globus_transfers_locals)) + transfer_data = list(chain(*[x['DATA'] for x in self.globus_patcher.globus_transfers_locals.values()])) + for data, file, did in zip(transfer_data, file_list, dids): + path = file.relative_to(self.root_path).as_posix() + self.assertEqual(data['destination_path'], '/mnt/s0/Data/Subjects/' + path) + self.assertIn(did, data['source_path'], 'failed to add UUID to source file name') + + # Check behaviour when tasks submitted + self.globus_patcher.client.get_task.return_value = {'completion_time': 0, 'fatal_error': None} + self.globus_patcher.launch_transfers(local_servers=True) + self.globus_patcher.client.submit_transfer.assert_called() + + class TestAlyx2Path(unittest.TestCase): dset = { 'url': 'https://alyx.internationalbrainlab.org/' From 154d4cd6c1d1b79082a79f51db82ef7e8154a636 Mon Sep 17 00:00:00 2001 From: Miles Wells Date: Mon, 23 Oct 2023 15:25:05 +0300 Subject: [PATCH 33/68] Moved Mayo's examples to tasks.py module --- ibllib/pipes/tasks.py | 72 ++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 71 insertions(+), 1 deletion(-) diff --git a/ibllib/pipes/tasks.py b/ibllib/pipes/tasks.py index 25a645385..670e2e8fa 100644 --- a/ibllib/pipes/tasks.py +++ b/ibllib/pipes/tasks.py @@ -1,4 +1,74 @@ -"""The abstract Pipeline and Task superclasses and concrete task runner.""" +"""The abstract Pipeline and Task superclasses and concrete task runner. + +Examples +-------- + +1. Running a task on your local computer. +| Download: via ONE. +| Upload: N/A. + +>>> task = VideoSyncQcBpod(session_path, one=one, location='remote', sync='bpod') +>>> task.run() + +2. Running a task on the local server that belongs to a given subject (e.g SWC054 on floferlab). +| Download: all data expected to be present. +| Upload: normal way of registering datasets, filerecords created and bulk sync, bulk transfer + jobs on Alyx transfer the data. + +>>> from ibllib.pipes.video_tasks import VideoSyncQcBpod +>>> session_path = '/mnt/ibl/s0/Data/Subjects/SWC054/2023-01-01/001' +>>> task = VideoSyncQcBpod(session_path, one=one, sync='bpod') +>>> task.run() +>>> task.register_datasets(one=one, labs=get_lab(session_path, alyx=ONE().alyx)) + +3. Running a task on the local server that belongs to that subject and forcing redownload of +missing data. +| Download: via Globus (TODO we should change this to use boto3 as globus is slow). +| Upload: normal way of registering datasets, filerecords created and bulk sync, bulk transfer + jobs on Alyx transfer the data. + +>>> task = VideoSyncQcBpod(session_path, one=one, sync='bpod') +>>> task.force = True +>>> task.run() +>>> task.register_datasets(one=one, labs=get_lab(session_path, alyx=ONE().alyx)) +>>> task.cleanUp() # Delete the files that have been downloaded + +4. Running a task on the local server that doesn't belongs to that subject +(e.g SWC054 on angelakilab). +| Download: via boto3, the AWS file records must exist and be set to exists = True. +| Upload: via globus, automatically uploads the datasets directly to FlatIron via globus. + Creates FlatIron filerecords and sets these to True once the globus task has completed. + +>>> task = VideoSyncQcBpod(session_path, one=one, location='AWS', sync='bpod') +>>> task.run() +>>> task.register_datasets() +>>> task.cleanUp() # Delete the files that have been downloaded + +5. Running a task on SDSC. +| Download: via creating symlink to relevant datasets on SDSC. +| Upload: via copying files to relevant location on SDSC. + +>>> task = VideoSyncQcBpod(session_path, one=one, location='SDSC', sync='bpod') +>>> task.run() +>>> response = task.register_datasets() +>>> # Here we just make sure filerecords are all correct +>>> for resp in response: +... fi = next((fr for fr in resp['file_records'] if 'flatiron' in fr['data_repository']), None) +... if fi is not None: +... if not fi['exists']: +... one.alyx.rest('files', 'partial_update', id=fi['id'], data={'exists': True}) +... +... aws = next((fr for fr in resp['file_records'] if 'aws' in fr['data_repository']), None) +... if aws is not None: +... one.alyx.rest('files', 'partial_update', id=aws['id'], data={'exists': False}) +... +... sr = next((fr for fr in resp['file_records'] if 'SR' in fr['data_repository']), None) +... if sr is not None: +... one.alyx.rest('files', 'partial_update', id=sr['id'], data={'exists': False}) +... # Finally remove symlinks once the task has completed +... task.cleanUp() + +""" from pathlib import Path import abc import logging From 148c4c2616244db255802e8d52d92f98c5bb5cad Mon Sep 17 00:00:00 2001 From: Miles Wells Date: Mon, 23 Oct 2023 16:44:29 +0300 Subject: [PATCH 34/68] Ensure paths are str in GlobusPatcher; ignore tmin, tmax if no trials start pulses detected --- ibllib/io/extractors/ephys_fpga.py | 9 ++++++--- ibllib/oneibl/patcher.py | 5 +++-- ibllib/tests/qc/test_critical_reasons.py | 2 +- 3 files changed, 10 insertions(+), 6 deletions(-) diff --git a/ibllib/io/extractors/ephys_fpga.py b/ibllib/io/extractors/ephys_fpga.py index eff8c9ab6..ee647e193 100644 --- a/ibllib/io/extractors/ephys_fpga.py +++ b/ibllib/io/extractors/ephys_fpga.py @@ -832,9 +832,12 @@ def _extract(self, sync=None, chmap=None, sync_collection='raw_ephys_data', # in ephysChoiceWorld that tend to ruin the final trial extraction). t_trial_start, *_ = _assign_events_bpod(bpod['times'], bpod['polarities']) bpod_start = self.bpod_trials['intervals_bpod'][:, 0] - fcn, *_ = neurodsp.utils.sync_timestamps(bpod_start, t_trial_start) - tmin = fcn(trials_table['intervals'][0, 0]) - 1 - tmax = fcn(trials_table['intervals'][-1, 1]) + 1 + if len(t_trial_start) > len(bpod_start) / 2: + fcn, *_ = neurodsp.utils.sync_timestamps(bpod_start, t_trial_start) + tmin = fcn(trials_table['intervals'][0, 0]) - 1 + tmax = fcn(trials_table['intervals'][-1, 1]) + 1 + else: # This type of alignment fails for some sessions, e.g. mesoscope + tmin = tmax = None # Store the cleaned frame2ttl, audio, and bpod pulses as this will be used for QC fpga_trials, self.frame2ttl, self.audio, self.bpod = extract_behaviour_sync( diff --git a/ibllib/oneibl/patcher.py b/ibllib/oneibl/patcher.py index 5cbb1689d..c5aa12975 100644 --- a/ibllib/oneibl/patcher.py +++ b/ibllib/oneibl/patcher.py @@ -250,6 +250,7 @@ def _scp(self, local_path, remote_path, dry=True): remote_path.relative_to(PurePosixPath(FLATIRON_MOUNT)) ) _logger.info(f"Globus copy {local_path} to {remote_path}") + local_path = globus.as_globus_path(local_path) if not dry: if isinstance(self.globus_transfer, globus_sdk.TransferData): self.globus_transfer.add_item(local_path, remote_path.as_posix()) @@ -263,9 +264,9 @@ def _rm(self, flatiron_path, dry=True): _logger.info(f'Globus del {flatiron_path}') if not dry: if isinstance(self.globus_delete, globus_sdk.DeleteData): - self.globus_delete.add_item(flatiron_path) + self.globus_delete.add_item(flatiron_path.as_posix()) else: - self.globus_delete.path.append(flatiron_path) + self.globus_delete.path.append(flatiron_path.as_posix()) return 0, '' def patch_datasets(self, file_list, **kwargs): diff --git a/ibllib/tests/qc/test_critical_reasons.py b/ibllib/tests/qc/test_critical_reasons.py index 1c6c44f00..0d1ae619d 100644 --- a/ibllib/tests/qc/test_critical_reasons.py +++ b/ibllib/tests/qc/test_critical_reasons.py @@ -18,7 +18,7 @@ def mock_input(prompt): if "Select from this list the reason(s)" in prompt: - return "1,3" + return "1," + prompt[prompt.index(') Other') - 1] # always choose last option, 'Other' elif "Explain why you selected" in prompt: return "estoy un poco preocupada" elif "You are about to delete" in prompt: From 47b5eff1ce12eaa5e6db5df0b44a223f4c65f9c1 Mon Sep 17 00:00:00 2001 From: Miles Wells Date: Tue, 24 Oct 2023 13:03:50 +0300 Subject: [PATCH 35/68] remove deprecated psychofit module; fix tests --- brainbox/behavior/pyschofit.py | 306 ----------------------- brainbox/behavior/training.py | 1 - ibllib/tests/qc/test_critical_reasons.py | 22 +- 3 files changed, 11 insertions(+), 318 deletions(-) delete mode 100644 brainbox/behavior/pyschofit.py diff --git a/brainbox/behavior/pyschofit.py b/brainbox/behavior/pyschofit.py deleted file mode 100644 index 4162eb90d..000000000 --- a/brainbox/behavior/pyschofit.py +++ /dev/null @@ -1,306 +0,0 @@ -""" -(DEPRECATED) The psychofit toolbox contains tools to fit two-alternative psychometric -data. The fitting is done using maximal likelihood estimation: one -assumes that the responses of the subject are given by a binomial -distribution whose mean is given by the psychometric function. -The data can be expressed in fraction correct (from .5 to 1) or in -fraction of one specific choice (from 0 to 1). To fit them you can use -these functions: - - weibull50: Weibull function from 0.5 to 1, with lapse rate - - weibull: Weibull function from 0 to 1, with lapse rate - - erf_psycho: erf function from 0 to 1, with lapse rate - - erf_psycho_2gammas: erf function from 0 to 1, with two lapse rates -Functions in the toolbox are: - - mle_fit_psycho: Maximumum likelihood fit of psychometric function - - neg_likelihood: Negative likelihood of a psychometric function -For more info, see: - - Examples: Examples of use of psychofit toolbox -Matteo Carandini, 2000-2015 - -NB: USE THE PSYCHOFIT PIP PACKAGE INSTEAD. -""" - -import functools -import warnings -import traceback -import logging - -import numpy as np -import scipy.optimize -from scipy.special import erf - - -for line in traceback.format_stack(): - print(line.strip()) - -msg = 'brainbox.behavior.pyschofit has been deprecated. Install psychofit via pip. See stack above' -warnings.warn(msg, DeprecationWarning) -logging.getLogger(__name__).warning(msg) - - -def mle_fit_psycho(data, P_model='weibull', parstart=None, parmin=None, parmax=None, nfits=5): - """ - Maximumum likelihood fit of psychometric function. - Args: - data: 3 x n matrix where first row corresponds to stim levels, - the second to number of trials for each stim level (int), - the third to proportion correct / proportion rightward (float between 0 and 1) - P_model: The psychometric function. Possibilities include 'weibull' - (DEFAULT), 'weibull50', 'erf_psycho' and 'erf_psycho_2gammas' - parstart: Non-zero starting parameters, used to try to avoid local minima. - The parameters are [threshold, slope, gamma], or if using the - 'erf_psycho_2gammas' model append a second gamma value. - Recommended to use a value > 1. If None, some reasonable defaults are used. - parmin: Minimum parameter values. If None, some reasonable defaults are used - parmax: Maximum parameter values. If None, some reasonable defaults are used - nfits: The number of fits - Returns: - pars: The parameters from the best of the fits - L: The likelihood of the best fit - Raises: - TypeError: data must be a list or numpy array - ValueError: data must be m by 3 matrix - Examples: - Below we fit a Weibull function to some data: - >>> import numpy as np - >>> import matplotlib.pyplot as plt - >>> cc = np.array([-8., -6., -4., -2., 0., 2., 4., 6., 8.]) # contrasts - >>> nn = np.full((9,), 10) # number of trials at each contrast - >>> pp = np.array([5., 8., 20., 41., 54., 59., 79., 92., 96])/100 # proportion "rightward" - >>> pars, L = mle_fit_psycho(np.vstack((cc, nn, pp)), 'erf_psycho') - >>> plt.plot(cc, pp, 'bo', mfc='b') - >>> plt.plot(np.arange(-8, 8, 0.1), erf_psycho(pars, np.arange(-8, 8, 0.1)), '-b') - Information: - 1999-11 FH wrote it - 2000-01 MC cleaned it up - 2000-04 MC took care of the 50% case - 2009-12 MC replaced fmins with fminsearch - 2010-02 MC, AZ added nfits - 2013-02 MC+MD fixed bug with dealing with NaNs - 2018-08 MW ported to Python - """ - # Input validation - if isinstance(data, (list, tuple)): - data = np.array(data) - elif not isinstance(data, np.ndarray): - raise TypeError('data must be a list or numpy array') - - if data.shape[0] != 3: - raise ValueError('data must be m by 3 matrix') - - rep = lambda x: (x, x) if P_model.endswith('2gammas') else (x,) # noqa - if parstart is None: - parstart = np.array([np.mean(data[0, :]), 3., *rep(.05)]) - if parmin is None: - parmin = np.array([np.min(data[0, :]), 0., *rep(0.)]) - if parmax is None: - parmax = np.array([np.max(data[0, :]), 10., *rep(.4)]) - - # find the good values in pp (conditions that were effectively run) - ii = np.isfinite(data[2, :]) - - likelihoods = np.zeros(nfits,) - pars = np.empty((nfits, parstart.size)) - - f = functools.partial(neg_likelihood, data=data[:, ii], - P_model=P_model, parmin=parmin, parmax=parmax) - for ifit in range(nfits): - pars[ifit, :] = scipy.optimize.fmin(f, parstart, disp=False) - parstart = parmin + np.random.rand(parmin.size) * (parmax - parmin) - likelihoods[ifit] = -neg_likelihood(pars[ifit, :], data[:, ii], P_model, parmin, parmax) - - # the values to be output - L = likelihoods.max() - iBestFit = likelihoods.argmax() - return pars[iBestFit, :], L - - -def neg_likelihood(pars, data, P_model='weibull', parmin=None, parmax=None): - """ - Negative likelihood of a psychometric function. - Args: - pars: Model parameters [threshold, slope, gamma], or if - using the 'erf_psycho_2gammas' model append a second gamma value. - data: 3 x n matrix where first row corresponds to stim levels, - the second to number of trials for each stim level (int), - the third to proportion correct / proportion rightward (float between 0 and 1) - P_model: The psychometric function. Possibilities include 'weibull' - (DEFAULT), 'weibull50', 'erf_psycho' and 'erf_psycho_2gammas' - parmin: Minimum bound for parameters. If None, some reasonable defaults are used - parmax: Maximum bound for parameters. If None, some reasonable defaults are used - Returns: - ll: The likelihood of the parameters. The equation is: - - sum(nn.*(pp.*log10(P_model)+(1-pp).*log10(1-P_model))) - See the the appendix of Watson, A.B. (1979). Probability - summation over time. Vision Res 19, 515-522. - Raises: - ValueError: invalid model, options are "weibull", - "weibull50", "erf_psycho" and "erf_psycho_2gammas" - TypeError: data must be a list or numpy array - ValueError data must be m by 3 matrix - Information: - 1999-11 FH wrote it - 2000-01 MC cleaned it up - 2000-07 MC made it indep of Weibull and added parmin and parmax - 2018-08 MW ported to Python - """ - # Validate input - if isinstance(data, (list, tuple)): - data = np.array(data) - elif not isinstance(data, np.ndarray): - raise TypeError('data must be a list or numpy array') - - if parmin is None: - parmin = np.array([.005, 0., 0.]) - if parmax is None: - parmax = np.array([.5, 10., .25]) - - if data.shape[0] == 3: - xx = data[0, :] - nn = data[1, :] - pp = data[2, :] - else: - raise ValueError('data must be m by 3 matrix') - - # here is where you effectively put the constraints. - if (any(pars < parmin)) or (any(pars > parmax)): - ll = 10000000 - return ll - - dispatcher = { - 'weibull': weibull, - 'weibull50': weibull50, - 'erf_psycho': erf_psycho, - 'erf_psycho_2gammas': erf_psycho_2gammas - } - try: - probs = dispatcher[P_model](pars, xx) - except KeyError: - raise ValueError('invalid model, options are "weibull", ' + - '"weibull50", "erf_psycho" and "erf_psycho_2gammas"') - - assert (max(probs) <= 1) or (min(probs) >= 0), 'At least one of the probabilities is not ' \ - 'between 0 and 1' - - probs[probs == 0] = np.finfo(float).eps - probs[probs == 1] = 1 - np.finfo(float).eps - - ll = - sum(nn * (pp * np.log(probs) + (1 - pp) * np.log(1 - probs))) - return ll - - -def weibull(pars, xx): - """ - Weibull function from 0 to 1, with lapse rate. - Args: - pars: Model parameters [alpha, beta, gamma]. - xx: vector of stim levels. - Returns: - A vector of length xx - Raises: - ValueError: pars must be of length 3 - TypeError: pars must be list-like or numpy array - Information: - 1999-11 FH wrote it - 2000-01 MC cleaned it up - 2018-08 MW ported to Python - """ - # Validate input - if not isinstance(pars, (list, tuple, np.ndarray)): - raise TypeError('pars must be list-like or numpy array') - - if len(pars) != 3: - raise ValueError('pars must be of length 3') - - alpha, beta, gamma = pars - return (1 - gamma) - (1 - 2 * gamma) * np.exp(-((xx / alpha) ** beta)) - - -def weibull50(pars, xx): - """ - Weibull function from 0.5 to 1, with lapse rate. - Args: - pars: Model parameters [alpha, beta, gamma]. - xx: vector of stim levels. - Returns: - A vector of length xx - Raises: - ValueError: pars must be of length 3 - TypeError: pars must be list-like or numpy array - Information: - 2000-04 MC wrote it - 2018-08 MW ported to Python - """ - # Validate input - if not isinstance(pars, (list, tuple, np.ndarray)): - raise TypeError('pars must be list-like or numpy array') - - if len(pars) != 3: - raise ValueError('pars must be of length 3') - - alpha, beta, gamma = pars - return (1 - gamma) - (.5 - gamma) * np.exp(-((xx / alpha) ** beta)) - - -def erf_psycho(pars, xx): - """ - erf function from 0 to 1, with lapse rate. - Args: - pars: Model parameters [bias, slope, lapse]. - xx: vector of stim levels. - Returns: - ff: A vector of length xx - Examples: - >>> import numpy as np - >>> import matplotlib.pyplot as plt - >>> xx = np.arange(-50, 50) - >>> ff = erf_psycho(np.array([-10., 10., 0.1]), xx) - >>> plt.plot(xx, ff) - Raises: - ValueError: pars must be of length 3 - TypeError: pars must be a list or numpy array - Information: - 2000 MC wrote it - 2018-08 MW ported to Python - """ - # Validate input - if not isinstance(pars, (list, tuple, np.ndarray)): - raise TypeError('pars must be list-like or numpy array') - - if len(pars) != 3: - raise ValueError('pars must be of length 4') - - (bias, slope, gamma) = pars - return gamma + (1 - 2 * gamma) * (erf((xx - bias) / slope) + 1) / 2 - - -def erf_psycho_2gammas(pars, xx): - """ - erf function from 0 to 1, with two lapse rates. - Args: - pars: Model parameters [bias, slope, gamma]. - xx: vector of stim levels (%) - Returns: - ff: A vector of length xx - Examples: - >>> import numpy as np - >>> import matplotlib.pyplot as plt - >>> xx = np.arange(-50, 50) - >>> ff = erf_psycho_2gammas(np.array([-10., 10., 0.2, 0.]), xx) - >>> plt.plot(xx, ff) - Raises: - ValueError: pars must be of length 4 - TypeError: pars must be list-like or numpy array - Information: - 2000 MC wrote it - 2018-08 MW ported to Python - """ - # Validate input - if not isinstance(pars, (list, tuple, np.ndarray)): - raise TypeError('pars must be a list-like or numpy array') - - if len(pars) != 4: - raise ValueError('pars must be of length 4') - - (bias, slope, gamma1, gamma2) = pars - return gamma1 + (1 - gamma1 - gamma2) * (erf((xx - bias) / slope) + 1) / 2 diff --git a/brainbox/behavior/training.py b/brainbox/behavior/training.py index 2e6c9f9fd..2a655252c 100644 --- a/brainbox/behavior/training.py +++ b/brainbox/behavior/training.py @@ -52,7 +52,6 @@ from one.api import ONE from one.alf.io import AlfBunch from one.alf.exceptions import ALFObjectNotFound - import psychofit as psy _logger = logging.getLogger('ibllib') diff --git a/ibllib/tests/qc/test_critical_reasons.py b/ibllib/tests/qc/test_critical_reasons.py index 0d1ae619d..4094ff468 100644 --- a/ibllib/tests/qc/test_critical_reasons.py +++ b/ibllib/tests/qc/test_critical_reasons.py @@ -58,9 +58,9 @@ def test_userinput_sess(self): print(critical_dict) expected_dict = { 'title': '=== EXPERIMENTER REASON(S) FOR MARKING THE SESSION AS CRITICAL ===', - 'reasons_selected': ['synching impossible', 'essential dataset missing'], - 'reason_for_other': []} - assert expected_dict == critical_dict + 'reasons_selected': ['synching impossible', 'Other'], + 'reason_for_other': 'estoy un poco preocupada'} + self.assertDictEqual(expected_dict, critical_dict) def test_userinput_ins(self): eid = self.ins_id # probe id @@ -70,9 +70,9 @@ def test_userinput_ins(self): critical_dict = json.loads(note[0]['text']) expected_dict = { 'title': '=== EXPERIMENTER REASON(S) FOR MARKING THE INSERTION AS CRITICAL ===', - 'reasons_selected': ['Track not visible on imaging data', 'Drift'], - 'reason_for_other': []} - assert expected_dict == critical_dict + 'reasons_selected': ['Track not visible on imaging data', 'Other'], + 'reason_for_other': 'estoy un poco preocupada'} + self.assertDictEqual(expected_dict, critical_dict) def test_note_already_existing(self): eid = self.sess_id # sess id @@ -85,8 +85,8 @@ def test_note_already_existing(self): usrpmt.main(eid, one=one) note = one.alyx.rest('notes', 'list', django=f'object_id,{eid}', no_cache=True) - assert len(note) == 1 - assert original_note_id != note[0]['id'] + self.assertEqual(len(note), 1) + self.assertNotEquals(original_note_id, note[0]['id']) def test_guiinput_ins(self): eid = self.ins_id # probe id @@ -103,12 +103,12 @@ def test_guiinput_ins(self): note = one.alyx.rest('notes', 'list', django=f'text__icontains,{str_notes_static},object_id,{eid}', no_cache=True) - assert len(note) == 1 + self.assertEqual(len(note), 1) critical_dict = json.loads(note[0]['text']) expected_dict = { 'title': '=== EXPERIMENTER REASON(S) FOR MARKING THE INSERTION AS CRITICAL ===', 'reasons_selected': ['Drift'], 'reason_for_other': []} - assert expected_dict == critical_dict + self.assertDictEqual(expected_dict, critical_dict) def test_note_probe_ins(self): # Note: this test is redundant with the above, but it tests specifically whether @@ -135,7 +135,7 @@ def test_note_probe_ins(self): notes = one.alyx.rest('notes', 'list', django=f'text__icontains,{note_text},object_id,{eid}', no_cache=True) - assert len(notes) == 1 + self.assertEqual(len(notes), 1) def tearDown(self) -> None: try: From eb41a3157c79c31120926ea65be8d2d41e926df0 Mon Sep 17 00:00:00 2001 From: Gaelle Date: Tue, 24 Oct 2023 12:38:59 +0200 Subject: [PATCH 36/68] doc raw data probe geometry --- .../loading_data/loading_raw_ephys_data.ipynb | 57 +++++++++---------- 1 file changed, 28 insertions(+), 29 deletions(-) diff --git a/examples/loading_data/loading_raw_ephys_data.ipynb b/examples/loading_data/loading_raw_ephys_data.ipynb index 9394b675e..56716f450 100644 --- a/examples/loading_data/loading_raw_ephys_data.ipynb +++ b/examples/loading_data/loading_raw_ephys_data.ipynb @@ -152,7 +152,7 @@ "\n", "# Use spikeglx reader to read in the whole raw data\n", "sr = spikeglx.Reader(bin_file)\n", - "sr.shape\n" + "print(sr.shape)" ] }, { @@ -328,57 +328,56 @@ }, { "cell_type": "markdown", - "id": "723df072", - "metadata": {}, "source": [ "## Get the probe geometry" - ] + ], + "metadata": { + "collapsed": false + } }, { "cell_type": "markdown", - "id": "35a0db60", - "metadata": {}, "source": [ "### Using the `eid` and `probe` information" - ] + ], + "metadata": { + "collapsed": false + } }, { "cell_type": "code", "execution_count": null, - "id": "c2335beb", - "metadata": {}, "outputs": [], "source": [ "from brainbox.io.one import load_channel_locations\n", - "from one.api import ONE\n", - "import spikeglx\n", - "import numpy as np\n", - "one = ONE()\n", - "\n", - "pid = 'da8dfec1-d265-44e8-84ce-6ae9c109b8bd'\n", - "eid, probe = one.pid2eid(pid)\n", - "channels = load_channel_locations(eid, probe)[probe]\n", - "channel_geometry = np.c_[channels['axial_um'], channels['lateral_um']]" - ] + "channels = load_channel_locations(eid, probe)\n", + "channels[probe][\"localCoordinates\"]" + ], + "metadata": { + "collapsed": false + } }, { "cell_type": "markdown", - "id": "360f668c", - "metadata": {}, "source": [ "### Using the reader and the `.cbin` file" - ] + ], + "metadata": { + "collapsed": false + } }, { "cell_type": "code", "execution_count": null, - "id": "763d6ac8", - "metadata": {}, "outputs": [], "source": [ - "sr = spikeglx.Reader(bin_file)\n", + "# You would have loaded the bin file as per the loading example above\n", + "# sr = spikeglx.Reader(bin_file)\n", "sr.geometry" - ] + ], + "metadata": { + "collapsed": false + } }, { "cell_type": "markdown", @@ -397,9 +396,9 @@ "metadata": { "celltoolbar": "Edit Metadata", "kernelspec": { - "display_name": "Python 3 (ipykernel)", + "display_name": "Python [conda env:iblenv] *", "language": "python", - "name": "python3" + "name": "conda-env-iblenv-py" }, "language_info": { "codemirror_mode": { @@ -411,7 +410,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.16" + "version": "3.9.7" } }, "nbformat": 4, From f76e29d0213d8b996faf4794f9a7358cd322199d Mon Sep 17 00:00:00 2001 From: Gaelle Date: Tue, 24 Oct 2023 12:54:36 +0200 Subject: [PATCH 37/68] local coordinates not existing as a key --- examples/loading_data/loading_raw_ephys_data.ipynb | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/examples/loading_data/loading_raw_ephys_data.ipynb b/examples/loading_data/loading_raw_ephys_data.ipynb index 56716f450..fbe95d3fb 100644 --- a/examples/loading_data/loading_raw_ephys_data.ipynb +++ b/examples/loading_data/loading_raw_ephys_data.ipynb @@ -351,7 +351,10 @@ "source": [ "from brainbox.io.one import load_channel_locations\n", "channels = load_channel_locations(eid, probe)\n", - "channels[probe][\"localCoordinates\"]" + "print(channels[probe].keys())\n", + "# Use the axial and lateral coordinates ; Print example first 4 channels\n", + "print(channels[probe][\"axial_um\"][0:4])\n", + "print(channels[probe][\"lateral_um\"][0:4])" ], "metadata": { "collapsed": false From 276bfea5b87a07cda3fc4da2b76788c95e1f623c Mon Sep 17 00:00:00 2001 From: Miles Wells Date: Wed, 25 Oct 2023 16:25:23 +0300 Subject: [PATCH 38/68] Add device_sound key to v7 settings; QC bugfix and test --- ibllib/atlas/__init__.py | 5 +- ibllib/io/extractors/bpod_trials.py | 3 +- ibllib/io/raw_data_loaders.py | 21 ++++++-- ibllib/pipes/__init__.py | 14 ++--- ibllib/plots/misc.py | 4 +- ibllib/qc/base.py | 2 +- ibllib/qc/task_metrics.py | 83 +++++++++++++++++++++-------- ibllib/tests/qc/test_base_qc.py | 16 +++++- 8 files changed, 109 insertions(+), 39 deletions(-) diff --git a/ibllib/atlas/__init__.py b/ibllib/atlas/__init__.py index 22c7720bd..b01f8f5e7 100644 --- a/ibllib/atlas/__init__.py +++ b/ibllib/atlas/__init__.py @@ -1,4 +1,7 @@ -"""A package for working with brain atlases. +"""(DEPRECATED) A package for working with brain atlases. + +For the correct atlas documentation, see +https://docs.internationalbrainlab.org/_autosummary/iblatlas.html For examples and tutorials on using the IBL atlas package, see https://docs.internationalbrainlab.org/atlas_examples.html diff --git a/ibllib/io/extractors/bpod_trials.py b/ibllib/io/extractors/bpod_trials.py index 950797b88..7f1db5cb1 100644 --- a/ibllib/io/extractors/bpod_trials.py +++ b/ibllib/io/extractors/bpod_trials.py @@ -1,4 +1,5 @@ -"""Trials data extraction from raw Bpod output +"""Trials data extraction from raw Bpod output. + This module will extract the Bpod trials and wheel data based on the task protocol, i.e. habituation, training or biased. """ diff --git a/ibllib/io/raw_data_loaders.py b/ibllib/io/raw_data_loaders.py index 200b8ca15..e1d78f5f3 100644 --- a/ibllib/io/raw_data_loaders.py +++ b/ibllib/io/raw_data_loaders.py @@ -16,7 +16,7 @@ from typing import Union from dateutil import parser as dateparser -from pkg_resources import parse_version +from packaging import version import numpy as np import pandas as pd @@ -325,16 +325,26 @@ def _read_settings_json_compatibility_enforced(settings): md['IS_MOCK'] = False if 'IBLRIG_VERSION_TAG' not in md.keys(): md['IBLRIG_VERSION_TAG'] = md.get('IBLRIG_VERSION', '') + if 'device_sound' not in md: + # sound device must be defined in version 8 and later # FIXME this assertion will cause tests to break + assert version.parse(md.get('IBLRIG_VERSION_TAG', '0')) < version.parse('8.0.0') + # in v7 we must infer the device from the sampling frequency if SD is None + if 'sounddevice' in md.get('SD', ''): + device = 'xonar' + else: + freq_map = {192000: 'xonar', 96000: 'harp', 44100: 'sysdefault'} + device = freq_map.get(md.get('SOUND_SAMPLE_FREQ'), 'unknown') + md['device_sound'] = {'OUTPUT': device} # 2018-12-05 Version 3.2.3 fixes (permanent fixes in IBL_RIG from 3.2.4 on) if md['IBLRIG_VERSION_TAG'] == '': pass - elif parse_version(md.get('IBLRIG_VERSION_TAG')) >= parse_version('8.0.0'): + elif version.parse(md.get('IBLRIG_VERSION_TAG', '0')) >= version.parse('8.0.0'): md['SESSION_NUMBER'] = str(md['SESSION_NUMBER']).zfill(3) md['PYBPOD_BOARD'] = md['RIG_NAME'] md['PYBPOD_CREATOR'] = (md['ALYX_USER'], '') md['SESSION_DATE'] = md['SESSION_START_TIME'][:10] md['SESSION_DATETIME'] = md['SESSION_START_TIME'] - elif parse_version(md.get('IBLRIG_VERSION_TAG')) <= parse_version('3.2.3'): + elif version.parse(md.get('IBLRIG_VERSION_TAG', '0')) <= version.parse('3.2.3'): if 'LAST_TRIAL_DATA' in md.keys(): md.pop('LAST_TRIAL_DATA') if 'weighings' in md['PYBPOD_SUBJECT_EXTRA'].keys(): @@ -423,7 +433,7 @@ def load_encoder_events(session_path, task_collection='raw_behavior_data', setti settings = {'IBLRIG_VERSION_TAG': '0.0.0'} if not path: return None - if parse_version(settings['IBLRIG_VERSION_TAG']) >= parse_version('5.0.0'): + if version.parse(settings['IBLRIG_VERSION_TAG']) >= version.parse('5.0.0'): return _load_encoder_events_file_ge5(path) else: return _load_encoder_events_file_lt5(path) @@ -528,7 +538,7 @@ def load_encoder_positions(session_path, task_collection='raw_behavior_data', se if not path: _logger.warning("No data loaded: could not find raw encoderPositions file") return None - if parse_version(settings['IBLRIG_VERSION_TAG']) >= parse_version('5.0.0'): + if version.parse(settings['IBLRIG_VERSION_TAG']) >= version.parse('5.0.0'): return _load_encoder_positions_file_ge5(path) else: return _load_encoder_positions_file_lt5(path) @@ -963,3 +973,4 @@ def patch_settings(session_path, collection='raw_behavior_data', with open(file_path, 'w') as fp: json.dump(settings, fp, indent=' ') return settings + diff --git a/ibllib/pipes/__init__.py b/ibllib/pipes/__init__.py index 95e8c6ce9..68c1d8445 100644 --- a/ibllib/pipes/__init__.py +++ b/ibllib/pipes/__init__.py @@ -1,16 +1,16 @@ """IBL preprocessing pipeline. This module concerns the data extraction and preprocessing for IBL data. The lab servers routinely -call `local_server.job_creator` to search for new sessions to extract. The job creator registers -the new session to Alyx (i.e. creates a new session record on the database), if required, then -deduces a set of tasks (a.k.a. the pipeline [*]_) from the 'experiment.description' file at the -root of the session (see `dynamic_pipeline.make_pipeline`). If no file exists one is created, +call :func:`local_server.job_creator` to search for new sessions to extract. The job creator +registers the new session to Alyx (i.e. creates a new session record on the database), if required, +then deduces a set of tasks (a.k.a. the pipeline[*]_) from the 'experiment.description' file at the +root of the session (see :func:`dynamic_pipeline.make_pipeline`). If no file exists one is created, inferring the acquisition hardware from the task protocol. The new session's pipeline tasks are then registered for another process (or server) to query. -Another process calls `local_server.task_queue` to get a list of queued tasks from Alyx, then -`local_server.tasks_runner` to loop through tasks. Each task is run by called -`tasks.run_alyx_task` with a dictionary of task information, including the Task class and its +Another process calls :func:`local_server.task_queue` to get a list of queued tasks from Alyx, then +:func:`local_server.tasks_runner` to loop through tasks. Each task is run by calling +:func:`tasks.run_alyx_task` with a dictionary of task information, including the Task class and its parameters. .. [*] A pipeline is a collection of tasks that depend on one another. A pipeline consists of diff --git a/ibllib/plots/misc.py b/ibllib/plots/misc.py index 36cd56afb..2a561ae8d 100644 --- a/ibllib/plots/misc.py +++ b/ibllib/plots/misc.py @@ -187,9 +187,9 @@ def squares(tscale, polarity, ax=None, yrange=[-1, 1], **kwargs): def vertical_lines(x, ymin=0, ymax=1, ax=None, **kwargs): """ - From a x vector, draw separate vertical lines at each x location ranging from ymin to ymax + From an x vector, draw separate vertical lines at each x location ranging from ymin to ymax - :param x: numpy array vector of x values where to display lnes + :param x: numpy array vector of x values where to display lines :param ymin: lower end of the lines (scalar) :param ymax: higher end of the lines (scalar) :param ax: (optional) matplotlib axis instance diff --git a/ibllib/qc/base.py b/ibllib/qc/base.py index 4669a860e..0fce5d18c 100644 --- a/ibllib/qc/base.py +++ b/ibllib/qc/base.py @@ -222,7 +222,7 @@ def compute_outcome_from_extended_qc(self) -> str: """ details = self.one.alyx.get(f'/{self.endpoint}/{self.eid}', clobber=True) extended_qc = details['json']['extended_qc'] if self.json else details['extended_qc'] - return self.overall_outcome(v for k, v in extended_qc or {} if k[0] != '_') + return self.overall_outcome(v for k, v in extended_qc.items() or {} if k[0] != '_') def sign_off_dict(exp_dec, sign_off_categories=None): diff --git a/ibllib/qc/task_metrics.py b/ibllib/qc/task_metrics.py index 42361645d..7738f8bde 100644 --- a/ibllib/qc/task_metrics.py +++ b/ibllib/qc/task_metrics.py @@ -130,23 +130,34 @@ def __init__(self, session_path_or_eid, **kwargs): self.passed = None def load_data(self, bpod_only=False, download_data=True): - """Extract the data from raw data files + """Extract the data from raw data files. + Extracts all the required task data from the raw data files. - :param bpod_only: if True no data is extracted from the FPGA for ephys sessions - :param download_data: if True, any missing raw data is downloaded via ONE. + Parameters + ---------- + bpod_only : bool + If True no data is extracted from the FPGA for ephys sessions. + download_data : bool + If True, any missing raw data is downloaded via ONE. By default data are not downloaded + if a session path was provided to the constructor. """ self.extractor = TaskQCExtractor( self.session_path, one=self.one, download_data=download_data, bpod_only=bpod_only) def compute(self, **kwargs): - """Compute and store the QC metrics + """Compute and store the QC metrics. + Runs the QC on the session and stores a map of the metrics for each datapoint for each - test, and a map of which datapoints passed for each test - :param bpod_only: if True no data is extracted from the FPGA for ephys sessions - :param download_data: if True, any missing raw data is downloaded via ONE. By default - data are not downloaded if a session path was provided to the constructor. - :return: + test, and a map of which datapoints passed for each test. + + Parameters + ---------- + bpod_only : bool + If True no data is extracted from the FPGA for ephys sessions. + download_data : bool + If True, any missing raw data is downloaded via ONE. By default data are not downloaded + if a session path was provided to the constructor. """ if self.extractor is None: kwargs['download_data'] = kwargs.pop('download_data', self.download_data) @@ -164,11 +175,26 @@ def compute(self, **kwargs): def run(self, update=False, namespace='task', **kwargs): """ - :param update: if True, updates the session QC fields on Alyx - :param bpod_only: if True no data is extracted from the FPGA for ephys sessions - :param download_data: if True, any missing raw data is downloaded via ONE. By default - data are not downloaded if a session path was provided to the constructor. - :return: QC outcome (str), a dict for extended QC + Compute the QC outcomes and return overall task QC outcome. + + Parameters + ---------- + update : bool + If True, updates the session QC fields on Alyx. + namespace : str + The namespace of the QC fields in the Alyx JSON field. + bpod_only : bool + If True no data is extracted from the FPGA for ephys sessions. + download_data : bool + If True, any missing raw data is downloaded via ONE. By default data are not downloaded + if a session path was provided to the constructor. + + Returns + ------- + str + Overall task QC outcome. + dict + A map of QC tests and the proportion of data points that passed them. """ if self.metrics is None: self.compute(**kwargs) @@ -183,9 +209,18 @@ def compute_session_status_from_dict(results): """ Given a dictionary of results, computes the overall session QC for each key and aggregates in a single value - :param results: a dictionary of qc keys containing (usually scalar) values - :return: Overall session QC outcome as a string - :return: A dict of QC tests and their outcomes + + Parameters + ---------- + results : dict + A dictionary of QC keys containing (usually scalar) values. + + Returns + ------- + str + Overall session QC outcome as a string. + dict + A map of QC tests and their outcomes. """ indices = np.zeros(len(results), dtype=int) for i, k in enumerate(results): @@ -203,10 +238,16 @@ def key_map(x): def compute_session_status(self): """ - Computes the overall session QC for each key and aggregates in a single value - :return: Overall session QC outcome as a string - :return: A dict of QC tests and the proportion of data points that passed them - :return: A dict of QC tests and their outcomes + Computes the overall session QC for each key and aggregates in a single value. + + Returns + ------- + str + Overall session QC outcome. + dict + A map of QC tests and the proportion of data points that passed them. + dict + A map of QC tests and their outcomes. """ if self.passed is None: raise AttributeError('passed is None; compute QC first') diff --git a/ibllib/tests/qc/test_base_qc.py b/ibllib/tests/qc/test_base_qc.py index b5e68dda0..e56750c64 100644 --- a/ibllib/tests/qc/test_base_qc.py +++ b/ibllib/tests/qc/test_base_qc.py @@ -1,4 +1,5 @@ import unittest +from unittest import mock import numpy as np @@ -100,6 +101,7 @@ def test_extended_qc(self) -> None: self.assertEqual(updated, {**current, **data}, 'failed to update the extended qc') def test_outcome_setter(self): + """Test for QC.outcome property setter.""" qc = self.qc qc.outcome = 'Fail' self.assertEqual(qc.outcome, 'FAIL') @@ -116,11 +118,23 @@ def test_outcome_setter(self): self.assertEqual(qc.outcome, 'PASS') def test_code_to_outcome(self): + """Test for QC.code_to_outcome method.""" self.assertEqual(QC.code_to_outcome(3), 'FAIL') def test_overall_outcome(self): + """Test for QC.overall_outcome method.""" self.assertEqual(QC.overall_outcome(['PASS', 'NOT_SET', None, 'FAIL']), 'FAIL') + def test_compute_outcome_from_extended_qc(self): + """Test for QC.compute_outcome_from_extended_qc method.""" + detail = {'extended_qc': {'foo': 'FAIL', 'bar': 'WARNING', '_baz_': 'CRITICAL'}, + 'json': {'extended_qc': {'foo': 'PASS', 'bar': 'WARNING', '_baz_': 'CRITICAL'}}} + with mock.patch.object(self.qc.one.alyx, 'get', return_value=detail): + self.qc.json = False + self.assertEqual(self.qc.compute_outcome_from_extended_qc(), 'FAIL') + self.qc.json = True + self.assertEqual(self.qc.compute_outcome_from_extended_qc(), 'WARNING') -if __name__ == "__main__": + +if __name__ == '__main__': unittest.main(exit=False, verbosity=2) From c9850ecff54d7f308cacae779340ab317191cf83 Mon Sep 17 00:00:00 2001 From: owinter Date: Thu, 19 Oct 2023 08:24:26 +0100 Subject: [PATCH 39/68] sdsc load data possible by monkey patching alfio in brainbox.io.one --- .flake8 | 2 +- brainbox/io/one.py | 28 +++++++++++++++++++++------- 2 files changed, 22 insertions(+), 8 deletions(-) diff --git a/.flake8 b/.flake8 index 308114caf..ffc753159 100644 --- a/.flake8 +++ b/.flake8 @@ -1,6 +1,6 @@ [flake8] max-line-length = 130 -ignore = W504, W503, E266 +ignore = W504, W503, E266, D, BLK exclude = .git, __pycache__, diff --git a/brainbox/io/one.py b/brainbox/io/one.py index 73870391c..b7fb5e535 100644 --- a/brainbox/io/one.py +++ b/brainbox/io/one.py @@ -12,10 +12,10 @@ import matplotlib.pyplot as plt from one.api import ONE, One -import one.alf.io as alfio from one.alf.files import get_alf_path from one.alf.exceptions import ALFObjectNotFound from one.alf import cache +import one.alf.io as alfio from neuropixel import TIP_SIZE_UM, trace_header import spikeglx @@ -830,6 +830,20 @@ def __post_init__(self): self.atlas = AllenAtlas() self.files = {} + def _load_object(self, *args, **kwargs): + """ + This function is a wrapper around alfio.load_object that will remove the UUID in the + filename if the object is on SDSC. + """ + remove_uuids = getattr(self.one, 'uuid_filenames', False) + d = alfio.load_object(*args, **kwargs) + if remove_uuids: + # pops the UUID in the key names + keys = list(d.keys()) + for k in keys: + d[k[:-37]] = d.pop(k) + return d + @staticmethod def _get_attributes(dataset_types): """returns attributes to load for spikes and clusters objects""" @@ -865,7 +879,7 @@ def load_spike_sorting_object(self, obj, *args, **kwargs): :return: """ self.download_spike_sorting_object(obj, *args, **kwargs) - return alfio.load_object(self.files[obj]) + return self._load_object(self.files[obj]) def download_spike_sorting_object(self, obj, spike_sorter='pykilosort', dataset_types=None, collection=None, missing='raise', **kwargs): @@ -922,10 +936,10 @@ def load_channels(self, **kwargs): # we do not specify the spike sorter on purpose here: the electrode sites do not depend on the spike sorting self.download_spike_sorting_object(obj='electrodeSites', collection=f'alf/{self.pname}', missing='ignore') if 'electrodeSites' in self.files: - channels = alfio.load_object(self.files['electrodeSites'], wildcards=self.one.wildcards) + channels = self._load_object(self.files['electrodeSites'], wildcards=self.one.wildcards) else: # otherwise, we try to load the channel object from the spike sorting folder - this may not contain histology self.download_spike_sorting_object(obj='channels', **kwargs) - channels = alfio.load_object(self.files['channels'], wildcards=self.one.wildcards) + channels = self._load_object(self.files['channels'], wildcards=self.one.wildcards) if 'brainLocationIds_ccf_2017' not in channels: _logger.debug(f"loading channels from alyx for {self.files['channels']}") _channels, self.histology = _load_channel_locations_traj( @@ -960,8 +974,8 @@ def load_spike_sorting(self, spike_sorter='pykilosort', **kwargs): self.spike_sorter = spike_sorter self.download_spike_sorting(spike_sorter=spike_sorter, **kwargs) channels = self.load_channels(spike_sorter=spike_sorter, **kwargs) - clusters = alfio.load_object(self.files['clusters'], wildcards=self.one.wildcards) - spikes = alfio.load_object(self.files['spikes'], wildcards=self.one.wildcards) + clusters = self._load_object(self.files['clusters'], wildcards=self.one.wildcards) + spikes = self._load_object(self.files['spikes'], wildcards=self.one.wildcards) return spikes, clusters, channels @@ -1090,7 +1104,7 @@ def raster(self, spikes, channels, save_dir=None, br=None, label='raster', time_ self.download_spike_sorting_object('drift', self.spike_sorter, missing='ignore') if 'drift' in self.files: - drift = alfio.load_object(self.files['drift'], wildcards=self.one.wildcards) + drift = self._load_object(self.files['drift'], wildcards=self.one.wildcards) axs[0, 0].plot(drift['times'], drift['um'], 'k', alpha=.5) if save_dir is not None: From 166c4c5efdf7c9f47b735b02077d1c785059b061 Mon Sep 17 00:00:00 2001 From: Miles Wells Date: Thu, 2 Nov 2023 14:24:13 +0200 Subject: [PATCH 40/68] bump ONE version --- ibllib/io/extractors/ephys_fpga.py | 5 ++--- requirements.txt | 2 +- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/ibllib/io/extractors/ephys_fpga.py b/ibllib/io/extractors/ephys_fpga.py index ee647e193..b0298011f 100644 --- a/ibllib/io/extractors/ephys_fpga.py +++ b/ibllib/io/extractors/ephys_fpga.py @@ -113,13 +113,12 @@ def _sync_to_alf(raw_ephys_apfile, output_path=None, save=False, parts=''): else: raw_ephys_apfile = Path(raw_ephys_apfile) sr = spikeglx.Reader(raw_ephys_apfile) - opened = sr.is_open - if not opened: # if not (opened := sr.is_open) # py3.8 + if not (opened := sr.is_open): sr.open() # if no output, need a temp folder to swap for big files if not output_path: output_path = raw_ephys_apfile.parent - file_ftcp = Path(output_path).joinpath(f'fronts_times_channel_polarity{str(uuid.uuid4())}.bin') + file_ftcp = Path(output_path).joinpath(f'fronts_times_channel_polarity{uuid.uuid4()}.bin') # loop over chunks of the raw ephys file wg = neurodsp.utils.WindowGenerator(sr.ns, int(SYNC_BATCH_SIZE_SECS * sr.fs), overlap=1) diff --git a/requirements.txt b/requirements.txt index 43d0760de..242ac9a0b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -26,7 +26,7 @@ tqdm>=4.32.1 ibl-neuropixel>=0.8.1 iblutil>=1.7.0 labcams # widefield extractor -ONE-api>=2.4 +ONE-api>=2.5 slidingRP>=1.0.0 # steinmetz lab refractory period metrics wfield==0.3.7 # widefield extractor frozen for now (2023/07/15) until Joao fixes latest version psychofit From e7ed533947712685fd6f36b6de95390f510eb0a2 Mon Sep 17 00:00:00 2001 From: Miles Wells Date: Fri, 3 Nov 2023 15:11:53 +0200 Subject: [PATCH 41/68] use packaging instead of pkg_resources; use IBLRIG_VERSION instead of IBLRIG_VERSION_TAG; soundcard task QC thresholds; change to check_stimulus_move_before_goCue --- ibllib/io/extractors/base.py | 6 +- ibllib/io/extractors/biased_trials.py | 14 ++-- ibllib/io/extractors/bpod_trials.py | 6 +- ibllib/io/extractors/ephys_passive.py | 4 +- ibllib/io/extractors/mesoscope.py | 8 +-- ibllib/io/extractors/training_trials.py | 26 +++---- ibllib/io/raw_data_loaders.py | 45 ++++++++---- ibllib/io/session_params.py | 6 +- ibllib/oneibl/registration.py | 8 +-- ibllib/pipes/base_tasks.py | 8 +-- ibllib/pipes/behavior_tasks.py | 6 +- ibllib/qc/task_metrics.py | 79 ++++++++++++++++------ ibllib/tests/extractors/test_extractors.py | 16 ++--- ibllib/tests/test_base_tasks.py | 2 +- ibllib/tests/test_oneibl.py | 2 +- 15 files changed, 148 insertions(+), 88 deletions(-) diff --git a/ibllib/io/extractors/base.py b/ibllib/io/extractors/base.py index c1b46b22e..1de1cff80 100644 --- a/ibllib/io/extractors/base.py +++ b/ibllib/io/extractors/base.py @@ -145,9 +145,9 @@ def extract(self, bpod_trials=None, settings=None, **kwargs): if not self.settings: self.settings = raw.load_settings(self.session_path, task_collection=self.task_collection) if self.settings is None: - self.settings = {"IBLRIG_VERSION_TAG": "100.0.0"} - elif self.settings.get("IBLRIG_VERSION_TAG", "") == "": - self.settings["IBLRIG_VERSION_TAG"] = "100.0.0" + self.settings = {"IBLRIG_VERSION": "100.0.0"} + elif self.settings.get("IBLRIG_VERSION", "") == "": + self.settings["IBLRIG_VERSION"] = "100.0.0" return super(BaseBpodTrialsExtractor, self).extract(**kwargs) diff --git a/ibllib/io/extractors/biased_trials.py b/ibllib/io/extractors/biased_trials.py index 16d8f8111..07dd64692 100644 --- a/ibllib/io/extractors/biased_trials.py +++ b/ibllib/io/extractors/biased_trials.py @@ -1,6 +1,6 @@ from pathlib import Path, PureWindowsPath -from pkg_resources import parse_version +from packaging import version import numpy as np from one.alf.io import AlfBunch @@ -80,8 +80,8 @@ def get_pregenerated_events(bpod_trials, settings): pLeft = pLeft[: ntrials] phase_path = sessions_folder.joinpath(f"session_{num}_stim_phase.npy") - is_patched_version = parse_version( - settings.get('IBLRIG_VERSION_TAG', 0)) > parse_version('6.4.0') + is_patched_version = version.parse( + settings.get('IBLRIG_VERSION') or '0') > version.parse('6.4.0') if phase_path.exists() and is_patched_version: phase = np.load(phase_path)[:ntrials] @@ -209,13 +209,13 @@ def extract_all(session_path, save=False, bpod_trials=False, settings=False, ext if not settings: settings = raw.load_settings(session_path, task_collection=task_collection) if settings is None: - settings = {'IBLRIG_VERSION_TAG': '100.0.0'} + settings = {'IBLRIG_VERSION': '100.0.0'} - if settings['IBLRIG_VERSION_TAG'] == '': - settings['IBLRIG_VERSION_TAG'] = '100.0.0' + if settings['IBLRIG_VERSION'] == '': + settings['IBLRIG_VERSION'] = '100.0.0' # Version check - if parse_version(settings['IBLRIG_VERSION_TAG']) >= parse_version('5.0.0'): + if version.parse(settings['IBLRIG_VERSION']) >= version.parse('5.0.0'): # We now extract a single trials table base = [BiasedTrials] else: diff --git a/ibllib/io/extractors/bpod_trials.py b/ibllib/io/extractors/bpod_trials.py index 7f1db5cb1..1e72d9da9 100644 --- a/ibllib/io/extractors/bpod_trials.py +++ b/ibllib/io/extractors/bpod_trials.py @@ -8,7 +8,7 @@ from collections import OrderedDict import warnings -from pkg_resources import parse_version +from packaging import version from ibllib.io.extractors import habituation_trials, training_trials, biased_trials, opto_trials from ibllib.io.extractors.base import get_bpod_extractor_class, protocol2extractor from ibllib.io.extractors.habituation_trials import HabituationTrials @@ -89,8 +89,8 @@ def extract_all(session_path, save=True, bpod_trials=None, settings=None, files_wheel = [] wheel = OrderedDict({k: trials.pop(k) for k in tuple(trials.keys()) if 'wheel' in k}) elif extractor_type == 'habituation': - if settings['IBLRIG_VERSION_TAG'] and \ - parse_version(settings['IBLRIG_VERSION_TAG']) <= parse_version('5.0.0'): + if settings['IBLRIG_VERSION'] and \ + version.parse(settings['IBLRIG_VERSION']) <= version.parse('5.0.0'): _logger.warning('No extraction of legacy habituation sessions') return None, None, None trials, files_trials = habituation_trials.extract_all(session_path, bpod_trials=bpod_trials, settings=settings, save=save, diff --git a/ibllib/io/extractors/ephys_passive.py b/ibllib/io/extractors/ephys_passive.py index 2dfcb34e2..f582da076 100644 --- a/ibllib/io/extractors/ephys_passive.py +++ b/ibllib/io/extractors/ephys_passive.py @@ -93,9 +93,11 @@ def _load_task_protocol(session_path: str, task_collection: str = 'raw_passive_d :type session_path: str :return: ibl rig task protocol version :rtype: str + + FIXME This function has a misleading name """ settings = rawio.load_settings(session_path, task_collection=task_collection) - ses_ver = settings["IBLRIG_VERSION_TAG"] + ses_ver = settings["IBLRIG_VERSION"] return ses_ver diff --git a/ibllib/io/extractors/mesoscope.py b/ibllib/io/extractors/mesoscope.py index 4def5ed3a..a7b0d1fce 100644 --- a/ibllib/io/extractors/mesoscope.py +++ b/ibllib/io/extractors/mesoscope.py @@ -7,7 +7,7 @@ from one.alf.files import session_path_parts import matplotlib.pyplot as plt from neurodsp.utils import falls -from pkg_resources import parse_version +from packaging import version from ibllib.plots.misc import squares, vertical_lines from ibllib.io.raw_daq_loaders import (extract_sync_timeline, timeline_get_channel, @@ -38,8 +38,8 @@ def patch_imaging_meta(meta: dict) -> dict: The loaded metadata file, updated to the most recent version. """ # 2023-05-17 (unversioned) adds nFrames, channelSaved keys, MM and Deg keys - version = parse_version(meta.get('version') or '0.0.0') - if version <= parse_version('0.0.0'): + ver = version.parse(meta.get('version') or '0.0.0') + if ver <= version.parse('0.0.0'): if 'channelSaved' not in meta: meta['channelSaved'] = next((x['channelIdx'] for x in meta['FOV'] if 'channelIdx' in x), []) fields = ('topLeft', 'topRight', 'bottomLeft', 'bottomRight') @@ -47,7 +47,7 @@ def patch_imaging_meta(meta: dict) -> dict: for unit in ('Deg', 'MM'): if unit not in fov: # topLeftDeg, etc. -> Deg[topLeft] fov[unit] = {f: fov.pop(f + unit, None) for f in fields} - elif version == parse_version('0.1.0'): + elif ver == version.parse('0.1.0'): for fov in meta.get('FOV', []): if 'roiUuid' in fov: fov['roiUUID'] = fov.pop('roiUuid') diff --git a/ibllib/io/extractors/training_trials.py b/ibllib/io/extractors/training_trials.py index 41a69d815..d3ca1447d 100644 --- a/ibllib/io/extractors/training_trials.py +++ b/ibllib/io/extractors/training_trials.py @@ -1,6 +1,6 @@ import logging import numpy as np -from pkg_resources import parse_version +from packaging import version from one.alf.io import AlfBunch import ibllib.io.raw_data_loaders as raw @@ -216,7 +216,7 @@ def get_feedback_times_ge5(session_path, task_collection='raw_behavior_data', da def _extract(self): # Version check - if parse_version(self.settings['IBLRIG_VERSION_TAG']) >= parse_version('5.0.0'): + if version.parse(self.settings['IBLRIG_VERSION'] or '100.0.0') >= version.parse('5.0.0'): merge = self.get_feedback_times_ge5(self.session_path, task_collection=self.task_collection, data=self.bpod_trials) else: merge = self.get_feedback_times_lt5(self.session_path, task_collection=self.task_collection, data=self.bpod_trials) @@ -287,7 +287,7 @@ class GoCueTriggerTimes(BaseBpodTrialsExtractor): var_names = 'goCueTrigger_times' def _extract(self): - if parse_version(self.settings['IBLRIG_VERSION_TAG']) >= parse_version('5.0.0'): + if version.parse(self.settings['IBLRIG_VERSION'] or '100.0.0') >= version.parse('5.0.0'): goCue = np.array([tr['behavior_data']['States timestamps'] ['play_tone'][0][0] for tr in self.bpod_trials]) else: @@ -361,7 +361,7 @@ class IncludedTrials(BaseBpodTrialsExtractor): var_names = 'included' def _extract(self): - if parse_version(self.settings['IBLRIG_VERSION_TAG']) >= parse_version('5.0.0'): + if version.parse(self.settings['IBLRIG_VERSION'] or '100.0.0') >= version.parse('5.0.0'): trials_included = self.get_included_trials_ge5( data=self.bpod_trials, settings=self.settings) else: @@ -370,7 +370,7 @@ def _extract(self): @staticmethod def get_included_trials_lt5(data=False): - trials_included = np.array([True for t in data]) + trials_included = np.ones(len(data), dtype=bool) return trials_included @staticmethod @@ -387,7 +387,7 @@ class ItiInTimes(BaseBpodTrialsExtractor): var_names = 'itiIn_times' def _extract(self): - if parse_version(self.settings["IBLRIG_VERSION_TAG"]) < parse_version("5.0.0"): + if version.parse(self.settings["IBLRIG_VERSION"] or '100.0.0') < version.parse("5.0.0"): iti_in = np.ones(len(self.bpod_trials)) * np.nan else: iti_in = np.array( @@ -416,7 +416,7 @@ class StimFreezeTriggerTimes(BaseBpodTrialsExtractor): var_names = 'stimFreezeTrigger_times' def _extract(self): - if parse_version(self.settings["IBLRIG_VERSION_TAG"]) < parse_version("6.2.5"): + if version.parse(self.settings["IBLRIG_VERSION"] or '100.0.0') < version.parse("6.2.5"): return np.ones(len(self.bpod_trials)) * np.nan freeze_reward = np.array( [ @@ -460,9 +460,9 @@ class StimOffTriggerTimes(BaseBpodTrialsExtractor): var_names = 'stimOffTrigger_times' def _extract(self): - if parse_version(self.settings["IBLRIG_VERSION_TAG"]) >= parse_version("6.2.5"): + if version.parse(self.settings["IBLRIG_VERSION"] or '100.0.0') >= version.parse("6.2.5"): stim_off_trigger_state = "hide_stim" - elif parse_version(self.settings["IBLRIG_VERSION_TAG"]) >= parse_version("5.0.0"): + elif version.parse(self.settings["IBLRIG_VERSION"]) >= version.parse("5.0.0"): stim_off_trigger_state = "exit_state" else: stim_off_trigger_state = "trial_start" @@ -518,7 +518,7 @@ def _extract(self): # Version check _logger.warning("Deprecation Warning: this is an old version of stimOn extraction." "From version 5., use StimOnOffFreezeTimes") - if parse_version(self.settings['IBLRIG_VERSION_TAG']) >= parse_version('5.0.0'): + if version.parse(self.settings['IBLRIG_VERSION'] or '100.0.0') >= version.parse('5.0.0'): stimOn_times = self.get_stimOn_times_ge5(self.session_path, data=self.bpod_trials, task_collection=self.task_collection) else: @@ -739,11 +739,11 @@ def extract_all(session_path, save=False, bpod_trials=None, settings=None, task_ bpod_trials = raw.load_data(session_path, task_collection=task_collection) if not settings: settings = raw.load_settings(session_path, task_collection=task_collection) - if settings is None or settings['IBLRIG_VERSION_TAG'] == '': - settings = {'IBLRIG_VERSION_TAG': '100.0.0'} + if settings is None or settings['IBLRIG_VERSION'] == '': + settings = {'IBLRIG_VERSION': '100.0.0'} # Version check - if parse_version(settings['IBLRIG_VERSION_TAG']) >= parse_version('5.0.0'): + if version.parse(settings['IBLRIG_VERSION']) >= version.parse('5.0.0'): # We now extract a single trials table base = [TrainingTrials] else: diff --git a/ibllib/io/raw_data_loaders.py b/ibllib/io/raw_data_loaders.py index e1d78f5f3..c2dae8c71 100644 --- a/ibllib/io/raw_data_loaders.py +++ b/ibllib/io/raw_data_loaders.py @@ -7,6 +7,7 @@ Module contains one loader function per raw datafile """ +import re import json import logging import wave @@ -323,11 +324,30 @@ def _read_settings_json_compatibility_enforced(settings): md = json.load(js) if 'IS_MOCK' not in md: md['IS_MOCK'] = False + # Many v < 8 sessions had both version and version tag keys. v > 8 have a version tag. + # Some sessions have neither key. From v8 onwards we will use IBLRIG_VERSION to test rig + # version, however some places may still use the version tag. if 'IBLRIG_VERSION_TAG' not in md.keys(): md['IBLRIG_VERSION_TAG'] = md.get('IBLRIG_VERSION', '') + if 'IBLRIG_VERSION' not in md.keys(): + md['IBLRIG_VERSION'] = md['IBLRIG_VERSION_TAG'] + elif all([md['IBLRIG_VERSION'], md['IBLRIG_VERSION_TAG']]): + # This may not be an issue; not sure what the intended difference between these keys was + assert md['IBLRIG_VERSION'] == md['IBLRIG_VERSION_TAG'], 'version and version tag mismatch' + # Test version can be parsed. If not, log an error and set the version to nothing + try: + version.parse(md['IBLRIG_VERSION'] or '0') + except version.InvalidVersion as ex: + _logger.error('%s in iblrig settings, this may affect extraction', ex) + # try a more relaxed version parse + laxed_parse = re.search(r'^\d+\.\d+\.\d+', md['IBLRIG_VERSION']) + # Set the tag as the invalid version + md['IBLRIG_VERSION_TAG'] = md['IBLRIG_VERSION'] + # overwrite version with either successfully parsed one or an empty string + md['IBLRIG_VERSION'] = laxed_parse.group() if laxed_parse else '' if 'device_sound' not in md: # sound device must be defined in version 8 and later # FIXME this assertion will cause tests to break - assert version.parse(md.get('IBLRIG_VERSION_TAG', '0')) < version.parse('8.0.0') + assert version.parse(md['IBLRIG_VERSION'] or '0') < version.parse('8.0.0') # in v7 we must infer the device from the sampling frequency if SD is None if 'sounddevice' in md.get('SD', ''): device = 'xonar' @@ -336,15 +356,15 @@ def _read_settings_json_compatibility_enforced(settings): device = freq_map.get(md.get('SOUND_SAMPLE_FREQ'), 'unknown') md['device_sound'] = {'OUTPUT': device} # 2018-12-05 Version 3.2.3 fixes (permanent fixes in IBL_RIG from 3.2.4 on) - if md['IBLRIG_VERSION_TAG'] == '': + if md['IBLRIG_VERSION'] == '': pass - elif version.parse(md.get('IBLRIG_VERSION_TAG', '0')) >= version.parse('8.0.0'): + elif version.parse(md['IBLRIG_VERSION']) >= version.parse('8.0.0'): md['SESSION_NUMBER'] = str(md['SESSION_NUMBER']).zfill(3) md['PYBPOD_BOARD'] = md['RIG_NAME'] md['PYBPOD_CREATOR'] = (md['ALYX_USER'], '') md['SESSION_DATE'] = md['SESSION_START_TIME'][:10] md['SESSION_DATETIME'] = md['SESSION_START_TIME'] - elif version.parse(md.get('IBLRIG_VERSION_TAG', '0')) <= version.parse('3.2.3'): + elif version.parse(md['IBLRIG_VERSION']) <= version.parse('3.2.3'): if 'LAST_TRIAL_DATA' in md.keys(): md.pop('LAST_TRIAL_DATA') if 'weighings' in md['PYBPOD_SUBJECT_EXTRA'].keys(): @@ -424,16 +444,16 @@ def load_encoder_events(session_path, task_collection='raw_behavior_data', setti path = next(path.glob("_iblrig_encoderEvents.raw*.ssv"), None) if not settings: settings = load_settings(session_path, task_collection=task_collection) - if settings is None or not settings.get('IBLRIG_VERSION_TAG'): - settings = {'IBLRIG_VERSION_TAG': '100.0.0'} + if settings is None or not settings.get('IBLRIG_VERSION'): + settings = {'IBLRIG_VERSION': '100.0.0'} # auto-detect old files when version is not labeled with open(path) as fid: line = fid.readline() if line.startswith('Event') and 'StateMachine' in line: - settings = {'IBLRIG_VERSION_TAG': '0.0.0'} + settings = {'IBLRIG_VERSION': '0.0.0'} if not path: return None - if version.parse(settings['IBLRIG_VERSION_TAG']) >= version.parse('5.0.0'): + if version.parse(settings['IBLRIG_VERSION']) >= version.parse('5.0.0'): return _load_encoder_events_file_ge5(path) else: return _load_encoder_events_file_lt5(path) @@ -528,17 +548,17 @@ def load_encoder_positions(session_path, task_collection='raw_behavior_data', se path = next(path.glob("_iblrig_encoderPositions.raw*.ssv"), None) if not settings: settings = load_settings(session_path, task_collection=task_collection) - if settings is None or not settings.get('IBLRIG_VERSION_TAG'): - settings = {'IBLRIG_VERSION_TAG': '100.0.0'} + if settings is None or not settings.get('IBLRIG_VERSION'): + settings = {'IBLRIG_VERSION': '100.0.0'} # auto-detect old files when version is not labeled with open(path) as fid: line = fid.readline() if line.startswith('Position'): - settings = {'IBLRIG_VERSION_TAG': '0.0.0'} + settings = {'IBLRIG_VERSION': '0.0.0'} if not path: _logger.warning("No data loaded: could not find raw encoderPositions file") return None - if version.parse(settings['IBLRIG_VERSION_TAG']) >= version.parse('5.0.0'): + if version.parse(settings['IBLRIG_VERSION']) >= version.parse('5.0.0'): return _load_encoder_positions_file_ge5(path) else: return _load_encoder_positions_file_lt5(path) @@ -973,4 +993,3 @@ def patch_settings(session_path, collection='raw_behavior_data', with open(file_path, 'w') as fp: json.dump(settings, fp, indent=' ') return settings - diff --git a/ibllib/io/session_params.py b/ibllib/io/session_params.py index 5bcaf2873..cd2eccad5 100644 --- a/ibllib/io/session_params.py +++ b/ibllib/io/session_params.py @@ -30,7 +30,7 @@ from copy import deepcopy from one.converters import ConversionMixin -from pkg_resources import parse_version +from packaging import version import ibllib.pipes.misc as misc @@ -71,9 +71,9 @@ def _patch_file(data: dict) -> dict: The patched description data. """ if data and (v := data.get('version', '0')) != SPEC_VERSION: - if parse_version(v) > parse_version(SPEC_VERSION): + if version.parse(v) > version.parse(SPEC_VERSION): _logger.warning('Description file generated by more recent code') - elif parse_version(v) <= parse_version('0.1.0'): + elif version.parse(v) <= version.parse('0.1.0'): # Change tasks key from dict to list of dicts if 'tasks' in data and isinstance(data['tasks'], dict): data['tasks'] = [{k: v} for k, v in data['tasks'].copy().items()] diff --git a/ibllib/oneibl/registration.py b/ibllib/oneibl/registration.py index 554735e15..470c8aead 100644 --- a/ibllib/oneibl/registration.py +++ b/ibllib/oneibl/registration.py @@ -4,7 +4,7 @@ import logging import itertools -from pkg_resources import parse_version +from packaging import version from one.alf.files import get_session_path, folder_parts, get_alf_path from one.registration import RegistrationClient, get_dataset_type from one.remote.globus import get_local_endpoint_id, get_lab_from_endpoint_id @@ -230,7 +230,7 @@ def register_session(self, ses_path, file_list=True, projects=None, procedures=N n_trials, n_correct_trials = _get_session_performance(settings, task_data) # TODO Add task_protocols to Alyx sessions endpoint - task_protocols = [md['PYBPOD_PROTOCOL'] + md['IBLRIG_VERSION_TAG'] for md in settings] + task_protocols = [md['PYBPOD_PROTOCOL'] + md['IBLRIG_VERSION'] for md in settings] # unless specified label the session projects with subject projects projects = subject['projects'] if projects is None else projects # makes sure projects is a list @@ -298,7 +298,7 @@ def register_session(self, ses_path, file_list=True, projects=None, procedures=N # register all files that match the Alyx patterns and file_list if any(settings): - rename_files_compatibility(ses_path, settings[0]['IBLRIG_VERSION_TAG']) + rename_files_compatibility(ses_path, settings[0]['IBLRIG_VERSION']) F = filter(lambda x: self._register_bool(x.name, file_list), self.find_files(ses_path)) recs = self.register_files(F, created_by=users[0] if users else None, versions=ibllib.__version__) return session, recs @@ -370,7 +370,7 @@ def _alyx_procedure_from_task_type(task_type): def rename_files_compatibility(ses_path, version_tag): if not version_tag: return - if parse_version(version_tag) <= parse_version('3.2.3'): + if version.parse(version_tag) <= version.parse('3.2.3'): task_code = ses_path.glob('**/_ibl_trials.iti_duration.npy') for fn in task_code: fn.replace(fn.parent.joinpath('_ibl_trials.itiDuration.npy')) diff --git a/ibllib/pipes/base_tasks.py b/ibllib/pipes/base_tasks.py index 9005e365d..fc848af85 100644 --- a/ibllib/pipes/base_tasks.py +++ b/ibllib/pipes/base_tasks.py @@ -2,7 +2,7 @@ import logging from pathlib import Path -from pkg_resources import parse_version +from packaging import version from one.webclient import no_cache from iblutil.util import flatten @@ -121,9 +121,9 @@ def _spacer_support(settings): bool True if task spacers are to be expected. """ - v = parse_version - version = v(settings.get('IBLRIG_VERSION_TAG')) - return version not in (v('100.0.0'), v('8.0.0')) and version >= v('7.1.0') + v = version.parse + ver = v(settings.get('IBLRIG_VERSION') or '100.0.0') + return ver not in (v('100.0.0'), v('8.0.0')) and ver >= v('7.1.0') class VideoTask(DynamicTask): diff --git a/ibllib/pipes/behavior_tasks.py b/ibllib/pipes/behavior_tasks.py index 6f1c8d506..a1fcc900d 100644 --- a/ibllib/pipes/behavior_tasks.py +++ b/ibllib/pipes/behavior_tasks.py @@ -2,7 +2,7 @@ import logging import traceback -from pkg_resources import parse_version +from packaging import version import one.alf.io as alfio from one.alf.files import session_path_parts from one.api import ONE @@ -209,8 +209,8 @@ def _run(self, **kwargs): This class exists to load the sync file and set the protocol_number to None """ settings = load_settings(self.session_path, self.collection) - version = settings.get('IBLRIG_VERSION_TAG', '100.0.0') - if version == '100.0.0' or parse_version(version) <= parse_version('7.1.0'): + ver = settings.get('IBLRIG_VERSION') or '100.0.0' + if ver == '100.0.0' or version.parse(ver) <= version.parse('7.1.0'): _logger.warning('Protocol spacers not supported; setting protocol_number to None') self.protocol_number = None diff --git a/ibllib/qc/task_metrics.py b/ibllib/qc/task_metrics.py index 7738f8bde..dd2d3d2b9 100644 --- a/ibllib/qc/task_metrics.py +++ b/ibllib/qc/task_metrics.py @@ -54,6 +54,7 @@ from collections.abc import Sized import numpy as np +from packaging import version from scipy.stats import chisquare from brainbox.behavior.wheel import cm_to_rad, traces_by_trial @@ -162,6 +163,15 @@ def compute(self, **kwargs): if self.extractor is None: kwargs['download_data'] = kwargs.pop('download_data', self.download_data) self.load_data(**kwargs) + + ver = self.extractor.settings.get('IBLRIG_VERSION', '') or '0.0.0' + if version.parse(ver) >= version.parse('8.0.0'): + self.criteria['_task_iti_delays'] = {'PASS': 0.99, 'WARNING': 0} + self.criteria['_task_passed_trial_checks'] = {'PASS': 0.7, 'WARNING': 0} + else: + self.criteria['_task_iti_delays'] = {'NOT_SET': 0} + self.criteria['_task_passed_trial_checks'] = {'NOT_SET': 0} + self.log.info(f'Session {self.session_path}: Running QC on behavior data...') self.metrics, self.passed = get_bpodqc_metrics_frame( self.extractor.data, @@ -169,7 +179,8 @@ def compute(self, **kwargs): photodiode=self.extractor.frame_ttls, audio=self.extractor.audio_ttls, re_encoding=self.extractor.wheel_encoding or 'X1', - min_qt=self.extractor.settings.get('QUIESCENT_PERIOD') or 0.2 + min_qt=self.extractor.settings.get('QUIESCENT_PERIOD') or 0.2, + audio_output=self.extractor.settings.get('device_sound', {}).get('OUTPUT', 'unknown') ) return @@ -261,9 +272,9 @@ def compute_session_status(self): class HabituationQC(TaskQC): def compute(self, download_data=None): - """Compute and store the QC metrics + """Compute and store the QC metrics. Runs the QC on the session and stores a map of the metrics for each datapoint for each - test, and a map of which datapoints passed for each test + test, and a map of which datapoints passed for each test. :return: """ if self.extractor is None: @@ -275,6 +286,7 @@ def compute(self, download_data=None): # Initialize checks prefix = '_task_' data = self.extractor.data + audio_output = self.extractor.settings.get('device_sound', {}).get('OUTPUT', 'unknown') metrics = {} passed = {} @@ -354,7 +366,7 @@ def compute(self, download_data=None): check_stimOn_delays, check_stimOff_delays] for fcn in checks: check = prefix + fcn.__name__[6:] - metrics[check], passed[check] = fcn(data) + metrics[check], passed[check] = fcn(data, audio_output=audio_output) self.metrics, self.passed = (metrics, passed) @@ -404,7 +416,7 @@ def is_metric(x): # === Delays between events checks === -def check_stimOn_goCue_delays(data, **_): +def check_stimOn_goCue_delays(data, audio_output='harp', **_): """ Checks that the time difference between the onset of the visual stimulus and the onset of the go cue tone is positive and less than 10ms. @@ -413,16 +425,22 @@ def check_stimOn_goCue_delays(data, **_): Units: seconds [s] :param data: dict of trial data with keys ('goCue_times', 'stimOn_times', 'intervals') + :param audio_output: audio output device name. + + Notes + ----- + For non-harp soundcards the permissible delay is 0.053s """ # Calculate the difference between stimOn and goCue times. # If either are NaN, the result will be Inf to ensure that it crosses the failure threshold. + threshold = 0.01 if audio_output.lower() == 'harp' else 0.053 metric = np.nan_to_num(data['goCue_times'] - data['stimOn_times'], nan=np.inf) - passed = (metric < 0.01) & (metric > 0) + passed = (metric < threshold) & (metric > 0) assert data['intervals'].shape[0] == len(metric) == len(passed) return metric, passed -def check_response_feedback_delays(data, **_): +def check_response_feedback_delays(data, audio_output='harp', **_): """ Checks that the time difference between the response and the feedback onset (error sound or valve) is positive and less than 10ms. @@ -431,9 +449,15 @@ def check_response_feedback_delays(data, **_): Units: seconds [s] :param data: dict of trial data with keys ('feedback_times', 'response_times', 'intervals') + :param audio_output: audio output device name. + + Notes + ----- + For non-harp soundcards the permissible delay is 0.053s """ + threshold = 0.01 if audio_output.lower() == 'harp' else 0.053 metric = np.nan_to_num(data['feedback_times'] - data['response_times'], nan=np.inf) - passed = (metric < 0.01) & (metric > 0) + passed = (metric < threshold) & (metric > 0) assert data['intervals'].shape[0] == len(metric) == len(passed) return metric, passed @@ -871,7 +895,7 @@ def check_trial_length(data, **_): # === Trigger-response delay checks === -def check_goCue_delays(data, **_): +def check_goCue_delays(data, audio_output='harp', **_): """ Check that the time difference between the go cue sound being triggered and effectively played is smaller than 1ms. @@ -879,15 +903,21 @@ def check_goCue_delays(data, **_): Criterion: 0 < M <= 0.0015 s Units: seconds [s] - :param data: dict of trial data with keys ('goCue_times', 'goCueTrigger_times', 'intervals') + :param data: dict of trial data with keys ('goCue_times', 'goCueTrigger_times', 'intervals'). + :param audio_output: audio output device name. + + Notes + ----- + For non-harp soundcards the permissible delay is 0.053s """ + threshold = 0.0015 if audio_output.lower() == 'harp' else 0.053 metric = np.nan_to_num(data['goCue_times'] - data['goCueTrigger_times'], nan=np.inf) - passed = (metric <= 0.0015) & (metric > 0) + passed = (metric <= threshold) & (metric > 0) assert data['intervals'].shape[0] == len(metric) == len(passed) return metric, passed -def check_errorCue_delays(data, **_): +def check_errorCue_delays(data, audio_output='harp', **_): """ Check that the time difference between the error sound being triggered and effectively played is smaller than 1ms. Metric: M = errorCue_times - errorCueTrigger_times @@ -896,9 +926,15 @@ def check_errorCue_delays(data, **_): :param data: dict of trial data with keys ('errorCue_times', 'errorCueTrigger_times', 'intervals', 'correct') + :param audio_output: audio output device name. + + Notes + ----- + For non-harp soundcards the permissible delay is 0.062s """ + threshold = 0.0015 if audio_output.lower() == 'harp' else 0.062 metric = np.nan_to_num(data['errorCue_times'] - data['errorCueTrigger_times'], nan=np.inf) - passed = ((metric <= 0.0015) & (metric > 0)).astype(float) + passed = ((metric <= threshold) & (metric > 0)).astype(float) passed[data['correct']] = metric[data['correct']] = np.nan assert data['intervals'].shape[0] == len(metric) == len(passed) return metric, passed @@ -1025,14 +1061,19 @@ def check_wheel_integrity(data, re_encoding='X1', enc_res=None, **_): # === Pre-stimulus checks === def check_stimulus_move_before_goCue(data, photodiode=None, **_): """ Check that there are no visual stimulus change(s) between the start of the trial and the - go cue sound onset - 20 ms. + go cue sound onset, expect for stim on. - Metric: M = number of visual stimulus change events between trial start and goCue_times - 20ms - Criterion: M == 0 + Metric: M = number of visual stimulus change events between trial start and goCue_times + Criterion: M == 1 Units: -none-, integer :param data: dict of trial data with keys ('goCue_times', 'intervals', 'choice') :param photodiode: the fronts from Bpod's BNC1 input or FPGA frame2ttl channel + + Notes + ----- + - There should be exactly 1 stimulus change before goCue; stimulus onset. Even if the stimulus + contrast is 0, the sync square will still flip at stimulus onset, etc. """ if photodiode is None: _log.warning('No photodiode TTL input in function call, returning None') @@ -1042,11 +1083,9 @@ def check_stimulus_move_before_goCue(data, photodiode=None, **_): s = s[~np.isnan(s)] # Remove NaNs metric = np.array([]) for i, c in zip(data['intervals'][:, 0], data['goCue_times']): - metric = np.append(metric, np.count_nonzero(s[s > i] < (c - 0.02))) + metric = np.append(metric, np.count_nonzero(s[s > i] < c)) - passed = (metric == 0).astype(float) - # Remove no go trials - passed[data['choice'] == 0] = np.nan + passed = (metric == 1).astype(float) assert data['intervals'].shape[0] == len(metric) == len(passed) return metric, passed diff --git a/ibllib/tests/extractors/test_extractors.py b/ibllib/tests/extractors/test_extractors.py index 56a8de86d..7bd58d4f0 100644 --- a/ibllib/tests/extractors/test_extractors.py +++ b/ibllib/tests/extractors/test_extractors.py @@ -423,7 +423,7 @@ def test_get_included_trials_ge5(self): def test_get_included_trials(self): # TRAINING SESSIONS it = training_trials.IncludedTrials( - self.training_lt5['path']).extract(settings={'IBLRIG_VERSION_TAG': '4.9.9'})[0] + self.training_lt5['path']).extract(settings={'IBLRIG_VERSION': '4.9.9'})[0] self.assertTrue(isinstance(it, np.ndarray)) # -- version >= 5.0.0 it = training_trials.IncludedTrials( @@ -432,7 +432,7 @@ def test_get_included_trials(self): # BIASED SESSIONS it = biased_trials.IncludedTrials( - self.biased_lt5['path']).extract(settings={'IBLRIG_VERSION_TAG': '4.9.9'})[0] + self.biased_lt5['path']).extract(settings={'IBLRIG_VERSION': '4.9.9'})[0] self.assertTrue(isinstance(it, np.ndarray)) # -- version >= 5.0.0 it = biased_trials.IncludedTrials( @@ -445,7 +445,7 @@ def test_extract_all(self): # Expect an error raised because no wheel moves were present in test data with self.assertRaises(ValueError) as ex: training_trials.extract_all( - self.training_lt5['path'], settings={'IBLRIG_VERSION_TAG': '4.9.9'}, save=True) + self.training_lt5['path'], settings={'IBLRIG_VERSION': '4.9.9'}, save=True) self.assertIn('_ibl_wheelMoves.intervals.npy appears to be empty', str(ex.exception)) # -- version >= 5.0.0 out, files = training_trials.extract_all(self.training_ge5['path'], save=True) @@ -459,7 +459,7 @@ def test_extract_all(self): Wheel.var_names = tuple() Wheel().extract.return_value = ({}, []) out, files = biased_trials.extract_all( - self.biased_lt5['path'], settings={'IBLRIG_VERSION_TAG': '4.9.9'}, save=True) + self.biased_lt5['path'], settings={'IBLRIG_VERSION': '4.9.9'}, save=True) self.assertEqual(15, len(out)) self.assertTrue(all(map(Path.exists, files))) # -- version >= 5.0.0 @@ -508,18 +508,18 @@ def test_wheel_folders(self): def test_load_encoder_positions(self): raw.load_encoder_positions(self.training_lt5['path'], - settings={'IBLRIG_VERSION_TAG': '4.9.9'}) + settings={'IBLRIG_VERSION': '4.9.9'}) raw.load_encoder_positions(self.training_ge5['path']) raw.load_encoder_positions(self.biased_lt5['path'], - settings={'IBLRIG_VERSION_TAG': '4.9.9'}) + settings={'IBLRIG_VERSION': '4.9.9'}) raw.load_encoder_positions(self.biased_ge5['path']) def test_load_encoder_events(self): raw.load_encoder_events(self.training_lt5['path'], - settings={'IBLRIG_VERSION_TAG': '4.9.9'}) + settings={'IBLRIG_VERSION': '4.9.9'}) raw.load_encoder_events(self.training_ge5['path']) raw.load_encoder_events(self.biased_lt5['path'], - settings={'IBLRIG_VERSION_TAG': '4.9.9'}) + settings={'IBLRIG_VERSION': '4.9.9'}) raw.load_encoder_events(self.biased_ge5['path']) def test_size_outputs(self): diff --git a/ibllib/tests/test_base_tasks.py b/ibllib/tests/test_base_tasks.py index e91d20450..f5014a162 100644 --- a/ibllib/tests/test_base_tasks.py +++ b/ibllib/tests/test_base_tasks.py @@ -87,7 +87,7 @@ def test_spacer_support(self) -> None: settings = {} spacer_support = partial(base_tasks.BehaviourTask._spacer_support, settings) for version, expected in to_test: - settings['IBLRIG_VERSION_TAG'] = version + settings['IBLRIG_VERSION'] = version with self.subTest(version): self.assertIs(spacer_support(), expected) diff --git a/ibllib/tests/test_oneibl.py b/ibllib/tests/test_oneibl.py index aa9483d6b..9493e4cda 100644 --- a/ibllib/tests/test_oneibl.py +++ b/ibllib/tests/test_oneibl.py @@ -144,7 +144,7 @@ def test_dsets_2_path(self): 'SUBJECT_NAME': SUBJECT, 'PYBPOD_BOARD': '_iblrig_mainenlab_behavior_1', 'PYBPOD_PROTOCOL': '_iblrig_tasks_ephysChoiceWorld', - 'IBLRIG_VERSION_TAG': '5.4.1', + 'IBLRIG_VERSION': '5.4.1', 'SUBJECT_WEIGHT': 22, } From 06d65eaf9356aca3d921ae92f3f6d3bd7da25fce Mon Sep 17 00:00:00 2001 From: Miles Wells Date: Mon, 6 Nov 2023 12:45:20 +0200 Subject: [PATCH 42/68] QC notes and test fix --- ibllib/qc/task_metrics.py | 12 ++++++++---- ibllib/tests/qc/test_task_metrics.py | 5 +++-- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/ibllib/qc/task_metrics.py b/ibllib/qc/task_metrics.py index dd2d3d2b9..7709960e3 100644 --- a/ibllib/qc/task_metrics.py +++ b/ibllib/qc/task_metrics.py @@ -429,7 +429,8 @@ def check_stimOn_goCue_delays(data, audio_output='harp', **_): Notes ----- - For non-harp soundcards the permissible delay is 0.053s + For non-harp sound card the permissible delay is 0.053s. This was chosen by taking the 99.5th + percentile of delays over 500 training sessions using the Xonar soundcard. """ # Calculate the difference between stimOn and goCue times. # If either are NaN, the result will be Inf to ensure that it crosses the failure threshold. @@ -453,7 +454,8 @@ def check_response_feedback_delays(data, audio_output='harp', **_): Notes ----- - For non-harp soundcards the permissible delay is 0.053s + For non-harp sound card the permissible delay is 0.053s. This was chosen by taking the 99.5th + percentile of delays over 500 training sessions using the Xonar soundcard. """ threshold = 0.01 if audio_output.lower() == 'harp' else 0.053 metric = np.nan_to_num(data['feedback_times'] - data['response_times'], nan=np.inf) @@ -908,7 +910,8 @@ def check_goCue_delays(data, audio_output='harp', **_): Notes ----- - For non-harp soundcards the permissible delay is 0.053s + For non-harp sound card the permissible delay is 0.053s. This was chosen by taking the 99.5th + percentile of delays over 500 training sessions using the Xonar soundcard. """ threshold = 0.0015 if audio_output.lower() == 'harp' else 0.053 metric = np.nan_to_num(data['goCue_times'] - data['goCueTrigger_times'], nan=np.inf) @@ -930,7 +933,8 @@ def check_errorCue_delays(data, audio_output='harp', **_): Notes ----- - For non-harp soundcards the permissible delay is 0.062s + For non-harp sound card the permissible delay is 0.062s. This was chosen by taking the 99.5th + percentile of delays over 500 training sessions using the Xonar soundcard. """ threshold = 0.0015 if audio_output.lower() == 'harp' else 0.062 metric = np.nan_to_num(data['errorCue_times'] - data['errorCueTrigger_times'], nan=np.inf) diff --git a/ibllib/tests/qc/test_task_metrics.py b/ibllib/tests/qc/test_task_metrics.py index a2b64f5ea..404876160 100644 --- a/ibllib/tests/qc/test_task_metrics.py +++ b/ibllib/tests/qc/test_task_metrics.py @@ -527,7 +527,8 @@ def setUp(self): eid = '8dd0fcb0-1151-4c97-ae35-2e2421695ad7' one = ONE(**TEST_DB) self.qc = qcmetrics.HabituationQC(eid, one=one) - self.qc.extractor = Bunch({'data': self.load_fake_bpod_data()}) # Dummy extractor obj + # Dummy extractor obj + self.qc.extractor = Bunch({'data': self.load_fake_bpod_data(), 'settings': {}}) @staticmethod def load_fake_bpod_data(n=5): @@ -578,5 +579,5 @@ def test_compute(self): self.assertEqual(outcomes['_task_habituation_time'], 'NOT_SET') -if __name__ == "__main__": +if __name__ == '__main__': unittest.main(exit=False, verbosity=2) From 262bb753830d42ba5a75462a05eb03a787c0c13e Mon Sep 17 00:00:00 2001 From: Miles Wells Date: Mon, 6 Nov 2023 12:53:01 +0200 Subject: [PATCH 43/68] Extend globus module deprecation --- ibllib/tests/test_io.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ibllib/tests/test_io.py b/ibllib/tests/test_io.py index 839fa57f8..7bc75951f 100644 --- a/ibllib/tests/test_io.py +++ b/ibllib/tests/test_io.py @@ -363,7 +363,7 @@ def setUp(self): self.addCleanup(self.patcher.stop) def test_as_globus_path(self): - assert datetime.now() < datetime(2023, 10, 30) + assert datetime.now() < datetime(2024, 1, 30), 'remove deprecated module' # A Windows path if sys.platform == 'win32': # "/E/FlatIron/integration" @@ -381,7 +381,7 @@ def test_as_globus_path(self): @unittest.mock.patch('iblutil.io.params.read') def test_login_auto(self, mock_params): - assert datetime.now() < datetime(2023, 10, 30) + assert datetime.now() < datetime(2024, 1, 30), 'remove deprecated module' client_id = 'h3u2ier' # Test ValueError thrown with incorrect parameters mock_params.return_value = None # No parameters saved From 036d6ea9f89c3dfa4cec8d122b2f4d942dac5104 Mon Sep 17 00:00:00 2001 From: Miles Wells Date: Wed, 15 Nov 2023 13:27:14 +0200 Subject: [PATCH 44/68] Issue #666 --- ibllib/io/extractors/habituation_trials.py | 93 +++++++++++++++------- 1 file changed, 63 insertions(+), 30 deletions(-) diff --git a/ibllib/io/extractors/habituation_trials.py b/ibllib/io/extractors/habituation_trials.py index 9dedbd3d5..59a29a269 100644 --- a/ibllib/io/extractors/habituation_trials.py +++ b/ibllib/io/extractors/habituation_trials.py @@ -1,12 +1,11 @@ +"""Habituation ChoiceWorld Bpod trials extraction.""" import logging import numpy as np import ibllib.io.raw_data_loaders as raw from ibllib.io.extractors.base import BaseBpodTrialsExtractor, run_extractor_classes from ibllib.io.extractors.biased_trials import ContrastLR -from ibllib.io.extractors.training_trials import ( - FeedbackTimes, StimOnTriggerTimes, Intervals, GoCueTimes -) +from ibllib.io.extractors.training_trials import FeedbackTimes, StimOnTriggerTimes, GoCueTimes _logger = logging.getLogger(__name__) @@ -24,9 +23,24 @@ def __init__(self, *args, **kwargs): self.save_names = tuple(f'_ibl_trials.{x}.npy' if x not in exclude else None for x in self.var_names) def _extract(self) -> dict: + """ + Extract the Bpod trial events. + + The Bpod state machine for this task has extremely misleading names! The 'iti' state is + actually the delay between valve open and trial end (the stimulus is still present during + this period), and the 'trial_start' state is actually the ITI during which there is a 1s + Bpod TTL and gray screen period. + + Returns + ------- + dict + A dictionary of Bpod trial events. The keys are defined in the `var_names` attribute. + """ # Extract all trials... - # Get all stim_sync events detected + # Get all detected TTLs. These are stored for QC purposes + self.frame2ttl, self.audio = raw.load_bpod_fronts(self.session_path, data=self.bpod_trials) + # These are the frame2TTL pulses as a list of lists, one per trial ttls = [raw.get_port_events(tr, 'BNC1') for tr in self.bpod_trials] # Report missing events @@ -38,10 +52,45 @@ def _extract(self) -> dict: _logger.warning(f'{self.session_path}: Missing BNC1 TTLs on {n_missing} trial(s)') # Extract datasets common to trainingChoiceWorld - training = [ContrastLR, FeedbackTimes, Intervals, GoCueTimes, StimOnTriggerTimes] + training = [ContrastLR, FeedbackTimes, GoCueTimes, StimOnTriggerTimes] out, _ = run_extractor_classes(training, session_path=self.session_path, save=False, bpod_trials=self.bpod_trials, settings=self.settings, task_collection=self.task_collection) + """ + The 'trial_start' state is in fact the 1s grey screen period, therefore the first timestamp + is really the end of the previous trial and also the stimOff trigger time. The second + timestamp is the true trial start time. + """ + (_, *ends), starts = zip(*[ + t['behavior_data']['States timestamps']['trial_start'][-1] for t in self.bpod_trials] + ) + + # StimOffTrigger times + out['stimOffTrigger_times'] = np.array(ends) + + # StimOff times + """ + There should be exactly three TTLs per trial. stimOff_times should be the first TTL pulse. + If 1 or more pulses are missing, we can not be confident of assigning the correct one. + """ + out['stimOff_times'] = np.array([sync[0] if len(sync) == 3 else np.nan + for sync, off in zip(ttls[1:], ends)]) + + # Trial intervals + """ + In terms of TTLs, the intervals are defined by the 'trial_start' state, however the stim + off time often happens after the trial end TTL front, i.e. after the 'trial_start' start + begins. For these trials, we set the trial end time as the stim off time. + """ + # NB: We lose the last trial because the stim off event occurs at trial_num + 1 + n_trials = out['stimOff_times'].size + out['intervals'] = np.c_[starts, np.r_[ends, np.nan]][:n_trials, :] + to_update = out['intervals'][:, 1] < out['stimOff_times'] + out['intervals'][to_update, 1] = out['stimOff_times'][to_update] + + # itiIn times + out['itiIn_times'] = np.r_[ends, np.nan] + # GoCueTriggerTimes is the same event as StimOnTriggerTimes out['goCueTrigger_times'] = out['stimOnTrigger_times'].copy() @@ -75,38 +124,22 @@ def _extract(self) -> dict: trial_volume = [x['reward_amount'] for x in self.bpod_trials] out['rewardVolume'] = np.array(trial_volume).astype(np.float64) - # StimOffTrigger times - # StimOff occurs at trial start (ignore the first trial's state update) - out['stimOffTrigger_times'] = np.array( - [tr["behavior_data"]["States timestamps"] - ["trial_start"][0][0] for tr in self.bpod_trials[1:]] - ) - - # StimOff times - """ - There should be exactly three TTLs per trial. stimOff_times should be the first TTL pulse. - If 1 or more pulses are missing, we can not be confident of assigning the correct one. - """ - trigg = out['stimOffTrigger_times'] - out['stimOff_times'] = np.array([sync[0] if len(sync) == 3 else np.nan - for sync, off in zip(ttls[1:], trigg)]) - # FeedbackType is always positive out['feedbackType'] = np.ones(len(out['feedback_times']), dtype=np.int8) - # ItiIn times - out['itiIn_times'] = np.array( - [tr["behavior_data"]["States timestamps"] - ["iti"][0][0] for tr in self.bpod_trials] - ) - # Phase and position out['position'] = np.array([t['position'] for t in self.bpod_trials]) out['phase'] = np.array([t['stim_phase'] for t in self.bpod_trials]) - # NB: We lose the last trial because the stim off event occurs at trial_num + 1 - n_trials = out['stimOff_times'].size - # return [out[k][:n_trials] for k in self.var_names] + # Double-check that the early and late trial events occur within the trial intervals + idx = ~np.isnan(out['stimOn_times'][:n_trials]) + if np.any(out['stimOn_times'][:n_trials][idx] < out['intervals'][idx, 0]): + _logger.warning('Stim on events occurring outside trial intervals') + idx = ~np.isnan(out['stimOff_times']) + if np.any(out['stimOff_times'][idx] > out['intervals'][idx, 1]): + _logger.warning('Stim off events occurring outside trial intervals') + + # Truncate arrays and return in correct order return {k: out[k][:n_trials] for k in self.var_names} From 0159729e7788bec355e63fa19d5ee8a84ff520bf Mon Sep 17 00:00:00 2001 From: Miles Wells Date: Fri, 17 Nov 2023 15:19:16 +0200 Subject: [PATCH 45/68] FpgaTrials refactor and FpgaHabituationTrials subclass --- ibllib/io/extractors/base.py | 9 + ibllib/io/extractors/biased_trials.py | 2 + ibllib/io/extractors/ephys_fpga.py | 866 +++++++++++++++++-- ibllib/io/extractors/habituation_trials.py | 19 +- ibllib/io/session_params.py | 1 - ibllib/pipes/behavior_tasks.py | 64 +- ibllib/pipes/dynamic_pipeline.py | 28 +- ibllib/qc/task_metrics.py | 29 +- ibllib/tests/extractors/test_ephys_fpga.py | 187 +--- ibllib/tests/extractors/test_ephys_trials.py | 189 ++++ 10 files changed, 1092 insertions(+), 302 deletions(-) diff --git a/ibllib/io/extractors/base.py b/ibllib/io/extractors/base.py index c1b46b22e..cfc9557f4 100644 --- a/ibllib/io/extractors/base.py +++ b/ibllib/io/extractors/base.py @@ -29,9 +29,16 @@ class BaseExtractor(abc.ABC): """ session_path = None + """pathlib.Path: Absolute path of session folder.""" + save_names = None + """tuple of str: The filenames of each extracted dataset, or None if array should not be saved.""" + var_names = None + """tuple of str: A list of names for the extracted variables. These become the returned output keys.""" + default_path = Path('alf') # relative to session + """pathlib.Path: The default output folder relative to `session_path`.""" def __init__(self, session_path=None): # If session_path is None Path(session_path) will fail @@ -127,6 +134,8 @@ class BaseBpodTrialsExtractor(BaseExtractor): bpod_trials = None settings = None task_collection = None + frame2ttl = None + audio = None def extract(self, bpod_trials=None, settings=None, **kwargs): """ diff --git a/ibllib/io/extractors/biased_trials.py b/ibllib/io/extractors/biased_trials.py index 16d8f8111..e2912d11e 100644 --- a/ibllib/io/extractors/biased_trials.py +++ b/ibllib/io/extractors/biased_trials.py @@ -183,6 +183,8 @@ class EphysTrials(BaseBpodTrialsExtractor): def _extract(self, extractor_classes=None, **kwargs) -> dict: base = [GoCueTriggerTimes, StimOnTriggerTimes, ItiInTimes, StimOffTriggerTimes, StimFreezeTriggerTimes, ErrorCueTriggerTimes, TrialsTableEphys, IncludedTrials, PhasePosQuiescence] + # Get all detected TTLs. These are stored for QC purposes + self.frame2ttl, self.audio = raw.load_bpod_fronts(self.session_path, data=self.bpod_trials) # Exclude from trials table out, _ = run_extractor_classes(base, session_path=self.session_path, bpod_trials=self.bpod_trials, settings=self.settings, save=False, task_collection=self.task_collection) diff --git a/ibllib/io/extractors/ephys_fpga.py b/ibllib/io/extractors/ephys_fpga.py index 74ac1e551..ad2cb0ab5 100644 --- a/ibllib/io/extractors/ephys_fpga.py +++ b/ibllib/io/extractors/ephys_fpga.py @@ -1,14 +1,46 @@ -"""Data extraction from raw FPGA output -Complete FPGA data extraction depends on Bpod extraction +"""Data extraction from raw FPGA output. + +The behaviour extraction happens in the following stages: + + 1. The NI DAQ events are extracted into a map of event times and TTL polarities. + 2. The Bpod trial events are extracted from the raw Bpod data, depending on the task protocol. + 3. As protocols may be chained together within a given recording, the period of a given task + protocol is determined using the 'spacer' DAQ signal (see `get_protocol_period`). + 4. Physical behaviour events such as stim on and reward time are separated out by TTL length or + sequence within the trial. + 5. The Bpod clock is sync'd with the FPGA using one of the extracted trial events. + 6. The Bpod software events are then converted to FPGA time. + +Examples +-------- +For simple extraction, use the FPGATrials class: + +>>> extractor = FpgaTrials(session_path) +>>> trials, _ = extractor.extract(update=False, save=False) + +Notes +----- +Sync extraction in this module only supports FPGA data acquired with an NI DAQ as part of a +Neuropixels recording system, however a sync and channel map extracted from a different DAQ format +can be passed to the FpgaTrials class. + +See Also +-------- +For dynamic pipeline sessions it is best to call the extractor via the BehaviorTask class. + +TODO notes on subclassing various methods of FpgaTrials for custom hardware. """ -from collections import OrderedDict import logging +from itertools import cycle from pathlib import Path import uuid import re +import warnings import matplotlib.pyplot as plt +from matplotlib.colors import TABLEAU_COLORS import numpy as np +from packaging import version import spikeglx import neurodsp.utils @@ -21,17 +53,22 @@ from ibllib.io.extractors.bpod_trials import extract_all as bpod_extract_all import ibllib.io.extractors.base as extractors_base from ibllib.io.extractors.training_wheel import extract_wheel_moves -import ibllib.plots as plots +from ibllib import plots from ibllib.io.extractors.default_channel_maps import DEFAULT_MAPS _logger = logging.getLogger(__name__) -SYNC_BATCH_SIZE_SECS = 100 # number of samples to read at once in bin file for sync +SYNC_BATCH_SIZE_SECS = 100 +"""int: Number of samples to read at once in bin file for sync.""" + WHEEL_RADIUS_CM = 1 # stay in radians +"""float: The radius of the wheel used in the task. A value of 1 ensures units remain in radians.""" + WHEEL_TICKS = 1024 +"""int: The number of encoder pulses per channel for one complete rotation.""" -BPOD_FPGA_DRIFT_THRESHOLD_PPM = 150 # throws an error if bpod to fpga clock drift is higher -F2TTL_THRESH = 0.01 # consecutive pulses with less than this threshold ignored +BPOD_FPGA_DRIFT_THRESHOLD_PPM = 150 +"""int: Throws an error if Bpod to FPGA clock drift is higher than this value.""" CHMAPS = {'3A': {'ap': @@ -62,10 +99,11 @@ {'imec_sync': 6} }, } +"""dict: The default channel indices corresponding to various devices for different recording systems.""" def data_for_keys(keys, data): - """Check keys exist in 'data' dict and contain values other than None""" + """Check keys exist in 'data' dict and contain values other than None.""" return data is not None and all(k in data and data.get(k, None) is not None for k in keys) @@ -157,6 +195,8 @@ def _assign_events_bpod(bpod_t, bpod_polarities, ignore_first_valve=True): :param bpod_fronts: numpy vector containing polarity of fronts (1 rise, -1 fall) :param ignore_first_valve (True): removes detected valve events at indices le 2 :return: numpy arrays of times t_trial_start, t_valve_open and t_iti_in + + TODO Remove function (now using FpgaTrials._assign_events) """ TRIAL_START_TTL_LEN = 2.33e-4 # the TTL length is 0.1ms but this has proven to drift on # some bpods and this is the highest possible value that discriminates trial start from valve @@ -258,6 +298,8 @@ def _assign_events_audio(audio_t, audio_polarities, return_indices=False, displa :param display (False): for debug mode, displays the raw fronts overlaid with detections :return: numpy arrays t_ready_tone_in, t_error_tone_in :return: numpy arrays ind_ready_tone_in, ind_error_tone_in if return_indices=True + + TODO Remove function (now using FpgaTrials._assign_events) """ # make sure that there are no 2 consecutive fall or consecutive rise events assert np.all(np.abs(np.diff(audio_polarities)) == 2) @@ -285,13 +327,29 @@ def _assign_events_to_trial(t_trial_start, t_event, take='last'): """ Assign events to a trial given trial start times and event times. - Trials without an event - result in nan value in output time vector. + Trials without an event result in nan value in output time vector. The output has a consistent size with t_trial_start and ready to output to alf. - :param t_trial_start: numpy vector of trial start times - :param t_event: numpy vector of event times to assign to trials - :param take: 'last' or 'first' (optional, default 'last'): index to take in case of duplicates - :return: numpy array of event times with the same shape of trial start. + + Parameters + ---------- + t_trial_start : numpy.array + An array of start times, used to bin edges for assigning values from `t_event`. + t_event : numpy.array + An array of event times to assign to trials. + take : str {'first', 'last'}, int + 'first' takes first event > t_trial_start; 'last' takes last event < the next + t_trial_start; an int defines the index to take for events within trial bounds. The index + may be negative. + + Returns + ------- + numpy.array + An array the length of `t_trial_start` containing values from `t_event`. Unassigned values + are replaced with np.nan. + + See Also + -------- + FpgaTrials._assign_events - Assign trial events based on TTL length. """ # make sure the events are sorted try: @@ -316,7 +374,7 @@ def _assign_events_to_trial(t_trial_start, t_event, take='last'): else: # if the index is arbitrary, needs to be numeric (could be negative if from the end) iall = np.unique(ind) minsize = take + 1 if take >= 0 else - take - # for each trial, take the takenth element if there are enough values in trial + # for each trial, take the take nth element if there are enough values in trial for iu in iall: match = t_event[iu == ind] if len(match) >= minsize: @@ -382,25 +440,39 @@ def _clean_audio(audio, display=False): return audio -def _clean_frame2ttl(frame2ttl, display=False): +def _clean_frame2ttl(frame2ttl, threshold=0.01, display=False): """ + Clean the frame2ttl events. + Frame 2ttl calibration can be unstable and the fronts may be flickering at an unrealistic pace. This removes the consecutive frame2ttl pulses happening too fast, below a threshold - of F2TTL_THRESH + of F2TTL_THRESH. + + Parameters + ---------- + frame2ttl : dict + A dictionary of frame2TTL events, with keys {'times', 'polarities'}. + threshold : float + Consecutive pulses occurring with this many seconds ignored. + display : bool + If true, plots the input TTLs and the cleaned output. + + Returns + ------- + """ dt = np.diff(frame2ttl['times']) - iko = np.where(np.logical_and(dt < F2TTL_THRESH, frame2ttl['polarities'][:-1] == -1))[0] + iko = np.where(np.logical_and(dt < threshold, frame2ttl['polarities'][:-1] == -1))[0] iko = np.unique(np.r_[iko, iko + 1]) frame2ttl_ = {'times': np.delete(frame2ttl['times'], iko), 'polarities': np.delete(frame2ttl['polarities'], iko)} if iko.size > (0.1 * frame2ttl['times'].size): _logger.warning(f'{iko.size} ({iko.size / frame2ttl["times"].size:.2%}) ' - f'frame to TTL polarity switches below {F2TTL_THRESH} secs') + f'frame to TTL polarity switches below {threshold} secs') if display: # pragma: no cover - from ibllib.plots import squares - plt.figure() - squares(frame2ttl['times'] * 1000, frame2ttl['polarities'], yrange=[0.1, 0.9]) - squares(frame2ttl_['times'] * 1000, frame2ttl_['polarities'], yrange=[1.1, 1.9]) + fig, (ax0, ax1) = plt.subplots(2, sharex=True) + plots.squares(frame2ttl['times'] * 1000, frame2ttl['polarities'], yrange=[0.1, 0.9], ax=ax0) + plots.squares(frame2ttl_['times'] * 1000, frame2ttl_['polarities'], yrange=[1.1, 1.9], ax=ax1) import seaborn as sns sns.displot(dt[dt < 0.05], binwidth=0.0005) @@ -425,9 +497,9 @@ def extract_wheel_sync(sync, chmap=None, tmin=None, tmax=None): Returns ------- - np.array + numpy.array Wheel timestamps in seconds. - np.array + numpy.array Wheel positions in radians. """ # Assume two separate edge count channels @@ -440,7 +512,7 @@ def extract_wheel_sync(sync, chmap=None, tmin=None, tmax=None): return re_ts, re_pos -def extract_behaviour_sync(sync, chmap=None, display=False, bpod_trials=None, tmin=None, tmax=None): +def extract_behaviour_sync(sync, chmap, display=False, bpod_trials=None, tmin=None, tmax=None): """ Extract task related event times from the sync. @@ -463,6 +535,8 @@ def extract_behaviour_sync(sync, chmap=None, display=False, bpod_trials=None, tm ------- dict A map of trial event timestamps. + + TODO Remove this function (now using FpgaTrials.extract_behaviour_sync) """ bpod = get_sync_fronts(sync, chmap['bpod'], tmin=tmin, tmax=tmax) if bpod.times.size == 0: @@ -476,6 +550,7 @@ def extract_behaviour_sync(sync, chmap=None, display=False, bpod_trials=None, tm t_trial_start, t_valve_open, t_iti_in = _assign_events_bpod(bpod['times'], bpod['polarities']) if not bpod_trials: raise ValueError('No Bpod trials to align') + intervals_bpod = bpod_trials['intervals'] # If there are no detected trial start times or more than double the trial end pulses, # the trial start pulses may be too small to be detected, in which case, sync using the ini_in if t_trial_start.size == 0 or (t_trial_start.size / t_iti_in.size) < .5: @@ -486,12 +561,12 @@ def extract_behaviour_sync(sync, chmap=None, display=False, bpod_trials=None, tm # if it's drifting too much if drift > 200 and bpod_end.size != t_iti_in.size: raise err.SyncBpodFpgaException('sync cluster f*ck') - t_trial_start = fcn(bpod_trials['intervals_bpod'][:, 0]) + t_trial_start = fcn(intervals_bpod[:, 0]) else: # one issue is that sometimes bpod pulses may not have been detected, in this case # perform the sync bpod/FPGA, and add the start that have not been detected _logger.info('Attempting to align on trial start') - bpod_start = bpod_trials['intervals_bpod'][:, 0] + bpod_start = intervals_bpod[:, 0] fcn, drift, ibpod, ifpga = neurodsp.utils.sync_timestamps( bpod_start, t_trial_start, return_indices=True) # if it's drifting too much @@ -703,34 +778,39 @@ def get_protocol_period(session_path, protocol_number, bpod_sync): class FpgaTrials(extractors_base.BaseExtractor): - save_names = ('_ibl_trials.intervals_bpod.npy', - '_ibl_trials.goCueTrigger_times.npy', None, None, None, None, None, None, None, + save_names = ('_ibl_trials.goCueTrigger_times.npy', None, None, None, None, None, None, None, '_ibl_trials.stimOff_times.npy', None, None, None, '_ibl_trials.quiescencePeriod.npy', '_ibl_trials.table.pqt', '_ibl_wheel.timestamps.npy', '_ibl_wheel.position.npy', '_ibl_wheelMoves.intervals.npy', '_ibl_wheelMoves.peakAmplitude.npy') - var_names = ('intervals_bpod', - 'goCueTrigger_times', 'stimOnTrigger_times', + var_names = ('goCueTrigger_times', 'stimOnTrigger_times', 'stimOffTrigger_times', 'stimFreezeTrigger_times', 'errorCueTrigger_times', 'errorCue_times', 'itiIn_times', 'stimFreeze_times', 'stimOff_times', 'valveOpen_times', 'phase', 'position', 'quiescence', 'table', 'wheel_timestamps', 'wheel_position', 'wheelMoves_intervals', 'wheelMoves_peakAmplitude') - # Fields from bpod extractor that we want to re-sync to FPGA bpod_rsync_fields = ('intervals', 'response_times', 'goCueTrigger_times', 'stimOnTrigger_times', 'stimOffTrigger_times', 'stimFreezeTrigger_times', 'errorCueTrigger_times') + """tuple of str: Fields from Bpod extractor that we want to re-sync to FPGA.""" - # Fields from bpod extractor that we want to save bpod_fields = ('feedbackType', 'choice', 'rewardVolume', 'contrastLeft', 'contrastRight', - 'probabilityLeft', 'intervals_bpod', 'phase', 'position', 'quiescence') + 'probabilityLeft', 'phase', 'position', 'quiescence') + """tuple of str: Fields from bpod extractor that we want to save.""" + + sync_field = 'intervals_0' # trial start events + """str: The trial event to synchronize (must be present in extracted trials).""" - """str: The Bpod events to synchronize (must be present in sync channel map).""" - sync_field = 'intervals' + bpod = None + """dict of numpy.array: The Bpod out TTLs recorded on the DAQ. Used in the QC viewer plot.""" def __init__(self, *args, bpod_trials=None, bpod_extractor=None, **kwargs): - """An extractor for all ephys trial data, in FPGA time""" + """An extractor for ephysChoiceWorld trials data, in FPGA time. + + This class may be subclassed to handle moderate variations in hardware and task protocol, + however there is flexible + """ super().__init__(*args, **kwargs) self.bpod2fpga = None self.bpod_trials = bpod_trials @@ -781,7 +861,7 @@ def _update_var_names(self, bpod_fields=None, bpod_rsync_fields=None): if not self.bpod_trials: self.bpod_trials = self.bpod_extractor.extract(save=False) table_keys = alfio.AlfBunch.from_df(self.bpod_trials['table']).keys() - self.bpod_fields += (*[x for x in table_keys if x not in excluded], self.sync_field + '_bpod') + self.bpod_fields += tuple([x for x in table_keys if x not in excluded]) @staticmethod def _time_fields(trials_attr) -> set: @@ -802,72 +882,266 @@ def _time_fields(trials_attr) -> set: pattern = re.compile(fr'^[_\w]*({"|".join(FIELDS)})[_\w]*$') return set(filter(pattern.match, trials_attr)) + def load_sync(self, sync_collection='raw_ephys_data', **kwargs): + """Load the DAQ sync and channel map data. + + This method may be subclassed for novel DAQ systems. The sync must contain the following + keys: 'times' - an array timestamps in seconds; 'polarities' - an array of {-1, 1} + corresponding to TTL LOW and TTL HIGH, respectively; 'channels' - an array of ints + corresponding to channel number. + + Parameters + ---------- + sync_collection : str + The session subdirectory where the sync data are located. + kwargs + Optional arguments used by subclass methods. + + Returns + ------- + one.alf.io.AlfBunch + A dictionary with keys ('times', 'polarities', 'channels'), containing the sync pulses and + the corresponding channel numbers. + dict + A map of channel names and their corresponding indices. + """ + return get_sync_and_chn_map(self.session_path, sync_collection) + def _extract(self, sync=None, chmap=None, sync_collection='raw_ephys_data', task_collection='raw_behavior_data', **kwargs) -> dict: - """Extracts ephys trials by combining Bpod and FPGA sync pulses""" - # extract the behaviour data from bpod + """Extracts ephys trials by combining Bpod and FPGA sync pulses. + + It is essential that the `var_names`, `bpod_rsync_fields`, `bpod_fields`, and `sync_field` + attributes are all correct for the bpod protocol used. + + Below are the steps involved: + 0. Load sync and bpod trials, if required. + 1. Determine protocol period and discard sync events outside the task. + 2. Classify and attribute DAQ TTLs to trial events (see :meth:`FpgaTrials.extract_behaviour_sync`). + 3. Sync the Bpod clock to the DAQ clock using one of the assigned trial events. + 4. Convert Bpod software event times to DAQ clock. + 5. Extract the wheel from the DAQ rotary encoder signal, if required. + + Parameters + ---------- + sync : dict + A dictionary with keys ('times', 'polarities', 'channels'), containing the sync pulses + and the corresponding channel numbers. If None, the sync is loaded using the + `load_sync` method. + chmap : dict + A map of channel names and their corresponding indices. If None, the channel map is + loaded using the :meth:`FpgaTrials.load_sync` method. + sync_collection : str + The session subdirectory where the sync data are located. This is only used if the + sync or channel maps are not provided. + task_collection : str + The session subdirectory where the raw Bpod data are located. This is used for loading + the task settings and extracting the bpod trials, if not already done. + protocol_number : int + The protocol number if multiple protocols were run during the session. If provided, a + spacer signal must be present in order to determine the correct period. + kwargs + Optional arguments for subclass methods to use. + + Returns + ------- + dict + A dictionary of numpy arrays with `FpgaTrials.var_names` as keys. + """ if sync is None or chmap is None: - _sync, _chmap = get_sync_and_chn_map(self.session_path, sync_collection) + _sync, _chmap = self.load_sync(sync_collection) sync = sync or _sync chmap = chmap or _chmap - if not self.bpod_trials: + if not self.bpod_trials: # extract the behaviour data from bpod self.bpod_trials, *_ = bpod_extract_all( session_path=self.session_path, task_collection=task_collection, save=False, extractor_type=kwargs.get('extractor_type')) + # Explode trials table df - trials_table = alfio.AlfBunch.from_df(self.bpod_trials.pop('table')) - table_columns = trials_table.keys() - self.bpod_trials.update(trials_table) - self.bpod_trials['intervals_bpod'] = np.copy(self.bpod_trials['intervals']) + if 'table' in self.var_names: + trials_table = alfio.AlfBunch.from_df(self.bpod_trials.pop('table')) + table_columns = trials_table.keys() + self.bpod_trials.update(trials_table) + else: + if 'table' in self.bpod_trials: + _logger.error( + '"table" found in Bpod trials but missing from `var_names` attribute and will' + 'therefore not be extracted. This is likely in error.') + table_columns = None # Get the spacer times for this protocol - if (protocol_number := kwargs.get('protocol_number')) is not None: # look for spacer + if any(arg in kwargs for arg in ('tmin', 'tmax')): + tmin, tmax = kwargs.get('tmin'), kwargs.get('tmax') + elif (protocol_number := kwargs.get('protocol_number')) is not None: # look for spacer # The spacers are TTLs generated by Bpod at the start of each protocol bpod = get_sync_fronts(sync, chmap['bpod']) tmin, tmax = get_protocol_period(self.session_path, protocol_number, bpod) else: tmin = tmax = None - # Store the cleaned frame2ttl, audio, and bpod pulses as this will be used for QC - fpga_trials, self.frame2ttl, self.audio, self.bpod = extract_behaviour_sync( - sync=sync, chmap=chmap, bpod_trials=self.bpod_trials, tmin=tmin, tmax=tmax) - assert self.sync_field in self.bpod_trials and self.sync_field in fpga_trials - self.bpod_trials[f'{self.sync_field}_bpod'] = np.copy(self.bpod_trials[self.sync_field]) - - # checks consistency and compute dt with bpod - self.bpod2fpga, drift_ppm, ibpod, ifpga = neurodsp.utils.sync_timestamps( - self.bpod_trials[f'{self.sync_field}_bpod'][:, 0], fpga_trials.pop(self.sync_field)[:, 0], - return_indices=True) - nbpod = self.bpod_trials[f'{self.sync_field}_bpod'].shape[0] - npfga = fpga_trials['feedback_times'].shape[0] - nsync = len(ibpod) - _logger.info(f'N trials: {nbpod} bpod, {npfga} FPGA, {nsync} merged, sync {drift_ppm} ppm') - if drift_ppm > BPOD_FPGA_DRIFT_THRESHOLD_PPM: - _logger.warning('BPOD/FPGA synchronization shows values greater than %i ppm', - BPOD_FPGA_DRIFT_THRESHOLD_PPM) - out = OrderedDict() + # Remove unnecessary data from sync + selection = np.logical_and( + sync['times'] <= (tmax if tmax is not None else sync['times'][-1]), + sync['times'] >= (tmin if tmin is not None else sync['times'][0]), + ) + sync = alfio.AlfBunch({k: v[selection] for k, v in sync.items()}) + _logger.debug('Protocol period from %.2fs to %.2fs (~%.0f min duration)', + *sync['times'][[0, -1]], np.diff(sync['times'][[0, -1]]) / 60) + + # Get the trial events from the DAQ sync TTLs + fpga_trials = self.extract_behaviour_sync(sync, chmap, **kwargs) + + # Sync the Bpod clock to the DAQ + self.bpod2fpga, drift_ppm, ibpod, ifpga = self.sync_bpod_clock(self.bpod_trials, fpga_trials, self.sync_field) + + if np.any(np.diff(ibpod) != 1) and self.sync_field == 'intervals_0': + # One issue is that sometimes pulses may not have been detected, in this case + # add the events that have not been detected and re-extract the behaviour sync. + # This is only really relevant for the Bpod interval events as the other TTLs are + # from devices where a missing TTL likely means the Bpod event was truly absent. + _logger.warning('Missing Bpod TTLs; reassigning events using aligned Bpod start times') + bpod_start = self.bpod_trials['intervals'][:, 0] + missing_bpod = self.bpod2fpga(bpod_start[np.setxor1d(ibpod, np.arange(len(bpod_start)))]) + t_trial_start = np.sort(np.r_[fpga_trials['intervals'][:, 0], missing_bpod]) + fpga_trials = self.extract_behaviour_sync(sync, chmap, start_times=t_trial_start, **kwargs) + + out = dict() out.update({k: self.bpod_trials[k][ibpod] for k in self.bpod_fields}) out.update({k: self.bpod2fpga(self.bpod_trials[k][ibpod]) for k in self.bpod_rsync_fields}) out.update({k: fpga_trials[k][ifpga] for k in sorted(fpga_trials.keys())}) # extract the wheel data - wheel, moves = self.get_wheel_positions(sync=sync, chmap=chmap, tmin=tmin, tmax=tmax) - from ibllib.io.extractors.training_wheel import extract_first_movement_times - if not self.settings: - self.settings = raw.load_settings(session_path=self.session_path, task_collection=task_collection) - min_qt = self.settings.get('QUIESCENT_PERIOD', None) - first_move_onsets, *_ = extract_first_movement_times(moves, out, min_qt=min_qt) - out.update({'firstMovement_times': first_move_onsets}) + if any(x.startswith('wheel') for x in self.var_names): + wheel, moves = self.get_wheel_positions(sync=sync, chmap=chmap, tmin=tmin, tmax=tmax) + from ibllib.io.extractors.training_wheel import extract_first_movement_times + if not self.settings: + self.settings = raw.load_settings(session_path=self.session_path, task_collection=task_collection) + min_qt = self.settings.get('QUIESCENT_PERIOD', None) + first_move_onsets, *_ = extract_first_movement_times(moves, out, min_qt=min_qt) + out.update({'firstMovement_times': first_move_onsets}) + out.update({f'wheel_{k}': v for k, v in wheel.items()}) + out.update({f'wheelMoves_{k}': v for k, v in moves.items()}) + # Re-create trials table - trials_table = alfio.AlfBunch({x: out.pop(x) for x in table_columns}) - out['table'] = trials_table.to_df() + if table_columns: + trials_table = alfio.AlfBunch({x: out.pop(x) for x in table_columns}) + out['table'] = trials_table.to_df() - out.update({f'wheel_{k}': v for k, v in wheel.items()}) - out.update({f'wheelMoves_{k}': v for k, v in moves.items()}) - out = {k: out[k] for k in self.var_names if k in out} # Reorder output + out = alfio.AlfBunch({k: out[k] for k in self.var_names if k in out}) # Reorder output assert self.var_names == tuple(out.keys()) return out + def extract_behaviour_sync(self, sync, chmap, start_times=None, display=False, **kwargs): + """ + Extract task related event times from the sync. + + The trial start times are the shortest Bpod TTLs and occur at the start of the trial. The + first trial start TTL of the session is longer and must be handled differently. The trial + start TTL is used to assign the other trial events to each trial. + + The trial end is the end of the so-called 'ITI' Bpod event TTL (classified as the longest + of the three Bpod event TTLs). Go cue audio TTLs are the shorter of the two expected audio + tones. The first of these after each trial start is taken to be the go cue time. Error + tones are longer audio TTLs and assigned as the last of such occurrence after each trial + start. The valve open Bpod TTLs are medium-length, the last of which is used for each trial. + The feedback times are times of either valve open or error tone as there should be only one + such event per trial. + + The stimulus times are taken from the frame2ttl events (with improbably high frequency TTLs + removed): the first TTL after each trial start is assumed to be the stim onset time; the + second to last and last are taken as the stimulus freeze and offset times, respectively. + + Parameters + ---------- + sync : dict + 'polarities' of fronts detected on sync trace for all 16 chans and their 'times' + chmap : dict + Map of channel names and their corresponding index. Default to constant. + start_times : numpy.array + An optional array of timestamps to separate trial events by. This is useful if after + syncing the clocks, some trial start TTLs are found to be missed. If None, uses + 'trial_start' Bpod event. + display : bool, matplotlib.pyplot.Axes + Show the full session sync pulses display. + + Returns + ------- + dict + A map of trial event timestamps. + """ + # Get the events from the sync. + # Store the cleaned frame2ttl, audio, and bpod pulses as this will be used for QC + self.frame2ttl = self.get_stimulus_update_times(sync, chmap, **kwargs) + self.audio, audio_event_intervals = self.get_audio_event_times(sync, chmap, **kwargs) + if not set(audio_event_intervals.keys()) >= {'ready_tone', 'error_tone'}: + raise ValueError( + 'Expected at least "ready_tone" and "error_tone" audio events.' + '`audio_event_ttls` kwarg may be incorrect.') + self.bpod, bpod_event_intervals = self.get_bpod_event_times(sync, chmap, **kwargs) + if not set(bpod_event_intervals.keys()) >= {'trial_start', 'valve_open', 'trial_end'}: + raise ValueError( + 'Expected at least "trial_start", "trial_end", and "valve_open" audio events. ' + '`bpod_event_ttls` kwarg may be incorrect.') + + # The first trial pulse is longer and often assigned to another event. + # Here we move the earliest non-trial_start event to the trial_start array. + t0 = bpod_event_intervals['trial_start'][0, 0] # expect 1st event to be trial_start + pretrial = [(k, v[0, 0]) for k, v in bpod_event_intervals.items() if v.size and v[0, 0] < t0] + if pretrial: + (pretrial, _) = sorted(pretrial, key=lambda x: x[1])[0] # take the earliest event + dt = np.diff(bpod_event_intervals[pretrial][0, :]) * 1e3 # record TTL length to log + _logger.debug('Reassigning first %s to trial_start. TTL length = %.3g ms', pretrial, dt) + bpod_event_intervals['trial_start'] = np.r_[ + bpod_event_intervals[pretrial][0:1, :], bpod_event_intervals['trial_start'] + ] + bpod_event_intervals[pretrial] = bpod_event_intervals[pretrial][1:, :] + + t_iti_in, t_trial_end = bpod_event_intervals['trial_end'].T + # Drop last trial start if incomplete + t_trial_start = bpod_event_intervals['trial_start'][:len(t_trial_end), 0] + t_valve_open = bpod_event_intervals['valve_open'][:, 0] + t_ready_tone_in = audio_event_intervals['ready_tone'][:, 0] + t_error_tone_in = audio_event_intervals['error_tone'][:, 0] + + start_times = start_times or t_trial_start + + trials = alfio.AlfBunch({ + 'goCue_times': _assign_events_to_trial(start_times, t_ready_tone_in, take='first'), + 'errorCue_times': _assign_events_to_trial(start_times, t_error_tone_in), + 'valveOpen_times': _assign_events_to_trial(start_times, t_valve_open), + 'stimFreeze_times': _assign_events_to_trial(start_times, self.frame2ttl['times'], take=-2), + 'stimOn_times': _assign_events_to_trial(start_times, self.frame2ttl['times'], take='first'), + 'stimOff_times': _assign_events_to_trial(start_times, self.frame2ttl['times']), + 'itiIn_times': _assign_events_to_trial(start_times, t_iti_in) + }) + + # feedback times are valve open on correct trials and error tone in on incorrect trials + trials['feedback_times'] = np.copy(trials['valveOpen_times']) + ind_err = np.isnan(trials['valveOpen_times']) + trials['feedback_times'][ind_err] = trials['errorCue_times'][ind_err] + trials['intervals'] = np.c_[start_times, t_trial_end] + + if display: # pragma: no cover + width = 0.5 + ymax = 5 + if isinstance(display, bool): + plt.figure('Bpod FPGA Sync') + ax = plt.gca() + else: + ax = display + plots.squares(self.bpod['times'], self.bpod['polarities'] * 0.4 + 1, ax=ax, color='k') + plots.squares(self.frame2ttl['times'], self.frame2ttl['polarities'] * 0.4 + 2, ax=ax, color='k') + plots.squares(self.audio['times'], self.audio['polarities'] * 0.4 + 3, ax=ax, color='k') + color_map = TABLEAU_COLORS.keys() + for (event_name, event_times), c in zip(trials.to_df().items(), cycle(color_map)): + plots.vertical_lines(event_times, ymin=0, ymax=ymax, ax=ax, color=c, label=event_name, linewidth=width) + ax.legend() + ax.set_yticks([0, 1, 2]) + ax.set_yticklabels(['bpod', 'f2ttl', 'audio']) + ax.set_ylim([0, 5]) + + return trials + def get_wheel_positions(self, *args, **kwargs): """Extract wheel and wheelMoves objects. @@ -875,6 +1149,432 @@ def get_wheel_positions(self, *args, **kwargs): """ return get_wheel_positions(*args, **kwargs) + def get_stimulus_update_times(self, sync, chmap, display=False, **_): + """ + Extract stimulus update times from sync. + + Gets the stimulus times from the frame2ttl channel and cleans the signal. + + Parameters + ---------- + sync : dict + A dictionary with keys ('times', 'polarities', 'channels'), containing the sync pulses + and the corresponding channel numbers. + chmap : dict + A map of channel names and their corresponding indices. Must contain a 'frame2ttl' key. + display : bool + If true, plots the input TTLs and the cleaned output. + + Returns + ------- + dict + A dictionary with keys {'times', 'polarities'} containing stimulus TTL fronts. + """ + frame2ttl = get_sync_fronts(sync, chmap['frame2ttl']) + frame2ttl = _clean_frame2ttl(frame2ttl, display=display) + return frame2ttl + + def get_audio_event_times(self, sync, chmap, audio_event_ttls=None, display=False, **_): + """ + Extract audio times from sync. + + Gets the TTL times from the 'audio' channel, cleans the signal, and classifies each TTL + event by length. + + Parameters + ---------- + sync : dict + A dictionary with keys ('times', 'polarities', 'channels'), containing the sync pulses + and the corresponding channel numbers. + chmap : dict + A map of channel names and their corresponding indices. Must contain an 'audio' key. + audio_event_ttls : dict + A map of event names to (min, max) TTL length. + display : bool + If true, plots the input TTLs and the cleaned output. + + Returns + ------- + dict + A dictionary with keys {'times', 'polarities'} containing audio TTL fronts. + dict + A dictionary of events (from `audio_event_ttls`) and their intervals as an Nx2 array. + """ + audio = get_sync_fronts(sync, chmap['audio']) + audio = _clean_audio(audio) + + if audio['times'].size == 0: + _logger.error('No audio sync fronts found.') + + if audio_event_ttls is None: + # For training/biased/ephys protocols, the ready tone should be below 110 ms. The error + # tone should be between 400ms and 1200ms + audio_event_ttls = {'ready_tone': (0, 0.11), 'error_tone': (0.4, 1.2)} + audio_event_intervals = self._assign_events(audio['times'], audio['polarities'], audio_event_ttls, display=display) + + return audio, audio_event_intervals + + def get_bpod_event_times(self, sync, chmap, bpod_event_ttls=None, display=False, **kwargs): + """ + Extract Bpod times from sync. + + Gets the Bpod TTL times from the sync 'bpod' channel and classifies each TTL event by + length. NB: The first trial has an abnormal trial_start TTL that is usually mis-assigned. + This is handled in the :meth:`FpgaTrials.extract_behaviour_sync` method. + + Parameters + ---------- + sync : dict + A dictionary with keys ('times', 'polarities', 'channels'), containing the sync pulses + and the corresponding channel numbers. Must contain a 'bpod' key. + chmap : dict + A map of channel names and their corresponding indices. + bpod_event_ttls : dict of tuple + A map of event names to (min, max) TTL length. + + Returns + ------- + dict + A dictionary with keys {'times', 'polarities'} containing Bpod TTL fronts. + dict + A dictionary of events (from `bpod_event_ttls`) and their intervals as an Nx2 array. + """ + bpod = get_sync_fronts(sync, chmap['bpod']) + if bpod.times.size == 0: + raise err.SyncBpodFpgaException('No Bpod event found in FPGA. No behaviour extraction. ' + 'Check channel maps.') + # Assign the Bpod BNC2 events based on TTL length. The defaults are below, however these + # lengths are defined by the state machine of the task protocol and therefore vary. + if bpod_event_ttls is None: + # For training/biased/ephys protocols, the trial start TTL length is 0.1ms but this has + # proven to drift on some Bpods and this is the highest possible value that + # discriminates trial start from valve. Valve open events are between 50ms to 300 ms. + # ITI events are above 400 ms. + bpod_event_ttls = { + 'trial_start': (0, 2.33e-4), 'valve_open': (2.33e-4, 0.4), 'trial_end': (0.4, np.inf)} + bpod_event_intervals = self._assign_events( + bpod['times'], bpod['polarities'], bpod_event_ttls, display=display) + + return bpod, bpod_event_intervals + + @staticmethod + def _assign_events(ts, polarities, event_lengths, precedence='shortest', display=False): + """ + Classify TTL events by length. + + Outputs the synchronisation events such as trial intervals, valve opening, and audio. + + Parameters + ---------- + ts : numpy.array + Numpy vector containing times of TTL fronts. + polarities : numpy.array + Numpy vector containing polarity of TTL fronts (1 rise, -1 fall). + event_lengths : dict of tuple + A map of TTL events and the range of permissible lengths, where l0 < ttl <= l1. + precedence : str {'shortest', 'longest', 'dict order'} + In the case of overlapping event TTL lengths, assign shortest/longest first or go by + the `event_lengths` dict order. + display : bool + If true, plots the TTLs with coloured lines delineating the assigned events. + + Returns + ------- + Dict[str, numpy.array] + A dictionary of events and their intervals as an Nx2 array. + + See Also + -------- + _assign_events_to_trial - classify TTLs by event order within a given trial period. + """ + event_intervals = dict.fromkeys(event_lengths) + assert 'unassigned' not in event_lengths.keys() + + if len(ts) == 0: + return {k: np.array([[], []]).T for k in (*event_lengths.keys(), 'unassigned')} + + # make sure that there are no 2 consecutive fall or consecutive rise events + assert np.all(np.abs(np.diff(polarities)) == 2) + if polarities[0] == -1: + ts = np.delete(ts, 0) + if polarities[-1] == 1: # if the final TTL is left HIGH, insert a NaN + ts = np.r_[ts, np.nan] + # take only even time differences: i.e. from rising to falling fronts + dt = np.diff(ts)[::2] + + # Assign events from shortest TTL to largest + assigned = np.zeros(ts.shape, dtype=bool) + if precedence.lower() == 'shortest': + event_items = sorted(event_lengths.items(), key=lambda x: np.diff(x[1])) + elif precedence.lower() == 'longest': + event_items = sorted(event_lengths.items(), key=lambda x: np.diff(x[1]), reverse=True) + elif precedence.lower() == 'dict order': + event_items = event_lengths.items() + else: + raise ValueError(f'Precedence must be one of "shortest", "longest", "dict order", got "{precedence}".') + for event, (min_len, max_len) in event_items: + _logger.debug('%s: %.4G < ttl <= %.4G', event, min_len, max_len) + i_event = np.where(np.logical_and(dt > min_len, dt <= max_len))[0] * 2 + i_event = i_event[np.where(~assigned[i_event])[0]] # remove those already assigned + event_intervals[event] = np.c_[ts[i_event], ts[i_event + 1]] + assigned[np.r_[i_event, i_event + 1]] = True + + # Include the unassigned events for convenience and debugging + event_intervals['unassigned'] = ts[~assigned].reshape(-1, 2) + + # Assert that event TTLs mutually exclusive + all_assigned = np.concatenate(list(event_intervals.values())).flatten() + assert all_assigned.size == np.unique(all_assigned).size, 'TTLs assigned to multiple events' + + # some debug plots when needed + if display: # pragma: no cover + plt.figure() + plots.squares(ts, polarities, label='raw fronts') + for event, intervals in event_intervals.items(): + plots.vertical_lines(intervals[:, 0], ymin=-0.2, ymax=1.1, linewidth=0.5, label=event) + plt.legend() + + # Return map of event intervals in the same order as `event_lengths` dict + return {k: event_intervals[k] for k in (*event_lengths, 'unassigned')} + + @staticmethod + def sync_bpod_clock(bpod_trials, fpga_trials, sync_field): + """ + Sync the Bpod clock to FPGA one using the provided trial event. + + It assumes that `sync_field` is in both `fpga_trials` and `bpod_trials`. Syncing on both + intervals is not supported so to sync on trial start times, `sync_field` should be + 'intervals_0'. + + Parameters + ---------- + bpod_trials : dict + A dictionary of extracted Bpod trial events. + fpga_trials : dict + A dictionary of trial events extracted from FPGA sync events (see + `extract_behaviour_sync` method). + sync_field : str + The trials key to use for syncing clocks. For intervals (i.e. Nx2 arrays) append the + column index, e.g. 'intervals_0'. + + Returns + ------- + function + Interpolation function such that f(timestamps_bpod) = timestamps_fpga. + float + The clock drift in parts per million. + numpy.array of int + The indices of the Bpod trial events in the FPGA trial events array. + numpy.array of int + The indices of the FPGA trial events in the Bpod trial events array. + + Raises + ------ + ValueError + The key `sync_field` was not found in either the `bpod_trials` or `fpga_trials` dicts. + """ + _logger.info(f'Attempting to align Bpod clock to DAQ using trial event "{sync_field}"') + if sync_field not in bpod_trials: + # handle syncing on intervals + if not (m := re.match(r'(.*)_(\d)', sync_field)): + raise ValueError(f'Sync field "{sync_field}" not in extracted bpod trials') + sync_field, i = m.groups() + timestamps_bpod = bpod_trials[sync_field][:, int(i)] + timestamps_fpga = fpga_trials[sync_field][:, int(i)] + elif sync_field not in fpga_trials: + raise ValueError(f'Sync field "{sync_field}" not in extracted fpga trials') + else: + timestamps_bpod = bpod_trials[sync_field] + timestamps_fpga = fpga_trials[sync_field] + + # Sync the two timestamps + fcn, drift, ibpod, ifpga = neurodsp.utils.sync_timestamps( + timestamps_bpod, timestamps_fpga, return_indices=True) + + # If it's drifting too much throw warning or error + _logger.info('N trials: %i bpod, %i FPGA, %i merged, sync %.5f ppm', + len(timestamps_bpod), len(timestamps_fpga), len(ibpod), drift) + if drift > 200 and timestamps_bpod.size != timestamps_fpga.size: + raise err.SyncBpodFpgaException('sync cluster f*ck') + elif drift > BPOD_FPGA_DRIFT_THRESHOLD_PPM: + _logger.warning('BPOD/FPGA synchronization shows values greater than %.2f ppm', + BPOD_FPGA_DRIFT_THRESHOLD_PPM) + + return fcn, drift, ibpod, ifpga + + +class FpgaTrialsHabituation(FpgaTrials): + """Extract habituationChoiceWorld trial events from an NI DAQ.""" + + save_names = ('_ibl_trials.stimCenter_times.npy', '_ibl_trials.feedbackType.npy', '_ibl_trials.rewardVolume.npy', + '_ibl_trials.stimOff_times.npy', '_ibl_trials.contrastLeft.npy', '_ibl_trials.contrastRight.npy', + '_ibl_trials.feedback_times.npy', '_ibl_trials.stimOn_times.npy', '_ibl_trials.stimOnTrigger_times.npy', + '_ibl_trials.intervals.npy', '_ibl_trials.goCue_times.npy', '_ibl_trials.goCueTrigger_times.npy', + None, None, None, None, None) + """tuple of str: The filenames of each extracted dataset, or None if array should not be saved.""" + + var_names = ('stimCenter_times', 'feedbackType', 'rewardVolume', 'stimOff_times', 'contrastLeft', + 'contrastRight', 'feedback_times', 'stimOn_times', 'stimOnTrigger_times', 'intervals', + 'goCue_times', 'goCueTrigger_times', 'itiIn_times', 'stimOffTrigger_times', + 'stimCenterTrigger_times', 'position', 'phase') + """tuple of str: A list of names for the extracted variables. These become the returned output keys.""" + + bpod_rsync_fields = ('intervals', 'stimOn_times', 'feedback_times', 'stimCenterTrigger_times', + 'goCue_times', 'itiIn_times', 'stimOffTrigger_times', 'stimOff_times', + 'stimCenter_times', 'stimOnTrigger_times', 'goCueTrigger_times') + """tuple of str: Fields from Bpod extractor that we want to re-sync to FPGA.""" + + bpod_fields = ('feedbackType', 'rewardVolume', 'contrastLeft', 'contrastRight', 'position', 'phase') + """tuple of str: Fields from Bpod extractor that we want to save.""" + + sync_field = 'feedback_times' # valve open events + """str: The trial event to synchronize (must be present in extracted trials).""" + + def _extract(self, sync=None, chmap=None, sync_collection='raw_ephys_data', + task_collection='raw_behavior_data', **kwargs) -> dict: + """ + Extract habituationChoiceWorld trial events from an NI DAQ. + + It is essential that the `var_names`, `bpod_rsync_fields`, `bpod_fields`, and `sync_field` + attributes are all correct for the bpod protocol used. + + Unlike FpgaTrials, this class assumes different Bpod TTL events and syncs the Bpod clock + using the valve open times, instead of the trial start times. + + Parameters + ---------- + sync : dict + A dictionary with keys ('times', 'polarities', 'channels'), containing the sync pulses + and the corresponding channel numbers. If None, the sync is loaded using the + `load_sync` method. + dict + A map of channel names and their corresponding indices. If None, the channel map is + loaded using the `load_sync` method. + sync_collection : str + The session subdirectory where the sync data are located. This is only used if the + sync or channel maps are not provided. + task_collection : str + The session subdirectory where the raw Bpod data are located. This is used for loading + the task settings and extracting the bpod trials, if not already done. + protocol_number : int + The protocol number if multiple protocols were run during the session. If provided, a + spacer signal must be present in order to determine the correct period. + kwargs + Optional arguments for class methods, e.g. 'display', 'bpod_event_ttls'. + + Returns + ------- + dict + A dictionary of numpy arrays with `FpgaTrialsHabituation.var_names` as keys. + """ + # Version check: the ITI in TTL was added in a later version + iblrig_version = version.parse(self.settings.get('IBL_VERSION', '0.0.0')) + if version.parse('8.9.3') <= iblrig_version < version.parse('8.12.6'): + """A second 1s TTL was added in this version during the 'iti' state, however this is + unrelated to the trial ITI and is unfortunately the same length as the trial start TTL.""" + raise NotImplementedError('Ambiguous TTLs in 8.9.3 >= version < 8.12.6') + + # Currently (at least v8.12 and below) there is no trial start or end TTL, only an ITI pulse + if 'bpod_event_ttls' not in kwargs: + kwargs['bpod_event_ttls'] = {'trial_iti': (1, 1.1), 'valve_open': (0, 0.4)} + trials = super()._extract(sync=sync, chmap=chmap, sync_collection=sync_collection, + task_collection=task_collection, **kwargs) + + n = trials['intervals'].shape[0] # number of trials + trials['intervals'][:, 1] = self.bpod2fpga(self.bpod_trials['intervals'][:n, 1]) + + return trials + + def extract_behaviour_sync(self, sync, chmap, start_times=None, display=False, **kwargs): + """ + Extract task related event times from the sync. + + This is called by the superclass `_extract` method. The key difference here is that the + `trial_start` LOW->HIGH is the trial end, and HIGH->LOW is trial start. + + Parameters + ---------- + sync : dict + 'polarities' of fronts detected on sync trace for all 16 chans and their 'times' + chmap : dict + Map of channel names and their corresponding index. Default to constant. + start_times : numpy.array + An optional array of timestamps to separate trial events by. This is useful if after + syncing the clocks, some trial start TTLs are found to be missed. If None, uses + 'trial_start' Bpod event. + display : bool, matplotlib.pyplot.Axes + Show the full session sync pulses display. + + Returns + ------- + dict + A map of trial event timestamps. + """ + # Get the events from the sync. + # Store the cleaned frame2ttl, audio, and bpod pulses as this will be used for QC + self.frame2ttl = self.get_stimulus_update_times(sync, chmap, **kwargs) + self.audio, audio_event_intervals = self.get_audio_event_times(sync, chmap, **kwargs) + self.bpod, bpod_event_intervals = self.get_bpod_event_times(sync, chmap, **kwargs) + if not set(bpod_event_intervals.keys()) >= {'valve_open', 'trial_iti'}: + raise ValueError( + 'Expected at least "trial_iti" and "valve_open" Bpod events. `bpod_event_ttls` kwarg may be incorrect.') + + # The first trial pulse is shorter and assigned to valve_open. Here we remove the first + # valve event, prepend a 0 to the trial_start events, and drop the last trial if it was + # incomplete in Bpod. + n_trials = self.bpod_trials['intervals'].shape[0] + t_valve_open = bpod_event_intervals['valve_open'][1:, 0] # drop first spurious valve event + t_ready_tone_in = audio_event_intervals['ready_tone'][:, 0] + t_trial_start = np.r_[0, bpod_event_intervals['trial_iti'][:, 1]] + t_trial_end = bpod_event_intervals['trial_iti'][:, 0] + + start_times = start_times or t_trial_start + + trials = alfio.AlfBunch({ + 'goCue_times': _assign_events_to_trial(start_times, t_ready_tone_in, take='first')[:n_trials], + 'feedback_times': _assign_events_to_trial(start_times, t_valve_open)[:n_trials], + 'stimFreeze_times': _assign_events_to_trial(start_times, self.frame2ttl['times'], take=-2)[:n_trials], + 'stimOn_times': _assign_events_to_trial(start_times, self.frame2ttl['times'], take='first')[:n_trials], + 'stimOff_times': _assign_events_to_trial(start_times, self.frame2ttl['times'])[:n_trials], + # These 'raw' intervals will be used in the sync + 'intervals_1': _assign_events_to_trial(start_times, t_trial_end), + 'intervals_0': start_times + }) + + # If stim on occurs before trial end, use stim on time. Likewise for trial end and stim off + trials['intervals'] = np.c_[trials['intervals_0'], trials['intervals_1']][:n_trials, :] + to_correct = ~np.isnan(trials['stimOn_times']) & (trials['stimOn_times'] < trials['intervals'][:, 0]) + if np.any(to_correct): + _logger.warning('%i/%i stim on events occurring outside trial intervals', sum(to_correct), len(to_correct)) + trials['intervals'][to_correct, 0] = trials['stimOn_times'][to_correct] + to_correct = ~np.isnan(trials['stimOff_times']) & (trials['stimOff_times'] > trials['intervals'][:, 1]) + if np.any(to_correct): + _logger.debug( + '%i/%i stim off events occurring outside trial intervals; using stim off times as trial end', + sum(to_correct), len(to_correct)) + trials['intervals'][to_correct, 1] = trials['stimOff_times'][to_correct] + + if display: # pragma: no cover + width = 0.5 + ymax = 5 + if isinstance(display, bool): + plt.figure('Bpod FPGA Sync') + ax = plt.gca() + else: + ax = display + plots.squares(self.bpod['times'], self.bpod['polarities'] * 0.4 + 1, ax=ax, color='k') + plots.squares(self.frame2ttl['times'], self.frame2ttl['polarities'] * 0.4 + 2, ax=ax, color='k') + plots.squares(self.audio['times'], self.audio['polarities'] * 0.4 + 3, ax=ax, color='k') + color_map = TABLEAU_COLORS.keys() + for (event_name, event_times), c in zip(trials.to_df().items(), cycle(color_map)): + plots.vertical_lines(event_times, ymin=0, ymax=ymax, ax=ax, color=c, label=event_name, linewidth=width) + ax.legend() + ax.set_yticks([0, 1, 2]) + ax.set_yticklabels(['bpod', 'f2ttl', 'audio']) + ax.set_ylim([0, 5]) + + return trials + def extract_all(session_path, sync_collection='raw_ephys_data', save=True, save_path=None, task_collection='raw_behavior_data', protocol_number=None, **kwargs): @@ -883,7 +1583,11 @@ def extract_all(session_path, sync_collection='raw_ephys_data', save=True, save_ - sync - wheel - behaviour - - video time stamps + + These `extract_all` functions should be deprecated as they make assumptions about hardware + parameters. Additionally the FpgaTrials class now automatically loads DAQ sync files, extracts + the Bpod trials, and returns a dict instead of a tuple. Therefore this function is entirely + redundant. See the examples for the correct way to extract NI DAQ behaviour sessions. Parameters ---------- @@ -909,6 +1613,10 @@ def extract_all(session_path, sync_collection='raw_ephys_data', save=True, save_ list of pathlib.Path, None If save is True, a list of file paths to the extracted data. """ + warnings.warn( + 'ibllib.io.extractors.ephys_fpga.extract_all will be removed in future versions; ' + 'use FpgaTrials instead. For reliable extraction, use the dynamic pipeline behaviour tasks.', + FutureWarning) # Extract Bpod trials bpod_raw = raw.load_data(session_path, task_collection=task_collection) assert bpod_raw is not None, 'No task trials data in raw_behavior_data - Exit' diff --git a/ibllib/io/extractors/habituation_trials.py b/ibllib/io/extractors/habituation_trials.py index 59a29a269..655ea2de1 100644 --- a/ibllib/io/extractors/habituation_trials.py +++ b/ibllib/io/extractors/habituation_trials.py @@ -73,8 +73,7 @@ def _extract(self) -> dict: There should be exactly three TTLs per trial. stimOff_times should be the first TTL pulse. If 1 or more pulses are missing, we can not be confident of assigning the correct one. """ - out['stimOff_times'] = np.array([sync[0] if len(sync) == 3 else np.nan - for sync, off in zip(ttls[1:], ends)]) + out['stimOff_times'] = np.array([sync[0] if len(sync) == 3 else np.nan for sync in ttls[1:]]) # Trial intervals """ @@ -85,8 +84,13 @@ def _extract(self) -> dict: # NB: We lose the last trial because the stim off event occurs at trial_num + 1 n_trials = out['stimOff_times'].size out['intervals'] = np.c_[starts, np.r_[ends, np.nan]][:n_trials, :] - to_update = out['intervals'][:, 1] < out['stimOff_times'] - out['intervals'][to_update, 1] = out['stimOff_times'][to_update] + + to_correct = ~np.isnan(out['stimOff_times']) & (out['stimOff_times'] > out['intervals'][:, 1]) + if np.any(to_correct): + _logger.debug( + '%i/%i stim off events occurring outside trial intervals; using stim off times as trial end', + sum(to_correct), len(to_correct)) + out['intervals'][to_correct, 1] = out['stimOff_times'][to_correct] # itiIn times out['itiIn_times'] = np.r_[ends, np.nan] @@ -133,11 +137,8 @@ def _extract(self) -> dict: # Double-check that the early and late trial events occur within the trial intervals idx = ~np.isnan(out['stimOn_times'][:n_trials]) - if np.any(out['stimOn_times'][:n_trials][idx] < out['intervals'][idx, 0]): - _logger.warning('Stim on events occurring outside trial intervals') - idx = ~np.isnan(out['stimOff_times']) - if np.any(out['stimOff_times'][idx] > out['intervals'][idx, 1]): - _logger.warning('Stim off events occurring outside trial intervals') + assert not np.any(out['stimOn_times'][:n_trials][idx] < out['intervals'][idx, 0]), \ + 'Stim on events occurring outside trial intervals' # Truncate arrays and return in correct order return {k: out[k][:n_trials] for k in self.var_names} diff --git a/ibllib/io/session_params.py b/ibllib/io/session_params.py index 5bcaf2873..fd9854455 100644 --- a/ibllib/io/session_params.py +++ b/ibllib/io/session_params.py @@ -413,7 +413,6 @@ def iter_dict(d): for d in filter(lambda x: isinstance(x, dict), v): iter_dict(d) elif isinstance(v, dict) and 'collection' in v: - print(k) # if the key already exists, append the collection name to the list if k in collection_map: clist = collection_map[k] if isinstance(collection_map[k], list) else [collection_map[k]] diff --git a/ibllib/pipes/behavior_tasks.py b/ibllib/pipes/behavior_tasks.py index 6f1c8d506..85e21c7ac 100644 --- a/ibllib/pipes/behavior_tasks.py +++ b/ibllib/pipes/behavior_tasks.py @@ -14,7 +14,7 @@ from ibllib.qc.task_metrics import HabituationQC, TaskQC from ibllib.io.extractors.ephys_passive import PassiveChoiceWorld from ibllib.io.extractors.bpod_trials import get_bpod_extractor -from ibllib.io.extractors.ephys_fpga import FpgaTrials, get_sync_and_chn_map +from ibllib.io.extractors.ephys_fpga import FpgaTrials, FpgaTrialsHabituation, get_sync_and_chn_map from ibllib.io.extractors.mesoscope import TimelineTrials from ibllib.pipes import training_status from ibllib.plots.figures import BehaviourPlots @@ -102,14 +102,61 @@ def _run_qc(self, trials_data=None, update=True): qc.extractor = TaskQCExtractor(self.session_path, lazy=True, sync_collection=self.sync_collection, one=self.one, sync_type=self.sync, task_collection=self.collection) - # Currently only the data field is accessed + # Update extractor fields qc.extractor.data = qc.extractor.rename_data(trials_data.copy()) + qc.extractor.frame_ttls = self.extractor.frame2ttl # used in iblapps QC viewer + qc.extractor.audio_ttls = self.extractor.audio # used in iblapps QC viewer + qc.extractor.settings = self.extractor.settings namespace = 'task' if self.protocol_number is None else f'task_{self.protocol_number:02}' qc.run(update=update, namespace=namespace) return qc +class HabituationTrialsNidq(HabituationTrialsBpod): + priority = 90 + job_size = 'small' + + @property + def signature(self): + signature = super().signature + signature['input_files'] = [ + ('_iblrig_taskData.raw.*', self.collection, True), + ('_iblrig_taskSettings.raw.*', self.collection, True), + (f'_{self.sync_namespace}_sync.channels.npy', self.sync_collection, True), + (f'_{self.sync_namespace}_sync.polarities.npy', self.sync_collection, True), + (f'_{self.sync_namespace}_sync.times.npy', self.sync_collection, True), + ('*wiring.json', self.sync_collection, False), + ('*.meta', self.sync_collection, True)] + return signature + + def _extract_behaviour(self, save=True, **kwargs): + """Extract the habituationChoiceWorld trial data using NI DAQ clock.""" + # Extract Bpod trials + bpod_trials, _ = super()._extract_behaviour(save=False, **kwargs) + + # Sync Bpod trials to FPGA + sync, chmap = get_sync_and_chn_map(self.session_path, self.sync_collection) + self.extractor = FpgaTrialsHabituation( + self.session_path, bpod_trials=bpod_trials, bpod_extractor=self.extractor) + + # NB: The stimOff times are called stimCenter times for habituation choice world + outputs, files = self.extractor.extract( + save=save, sync=sync, chmap=chmap, path_out=self.session_path.joinpath(self.output_collection), + task_collection=self.collection, protocol_number=self.protocol_number, **kwargs) + return outputs, files + + def _run_qc(self, trials_data=None, update=True, **_): + """Run and update QC. + + This adds the bpod TTLs to the QC object *after* the QC is run in the super call method. + The raw Bpod TTLs are not used by the QC however they are used in the iblapps QC plot. + """ + qc = super()._run_qc(trials_data=trials_data, update=update) + qc.extractor.bpod_ttls = self.extractor.bpod + return qc + + class TrialRegisterRaw(base_tasks.RegisterRawDataTask, base_tasks.BehaviourTask): priority = 100 job_size = 'small' @@ -286,9 +333,9 @@ def _run_qc(self, trials_data=None, update=True): else: qc = TaskQC(self.session_path, one=self.one, log=_logger) qc_extractor.wheel_encoding = 'X1' - qc_extractor.settings = self.extractor.settings - qc_extractor.frame_ttls, qc_extractor.audio_ttls = load_bpod_fronts( - self.session_path, task_collection=self.collection) + qc_extractor.settings = self.extractor.settings + qc_extractor.frame_ttls, qc_extractor.audio_ttls = load_bpod_fronts( + self.session_path, task_collection=self.collection) qc.extractor = qc_extractor # Aggregate and update Alyx QC fields @@ -370,14 +417,15 @@ def _run_qc(self, trials_data=None, update=False, plot_qc=False): qc = HabituationQC(self.session_path, one=self.one, log=_logger) else: qc = TaskQC(self.session_path, one=self.one, log=_logger) - qc_extractor.settings = self.extractor.settings # Add Bpod wheel data wheel_ts_bpod = self.extractor.bpod2fpga(self.extractor.bpod_trials['wheel_timestamps']) qc_extractor.data['wheel_timestamps_bpod'] = wheel_ts_bpod qc_extractor.data['wheel_position_bpod'] = self.extractor.bpod_trials['wheel_position'] qc_extractor.wheel_encoding = 'X4' - qc_extractor.frame_ttls = self.extractor.frame2ttl - qc_extractor.audio_ttls = self.extractor.audio + qc_extractor.frame_ttls = self.extractor.frame2ttl + qc_extractor.audio_ttls = self.extractor.audio + qc_extractor.bpod_ttls = self.extractor.bpod + qc_extractor.settings = self.extractor.settings qc.extractor = qc_extractor # Aggregate and update Alyx QC fields diff --git a/ibllib/pipes/dynamic_pipeline.py b/ibllib/pipes/dynamic_pipeline.py index bc2caaf1b..3c72853fb 100644 --- a/ibllib/pipes/dynamic_pipeline.py +++ b/ibllib/pipes/dynamic_pipeline.py @@ -230,22 +230,28 @@ def make_pipeline(session_path, **pkwargs): # - choice_world_biased # - choice_world_training # - choice_world_habituation - if 'habituation' in protocol: - registration_class = btasks.HabituationRegisterRaw - behaviour_class = btasks.HabituationTrialsBpod - compute_status = False - elif 'passiveChoiceWorld' in protocol: + if 'passiveChoiceWorld' in protocol: registration_class = btasks.PassiveRegisterRaw behaviour_class = btasks.PassiveTask compute_status = False elif sync_kwargs['sync'] == 'bpod': - registration_class = btasks.TrialRegisterRaw - behaviour_class = btasks.ChoiceWorldTrialsBpod - compute_status = True + if 'habituation' in protocol: + registration_class = btasks.HabituationRegisterRaw + behaviour_class = btasks.HabituationTrialsBpod + compute_status = False + else: + registration_class = btasks.TrialRegisterRaw + behaviour_class = btasks.ChoiceWorldTrialsBpod + compute_status = True elif sync_kwargs['sync'] == 'nidq': - registration_class = btasks.TrialRegisterRaw - behaviour_class = btasks.ChoiceWorldTrialsNidq - compute_status = True + if 'habituation' in protocol: + registration_class = btasks.HabituationRegisterRaw + behaviour_class = btasks.HabituationTrialsNidq + compute_status = False + else: + registration_class = btasks.TrialRegisterRaw + behaviour_class = btasks.ChoiceWorldTrialsNidq + compute_status = True else: raise NotImplementedError tasks[f'RegisterRaw_{protocol}_{i:02}'] = type(f'RegisterRaw_{protocol}_{i:02}', (registration_class,), {})( diff --git a/ibllib/qc/task_metrics.py b/ibllib/qc/task_metrics.py index 42361645d..d746626d5 100644 --- a/ibllib/qc/task_metrics.py +++ b/ibllib/qc/task_metrics.py @@ -1,4 +1,5 @@ -"""Behaviour QC +"""Behaviour QC. + This module runs a list of quality control metrics on the behaviour data. Examples @@ -179,20 +180,22 @@ def run(self, update=False, namespace='task', **kwargs): return outcome, results @staticmethod - def compute_session_status_from_dict(results): + def compute_session_status_from_dict(results, criteria=None): """ Given a dictionary of results, computes the overall session QC for each key and aggregates in a single value - :param results: a dictionary of qc keys containing (usually scalar) values + :param results: a dictionary of qc keys containing (usually scalar) values. + :param criteria: a dictionary of qc keys containing map of PASS, WARNING, FAIL thresholds. :return: Overall session QC outcome as a string :return: A dict of QC tests and their outcomes """ indices = np.zeros(len(results), dtype=int) + criteria = criteria or TaskQC.criteria for i, k in enumerate(results): - if k in TaskQC.criteria.keys(): - indices[i] = TaskQC._thresholding(results[k], thresholds=TaskQC.criteria[k]) + if k in criteria.keys(): + indices[i] = TaskQC._thresholding(results[k], thresholds=criteria[k]) else: - indices[i] = TaskQC._thresholding(results[k], thresholds=TaskQC.criteria['default']) + indices[i] = TaskQC._thresholding(results[k], thresholds=criteria['default']) def key_map(x): return 'NOT_SET' if x < 0 else list(TaskQC.criteria['default'].keys())[x] @@ -213,14 +216,19 @@ def compute_session_status(self): # Get mean passed of each check, or None if passed is None or all NaN results = {k: None if v is None or np.isnan(v).all() else np.nanmean(v) for k, v in self.passed.items()} - session_outcome, outcomes = self.compute_session_status_from_dict(results) + session_outcome, outcomes = self.compute_session_status_from_dict(results, self.criteria) return session_outcome, results, outcomes class HabituationQC(TaskQC): - def compute(self, download_data=None): - """Compute and store the QC metrics + criteria = dict() + criteria['default'] = {'PASS': 0.99, 'WARNING': 0.90, 'FAIL': 0} # Note: WARNING was 0.95 prior to Aug 2022 + criteria['_task_phase_distribution'] = {'PASS': 0.99, 'NOT_SET': 0} # This rarely passes due to low trial num + + def compute(self, download_data=None, **kwargs): + """Compute and store the QC metrics. + Runs the QC on the session and stores a map of the metrics for each datapoint for each test, and a map of which datapoints passed for each test :return: @@ -228,7 +236,7 @@ def compute(self, download_data=None): if self.extractor is None: # If download_data is None, decide based on whether eid or session path was provided ensure_data = self.download_data if download_data is None else download_data - self.load_data(download_data=ensure_data) + self.load_data(download_data=ensure_data, **kwargs) self.log.info(f'Session {self.session_path}: Running QC on habituation data...') # Initialize checks @@ -302,6 +310,7 @@ def compute(self, download_data=None): passed[check] = (metric <= 2 * np.pi) & (metric >= 0) metrics[check] = metric + # This is not very useful as a check because there are so few trials check = prefix + 'phase_distribution' metric, _ = np.histogram(data['phase']) _, p = chisquare(metric) diff --git a/ibllib/tests/extractors/test_ephys_fpga.py b/ibllib/tests/extractors/test_ephys_fpga.py index ca211e426..465322810 100644 --- a/ibllib/tests/extractors/test_ephys_fpga.py +++ b/ibllib/tests/extractors/test_ephys_fpga.py @@ -1,15 +1,11 @@ +"""Tests for ephys FPGA sync and FPGA wheel extraction.""" import unittest import tempfile from pathlib import Path -import pickle -import logging import numpy as np -from ibllib.io.extractors.training_wheel import extract_first_movement_times, infer_wheel_units from ibllib.io.extractors import ephys_fpga -from ibllib.io.extractors.training_wheel import extract_wheel_moves -import brainbox.behavior.wheel as wh import spikeglx @@ -88,189 +84,12 @@ def test_ibl_sync_maps(self): self.assertEqual(s, ephys_fpga.CHMAPS['3B']['ap']) -class TestWheelExtraction(unittest.TestCase): - - def setUp(self) -> None: - self.ta = np.array([2, 4, 6, 8, 12, 14, 16, 18]) - self.pa = np.array([1, -1, 1, -1, 1, -1, 1, -1]) - self.tb = np.array([3, 5, 7, 9, 11, 13, 15, 17]) - self.pb = np.array([1, -1, 1, -1, 1, -1, 1, -1]) - - def test_x1_decoding(self): - p_ = np.array([1, 2, 1, 0]) - t_ = np.array([2, 6, 11, 15]) - t, p = ephys_fpga._rotary_encoder_positions_from_fronts( - self.ta, self.pa, self.tb, self.pb, ticks=np.pi * 2, coding='x1') - self.assertTrue(np.all(t == t_)) - self.assertTrue(np.all(p == p_)) - - def test_x4_decoding(self): - p_ = np.array([1, 2, 3, 4, 5, 6, 7, 8, 7, 6, 5, 4, 3, 2, 1, 0]) / 4 - t_ = np.array([2, 3, 4, 5, 6, 7, 8, 9, 11, 12, 13, 14, 15, 16, 17, 18]) - t, p = ephys_fpga._rotary_encoder_positions_from_fronts( - self.ta, self.pa, self.tb, self.pb, ticks=np.pi * 2, coding='x4') - self.assertTrue(np.all(t == t_)) - self.assertTrue(np.all(np.isclose(p, p_))) - - def test_x2_decoding(self): - p_ = np.array([1, 2, 3, 4, 3, 2, 1, 0]) / 2 - t_ = np.array([2, 4, 6, 8, 12, 14, 16, 18]) - t, p = ephys_fpga._rotary_encoder_positions_from_fronts( - self.ta, self.pa, self.tb, self.pb, ticks=np.pi * 2, coding='x2') - self.assertTrue(np.all(t == t_)) - self.assertTrue(np.all(p == p_)) - - -class TestExtractedWheelUnits(unittest.TestCase): - """Tests the infer_wheel_units function""" - - wheel_radius_cm = 3.1 - - def setUp(self) -> None: - """ - Create the wheel position data for testing: the positions attribute holds a dictionary of - units, each holding a dictionary of encoding types to test, e.g. - - positions = { - 'rad': { - 'X1': ..., - 'X2': ..., - 'X4': ... - }, - 'cm': { - 'X1': ..., - 'X2': ..., - 'X4': ... - } - } - :return: - """ - def x(unit, enc=int(1), wheel_radius=self.wheel_radius_cm): - radius = 1 if unit == 'rad' else wheel_radius - return 1 / ephys_fpga.WHEEL_TICKS * np.pi * 2 * radius / enc - - # A pseudo-random sequence of integrated fronts - seq = np.array([-1, 0, 1, 2, 1, 2, 3, 4, 3, 2, 1, 0, -1, -2, 1, -2]) - encs = (1, 2, 4) # Encoding types to test - units = ('rad', 'cm') # Units to test - self.positions = {unit: {f'X{e}': x(unit, e) * seq for e in encs} for unit in units} - - def test_extract_wheel_moves(self): - for unit in self.positions.keys(): - for encoding, pos in self.positions[unit].items(): - result = infer_wheel_units(pos) - self.assertEqual(unit, result[0], f'failed to determine units for {encoding}') - expected = int(ephys_fpga.WHEEL_TICKS * int(encoding[1])) - self.assertEqual(expected, result[1], - f'failed to determine number of ticks for {encoding} in {unit}') - self.assertEqual(encoding, result[2], f'failed to determine encoding in {unit}') - - -class TestWheelMovesExtraction(unittest.TestCase): - - def setUp(self) -> None: - """ - Test data is in the form ((inputs), (outputs)) where inputs is a tuple containing a - numpy array of timestamps and one of positions; outputs is a tuple of outputs from - the functions. For details, see help on TestWheel.setUp method in module - brainbox.tests.test_behavior - """ - pickle_file = Path(__file__).parents[3].joinpath( - 'brainbox', 'tests', 'fixtures', 'wheel_test.p') - if not pickle_file.exists(): - self.test_data = None - else: - with open(pickle_file, 'rb') as f: - self.test_data = pickle.load(f) - - # Some trial times for trial_data[1] - self.trials = { - 'goCue_times': np.array([162.5, 105.6, 55]), - 'feedback_times': np.array([164.3, 108.3, 56]) - } - - def test_extract_wheel_moves(self): - test_data = self.test_data[1] - # Wrangle data into expected form - re_ts = test_data[0][0] - re_pos = test_data[0][1] - - logger = logging.getLogger(logname := 'ibllib.io.extractors.training_wheel') - with self.assertLogs(logger, level='INFO') as cm: - wheel_moves = extract_wheel_moves(re_ts, re_pos) - self.assertEqual([f'INFO:{logname}:Wheel in cm units using X2 encoding'], cm.output) - - n = 56 # expected number of movements - self.assertTupleEqual(wheel_moves['intervals'].shape, (n, 2), - 'failed to return the correct number of intervals') - self.assertEqual(wheel_moves['peakAmplitude'].size, n) - self.assertEqual(wheel_moves['peakVelocity_times'].size, n) - - # Check the first 3 intervals - ints = np.array( - [[24.78462599, 25.22562599], - [29.58762599, 31.15062599], - [31.64262599, 31.81662599]]) - actual = wheel_moves['intervals'][:3, ] - self.assertIsNone(np.testing.assert_allclose(actual, ints), 'unexpected intervals') - - # Check amplitudes - actual = wheel_moves['peakAmplitude'][-3:] - expected = [0.50255486, -1.70103154, 1.00740789] - self.assertIsNone(np.testing.assert_allclose(actual, expected), 'unexpected amplitudes') - - # Check peak velocities - actual = wheel_moves['peakVelocity_times'][-3:] - expected = [175.13662599, 176.65762599, 178.57262599] - self.assertIsNone(np.testing.assert_allclose(actual, expected), 'peak times') - - # Test extraction in rad - re_pos = wh.cm_to_rad(re_pos) - with self.assertLogs(logger, level='INFO') as cm: - wheel_moves = ephys_fpga.extract_wheel_moves(re_ts, re_pos) - self.assertEqual([f'INFO:{logname}:Wheel in rad units using X2 encoding'], cm.output) - - # Check the first 3 intervals. As position thresholds are adjusted by units and - # encoding, we should expect the intervals to be identical to above - actual = wheel_moves['intervals'][:3, ] - self.assertIsNone(np.testing.assert_allclose(actual, ints), 'unexpected intervals') - - def test_movement_log(self): - """ - Integration test for inferring the units and decoding type for wheel data input for - extract_wheel_moves. Only expected to work for the default wheel diameter. - """ - ta = np.array([2, 4, 6, 8, 12, 14, 16, 18]) - pa = np.array([1, -1, 1, -1, 1, -1, 1, -1]) - tb = np.array([3, 5, 7, 9, 11, 13, 15, 17]) - pb = np.array([1, -1, 1, -1, 1, -1, 1, -1]) - logger = logging.getLogger(logname := 'ibllib.io.extractors.training_wheel') - - for unit in ['cm', 'rad']: - for i in (1, 2, 4): - encoding = 'X' + str(i) - r = 3.1 if unit == 'cm' else 1 - # print(encoding, unit) - t, p = ephys_fpga._rotary_encoder_positions_from_fronts( - ta, pa, tb, pb, ticks=1024, coding=encoding.lower(), radius=r) - expected = f'INFO:{logname}:Wheel in {unit} units using {encoding} encoding' - with self.assertLogs(logger, level='INFO') as cm: - ephys_fpga.extract_wheel_moves(t, p) - self.assertEqual([expected], cm.output) - - def test_extract_first_movement_times(self): - test_data = self.test_data[1] - wheel_moves = ephys_fpga.extract_wheel_moves(test_data[0][0], test_data[0][1]) - first, is_final, ind = extract_first_movement_times(wheel_moves, self.trials) - np.testing.assert_allclose(first, [162.48462599, 105.62562599, np.nan]) - np.testing.assert_array_equal(is_final, [False, True, False]) - np.testing.assert_array_equal(ind, [46, 18]) - - class TestEphysFPGA_TTLsExtraction(unittest.TestCase): def test_audio_ttl_wiring_camera(self): """ + Test ephys_fpga._clean_audio function. + Test removal of spurious TTLs due to a wrong wiring of the camera onto the soundcard example eid: e349a2e7-50a3-47ca-bc45-20d1899854ec """ diff --git a/ibllib/tests/extractors/test_ephys_trials.py b/ibllib/tests/extractors/test_ephys_trials.py index d5483792f..7d77079af 100644 --- a/ibllib/tests/extractors/test_ephys_trials.py +++ b/ibllib/tests/extractors/test_ephys_trials.py @@ -1,15 +1,23 @@ import unittest from pathlib import Path +import pickle + import numpy as np from ibllib.io.extractors import ephys_fpga, biased_trials import ibllib.io.raw_data_loaders as raw +from ibllib.io.extractors.training_wheel import extract_first_movement_times, infer_wheel_units +from ibllib.io.extractors.training_wheel import extract_wheel_moves +import brainbox.behavior.wheel as wh class TestEphysSyncExtraction(unittest.TestCase): def test_bpod_trace_extraction(self): + """Test ephys_fpga._assign_events_bpod function. + TODO Remove this test and corresponding function. + """ t_valve_open_ = np.array([117.12136667, 122.3873, 127.82903333, 140.56083333, 143.55326667, 155.29713333, 164.9186, 167.91133333, 171.39736667, 178.0305, 181.70343333]) @@ -48,6 +56,7 @@ def test_bpod_trace_extraction(self): self.assertTrue(np.all(np.isclose(t_valve_open, t_valve_open_))) def test_align_to_trial(self): + """Test ephys_fpga._assign_events_to_trial function.""" # simple test with one missing at the end t_trial_start = np.arange(0, 5) * 10 t_event = np.arange(0, 5) * 10 + 2 @@ -95,6 +104,7 @@ def test_align_to_trial(self): self.assertRaises(ValueError, ephys_fpga._assign_events_to_trial, t_trial_start, np.array([0., 2., 1.])) def test_wheel_trace_from_sync(self): + """Test ephys_fpga._rotary_encoder_positions_from_fronts function.""" pos_ = - np.array([-1, 0, -1, -2, -1, -2]) * (np.pi / ephys_fpga.WHEEL_TICKS) ta = np.array([1, 2, 3, 4, 5, 6]) tb = np.array([0.5, 3.2, 3.3, 3.4, 5.25, 5.5]) @@ -137,5 +147,184 @@ def test_get_probabilityLeft(self): self.assertTrue(all([x in [0.2, 0.5, 0.8] for x in np.unique(pLeft1)])) +class TestWheelExtraction(unittest.TestCase): + + def setUp(self) -> None: + self.ta = np.array([2, 4, 6, 8, 12, 14, 16, 18]) + self.pa = np.array([1, -1, 1, -1, 1, -1, 1, -1]) + self.tb = np.array([3, 5, 7, 9, 11, 13, 15, 17]) + self.pb = np.array([1, -1, 1, -1, 1, -1, 1, -1]) + + def test_x1_decoding(self): + p_ = np.array([1, 2, 1, 0]) + t_ = np.array([2, 6, 11, 15]) + t, p = ephys_fpga._rotary_encoder_positions_from_fronts( + self.ta, self.pa, self.tb, self.pb, ticks=np.pi * 2, coding='x1') + self.assertTrue(np.all(t == t_)) + self.assertTrue(np.all(p == p_)) + + def test_x4_decoding(self): + p_ = np.array([1, 2, 3, 4, 5, 6, 7, 8, 7, 6, 5, 4, 3, 2, 1, 0]) / 4 + t_ = np.array([2, 3, 4, 5, 6, 7, 8, 9, 11, 12, 13, 14, 15, 16, 17, 18]) + t, p = ephys_fpga._rotary_encoder_positions_from_fronts( + self.ta, self.pa, self.tb, self.pb, ticks=np.pi * 2, coding='x4') + self.assertTrue(np.all(t == t_)) + self.assertTrue(np.all(np.isclose(p, p_))) + + def test_x2_decoding(self): + p_ = np.array([1, 2, 3, 4, 3, 2, 1, 0]) / 2 + t_ = np.array([2, 4, 6, 8, 12, 14, 16, 18]) + t, p = ephys_fpga._rotary_encoder_positions_from_fronts( + self.ta, self.pa, self.tb, self.pb, ticks=np.pi * 2, coding='x2') + self.assertTrue(np.all(t == t_)) + self.assertTrue(np.all(p == p_)) + + +class TestExtractedWheelUnits(unittest.TestCase): + """Tests the infer_wheel_units function""" + + wheel_radius_cm = 3.1 + + def setUp(self) -> None: + """ + Create the wheel position data for testing: the positions attribute holds a dictionary of + units, each holding a dictionary of encoding types to test, e.g. + + positions = { + 'rad': { + 'X1': ..., + 'X2': ..., + 'X4': ... + }, + 'cm': { + 'X1': ..., + 'X2': ..., + 'X4': ... + } + } + :return: + """ + def x(unit, enc=int(1), wheel_radius=self.wheel_radius_cm): + radius = 1 if unit == 'rad' else wheel_radius + return 1 / ephys_fpga.WHEEL_TICKS * np.pi * 2 * radius / enc + + # A pseudo-random sequence of integrated fronts + seq = np.array([-1, 0, 1, 2, 1, 2, 3, 4, 3, 2, 1, 0, -1, -2, 1, -2]) + encs = (1, 2, 4) # Encoding types to test + units = ('rad', 'cm') # Units to test + self.positions = {unit: {f'X{e}': x(unit, e) * seq for e in encs} for unit in units} + + def test_extract_wheel_moves(self): + for unit in self.positions.keys(): + for encoding, pos in self.positions[unit].items(): + result = infer_wheel_units(pos) + self.assertEqual(unit, result[0], f'failed to determine units for {encoding}') + expected = int(ephys_fpga.WHEEL_TICKS * int(encoding[1])) + self.assertEqual(expected, result[1], + f'failed to determine number of ticks for {encoding} in {unit}') + self.assertEqual(encoding, result[2], f'failed to determine encoding in {unit}') + + +class TestWheelMovesExtraction(unittest.TestCase): + + def setUp(self) -> None: + """ + Test data is in the form ((inputs), (outputs)) where inputs is a tuple containing a + numpy array of timestamps and one of positions; outputs is a tuple of outputs from + the functions. For details, see help on TestWheel.setUp method in module + brainbox.tests.test_behavior + """ + pickle_file = Path(__file__).parents[3].joinpath( + 'brainbox', 'tests', 'fixtures', 'wheel_test.p') + if not pickle_file.exists(): + self.test_data = None + else: + with open(pickle_file, 'rb') as f: + self.test_data = pickle.load(f) + + # Some trial times for trial_data[1] + self.trials = { + 'goCue_times': np.array([162.5, 105.6, 55]), + 'feedback_times': np.array([164.3, 108.3, 56]) + } + + def test_extract_wheel_moves(self): + test_data = self.test_data[1] + # Wrangle data into expected form + re_ts = test_data[0][0] + re_pos = test_data[0][1] + + logname = 'ibllib.io.extractors.training_wheel' + with self.assertLogs(logname, level='INFO') as cm: + wheel_moves = extract_wheel_moves(re_ts, re_pos) + self.assertEqual([f'INFO:{logname}:Wheel in cm units using X2 encoding'], cm.output) + + n = 56 # expected number of movements + self.assertTupleEqual(wheel_moves['intervals'].shape, (n, 2), + 'failed to return the correct number of intervals') + self.assertEqual(wheel_moves['peakAmplitude'].size, n) + self.assertEqual(wheel_moves['peakVelocity_times'].size, n) + + # Check the first 3 intervals + ints = np.array( + [[24.78462599, 25.22562599], + [29.58762599, 31.15062599], + [31.64262599, 31.81662599]]) + actual = wheel_moves['intervals'][:3, ] + self.assertIsNone(np.testing.assert_allclose(actual, ints), 'unexpected intervals') + + # Check amplitudes + actual = wheel_moves['peakAmplitude'][-3:] + expected = [0.50255486, -1.70103154, 1.00740789] + self.assertIsNone(np.testing.assert_allclose(actual, expected), 'unexpected amplitudes') + + # Check peak velocities + actual = wheel_moves['peakVelocity_times'][-3:] + expected = [175.13662599, 176.65762599, 178.57262599] + self.assertIsNone(np.testing.assert_allclose(actual, expected), 'peak times') + + # Test extraction in rad + re_pos = wh.cm_to_rad(re_pos) + with self.assertLogs(logname, level='INFO') as cm: + wheel_moves = ephys_fpga.extract_wheel_moves(re_ts, re_pos) + self.assertEqual([f'INFO:{logname}:Wheel in rad units using X2 encoding'], cm.output) + + # Check the first 3 intervals. As position thresholds are adjusted by units and + # encoding, we should expect the intervals to be identical to above + actual = wheel_moves['intervals'][:3, ] + self.assertIsNone(np.testing.assert_allclose(actual, ints), 'unexpected intervals') + + def test_movement_log(self): + """ + Integration test for inferring the units and decoding type for wheel data input for + extract_wheel_moves. Only expected to work for the default wheel diameter. + """ + ta = np.array([2, 4, 6, 8, 12, 14, 16, 18]) + pa = np.array([1, -1, 1, -1, 1, -1, 1, -1]) + tb = np.array([3, 5, 7, 9, 11, 13, 15, 17]) + pb = np.array([1, -1, 1, -1, 1, -1, 1, -1]) + logname = 'ibllib.io.extractors.training_wheel' + + for unit in ['cm', 'rad']: + for i in (1, 2, 4): + encoding = 'X' + str(i) + r = 3.1 if unit == 'cm' else 1 + # print(encoding, unit) + t, p = ephys_fpga._rotary_encoder_positions_from_fronts( + ta, pa, tb, pb, ticks=1024, coding=encoding.lower(), radius=r) + expected = f'INFO:{logname}:Wheel in {unit} units using {encoding} encoding' + with self.assertLogs(logname, level='INFO') as cm: + ephys_fpga.extract_wheel_moves(t, p) + self.assertEqual([expected], cm.output) + + def test_extract_first_movement_times(self): + test_data = self.test_data[1] + wheel_moves = ephys_fpga.extract_wheel_moves(test_data[0][0], test_data[0][1]) + first, is_final, ind = extract_first_movement_times(wheel_moves, self.trials) + np.testing.assert_allclose(first, [162.48462599, 105.62562599, np.nan]) + np.testing.assert_array_equal(is_final, [False, True, False]) + np.testing.assert_array_equal(ind, [46, 18]) + + if __name__ == '__main__': unittest.main(exit=False, verbosity=2) From 822f1837995cc208100dfd35d1191336037834ea Mon Sep 17 00:00:00 2001 From: Mayo Faulkner Date: Tue, 28 Nov 2023 12:58:26 +0000 Subject: [PATCH 46/68] Auto stash before merge of "develop" and "origin/develop" support changes due to newer matplotlib version --- brainbox/plot_base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/brainbox/plot_base.py b/brainbox/plot_base.py index cbba05aff..e6515287b 100644 --- a/brainbox/plot_base.py +++ b/brainbox/plot_base.py @@ -568,7 +568,7 @@ def plot_probe(data, ax=None, show_cbar=True, make_pretty=True, fig_kwargs=dict( im = NonUniformImage(ax, interpolation='nearest', cmap=data['cmap']) im.set_clim(data['clim'][0], data['clim'][1]) im.set_data(x, y, dat.T) - ax.images.append(im) + ax.add_image(im) ax.set_xlim(data['xlim'][0], data['xlim'][1]) ax.set_ylim(data['ylim'][0], data['ylim'][1]) From dd7b0873e48011cefb0b4bdb0aad35ef21429e36 Mon Sep 17 00:00:00 2001 From: juhuntenburg Date: Tue, 28 Nov 2023 21:08:32 +0100 Subject: [PATCH 47/68] training status to run on SDSC --- ibllib/pipes/training_status.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ibllib/pipes/training_status.py b/ibllib/pipes/training_status.py index fc73304c6..6ef4bf120 100644 --- a/ibllib/pipes/training_status.py +++ b/ibllib/pipes/training_status.py @@ -165,9 +165,9 @@ def load_trials(sess_path, one, collections=None, force=True, mode='raise'): try: # try and load all trials that are found locally in the session path locally if collections is None: - trial_locations = list(sess_path.rglob('_ibl_trials.goCueTrigger_times.npy')) + trial_locations = list(sess_path.rglob('_ibl_trials.goCueTrigger_times.*npy')) else: - trial_locations = [Path(sess_path).joinpath(c, '_ibl_trials.goCueTrigger_times.npy') for c in collections] + trial_locations = [Path(sess_path).joinpath(c, '_ibl_trials.goCueTrigger_times.*npy') for c in collections] if len(trial_locations) > 1: trial_dict = {} From def135371c871f7d1e97fb2a90f6c97ddb0b8d25 Mon Sep 17 00:00:00 2001 From: Mayo Faulkner Date: Thu, 30 Nov 2023 13:36:52 +0000 Subject: [PATCH 48/68] training status - only extract bpod data, don't run qc and don't sync --- ibllib/pipes/training_status.py | 105 +++++++++++++------------------- 1 file changed, 42 insertions(+), 63 deletions(-) diff --git a/ibllib/pipes/training_status.py b/ibllib/pipes/training_status.py index fc73304c6..bbe839d16 100644 --- a/ibllib/pipes/training_status.py +++ b/ibllib/pipes/training_status.py @@ -3,9 +3,9 @@ from ibllib.io.raw_data_loaders import load_bpod from ibllib.oneibl.registration import _get_session_times -from ibllib.io.extractors.base import get_pipeline, get_session_extractor_type +from ibllib.io.extractors.base import get_session_extractor_type from ibllib.io.session_params import read_params -import ibllib.pipes.dynamic_pipeline as dyn +from ibllib.io.extractors.bpod_trials import get_bpod_extractor from iblutil.util import setup_logger from ibllib.plots.snapshot import ReportSnapshot @@ -22,6 +22,7 @@ import seaborn as sns import boto3 from botocore.exceptions import ProfileNotFound, ClientError +from itertools import chain logger = setup_logger(__name__) @@ -87,43 +88,6 @@ def upload_training_table_to_aws(lab, subject): return -def get_trials_task(session_path, one): - # If experiment description file then process this - experiment_description_file = read_params(session_path) - if experiment_description_file is not None: - tasks = [] - pipeline = dyn.make_pipeline(session_path) - trials_tasks = [t for t in pipeline.tasks if 'Trials' in t] - for task in trials_tasks: - t = pipeline.tasks.get(task) - t.__init__(session_path, **t.kwargs) - tasks.append(t) - else: - # Otherwise default to old way of doing things - pipeline = get_pipeline(session_path) - if pipeline == 'training': - from ibllib.pipes.training_preprocessing import TrainingTrials - tasks = [TrainingTrials(session_path)] - elif pipeline == 'ephys': - from ibllib.pipes.ephys_preprocessing import EphysTrials - tasks = [EphysTrials(session_path)] - else: - try: - # try and look if there is a custom extractor in the personal projects extraction class - import projects.base - task_type = get_session_extractor_type(session_path) - PipelineClass = projects.base.get_pipeline(task_type) - pipeline = PipelineClass(session_path, one) - trials_task_name = next(task for task in pipeline.tasks if 'Trials' in task) - task = pipeline.tasks.get(trials_task_name) - task.__init__(session_path) - tasks = [task] - except Exception: - tasks = [] - - return tasks - - def save_path(subj_path): return Path(subj_path).joinpath('training.csv') @@ -155,7 +119,7 @@ def load_existing_dataframe(subj_path): def load_trials(sess_path, one, collections=None, force=True, mode='raise'): """ Load trials data for session. First attempts to load from local session path, if this fails will attempt to download via ONE, - if this also fails, will then attempt to re-extraxt locally + if this also fails, will then attempt to re-extract locally :param sess_path: session path :param one: ONE instance :param force: when True and if the session trials can't be found, will attempt to re-extract from the disk @@ -207,19 +171,24 @@ def load_trials(sess_path, one, collections=None, force=True, mode='raise'): if 'probabilityLeft' not in trials.keys(): raise ALFObjectNotFound except Exception: - # Finally try to rextract the trials data locally + # Finally try to re-extract the trials data locally try: - # Get the tasks that need to be run - tasks = get_trials_task(sess_path, one) - if len(tasks) > 0: - for task in tasks: - status = task.run() - if status == 0: - return load_trials(sess_path, collections=collections, one=one, force=False) - else: - return + raw_collections, _ = get_data_collection(sess_path) + + if len(raw_collections) == 0: + return None + + trials_dict = {} + for i, collection in enumerate(raw_collections): + extractor = get_bpod_extractor(sess_path, task_collection=collection) + trials_data, _ = extractor.extract(task_collection=collection, save=False) + trials_dict[i] = alfio.AlfBunch.from_df(trials_data['table']) + + if len(trials_dict) > 1: + trials = training.concatenate_trials(trials_dict) else: - trials = None + trials = trials_dict[0] + except Exception as e: if mode == 'raise': raise Exception(f'Exhausted all possibilities for loading trials for {sess_path}') from e @@ -468,20 +437,29 @@ def get_data_collection(session_path): :param session_path: path of session :return: """ - experiment_description_file = read_params(session_path) - if experiment_description_file is not None: - pipeline = dyn.make_pipeline(session_path) - trials_tasks = [t for t in pipeline.tasks if 'Trials' in t] - collections = [pipeline.tasks.get(task).kwargs['collection'] for task in trials_tasks] - if len(collections) == 1 and collections[0] == 'raw_behavior_data': - alf_collections = ['alf'] - elif all(['raw_task_data' in c for c in collections]): - alf_collections = [f'alf/task_{c[-2:]}' for c in collections] - else: - alf_collections = None + experiment_description = read_params(session_path) + collections = [] + if experiment_description is not None: + task_protocols = experiment_description.get('tasks', []) + for i, (protocol, task_info) in enumerate(chain(*map(dict.items, task_protocols))): + if 'passiveChoiceWorld' in protocol: + continue + collection = task_info.get('collection', f'raw_task_data_{i:02}') + if collection == 'raw_passive_data': + continue + collections.append(collection) else: - collections = ['raw_behavior_data'] + settings = Path(session_path).rglob('_iblrig_taskSettings.raw.json') + for setting in settings: + if setting.parent.name != 'raw_passive_data': + collections.append(setting.parent.name) + + if len(collections) == 1 and collections[0] == 'raw_behavior_data': alf_collections = ['alf'] + elif all(['raw_task_data' in c for c in collections]): + alf_collections = [f'alf/task_{c[-2:]}' for c in collections] + else: + alf_collections = None return collections, alf_collections @@ -561,6 +539,7 @@ def get_training_info_for_session(session_paths, one, force=True): un_protocols = np.unique(protocols) # Example, training, training, biased - training would be combined, biased not + sess_dict = None if len(un_protocols) != 1: print(f'Different protocols in same session {session_path} : {protocols}') for prot in un_protocols: From d2294982f5c7462f16fee375c4d44ea9767fc9f1 Mon Sep 17 00:00:00 2001 From: Miles Wells Date: Tue, 5 Dec 2023 14:54:49 +0200 Subject: [PATCH 49/68] Remove test reference to module constant --- ibllib/tests/extractors/test_ephys_fpga.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/ibllib/tests/extractors/test_ephys_fpga.py b/ibllib/tests/extractors/test_ephys_fpga.py index 465322810..fdfe27218 100644 --- a/ibllib/tests/extractors/test_ephys_fpga.py +++ b/ibllib/tests/extractors/test_ephys_fpga.py @@ -208,14 +208,15 @@ def test_frame2ttl_flickers(self): switches under a given threshold """ DISPLAY = False # for debug purposes - diff = ephys_fpga.F2TTL_THRESH * np.array([0.5, 10]) + F2TTL_THRESH = 0.01 + diff = F2TTL_THRESH * np.array([0.5, 10]) # flicker ends with a polarity switch - downgoing pulse is removed t = np.r_[0, np.cumsum(diff[np.array([1, 1, 0, 0, 1])])] + 1 frame2ttl = {'times': t, 'polarities': np.mod(np.arange(t.size) + 1, 2) * 2 - 1} expected = {'times': np.array([1., 1.1, 1.2, 1.31]), 'polarities': np.array([1, -1, 1, -1])} - frame2ttl_ = ephys_fpga._clean_frame2ttl(frame2ttl, display=DISPLAY) + frame2ttl_ = ephys_fpga._clean_frame2ttl(frame2ttl, display=DISPLAY, threshold=F2TTL_THRESH) assert all([np.all(frame2ttl_[k] == expected[k]) for k in frame2ttl_]) # stand-alone flicker From cb9bc9df7d306b4568970389eb51cc170e5edbba Mon Sep 17 00:00:00 2001 From: Mayo Faulkner Date: Wed, 6 Dec 2023 09:36:07 +0000 Subject: [PATCH 50/68] when no spacer take last ttl as end time --- ibllib/io/extractors/ephys_passive.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ibllib/io/extractors/ephys_passive.py b/ibllib/io/extractors/ephys_passive.py index 2dfcb34e2..78e207ad8 100644 --- a/ibllib/io/extractors/ephys_passive.py +++ b/ibllib/io/extractors/ephys_passive.py @@ -227,8 +227,8 @@ def _get_passive_spacers(session_path, sync_collection='raw_ephys_data', f'trace ({int(np.size(spacer_times) / 2)})' ) - if tmax is None: # TODO THIS NEEDS CHANGING AS FOR DYNAMIC PIPELINE F2TTL slower than valve - tmax = fttl['times'][-1] + if tmax is None: + tmax = sync['times'][-1] spacer_times = np.r_[spacer_times.flatten(), tmax] return spacer_times[0], spacer_times[1::2], spacer_times[2::2] From 3a02da274db7c3d28f431dd1a50c51ff585e60bb Mon Sep 17 00:00:00 2001 From: Mayo Faulkner Date: Wed, 6 Dec 2023 12:35:48 +0000 Subject: [PATCH 51/68] improve logging --- ibllib/io/extractors/ephys_passive.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/ibllib/io/extractors/ephys_passive.py b/ibllib/io/extractors/ephys_passive.py index 78e207ad8..0867a3f67 100644 --- a/ibllib/io/extractors/ephys_passive.py +++ b/ibllib/io/extractors/ephys_passive.py @@ -418,8 +418,10 @@ def _extract_passiveAudio_intervals(audio: dict, rig_version: str) -> Tuple[np.a soundOff_times = audio["times"][audio["polarities"] < 0] # Check they are the correct number - assert len(soundOn_times) == NTONES + NNOISES, "Wrong number of sound ONSETS" - assert len(soundOff_times) == NTONES + NNOISES, "Wrong number of sound OFFSETS" + assert len(soundOn_times) == NTONES + NNOISES, f"Wrong number of sound ONSETS, " \ + f"{len(soundOn_times)}/{NTONES + NNOISES}" + assert len(soundOff_times) == NTONES + NNOISES, f"Wrong number of sound OFFSETS, " \ + f"{len(soundOn_times)}/{NTONES + NNOISES}" diff = soundOff_times - soundOn_times # Tone is ~100ms so check if diff < 0.3 From 8aff0ad6875fba0367545cb8673400722cc29830 Mon Sep 17 00:00:00 2001 From: Miles Wells Date: Fri, 8 Dec 2023 13:01:35 +0200 Subject: [PATCH 52/68] mesoscope trials extractor refactor; fix attribute_times numpy version bug; FpgaTrials.build_trials method call after sync --- ibllib/io/extractors/camera.py | 6 +- ibllib/io/extractors/ephys_fpga.py | 65 +++-- ibllib/io/extractors/mesoscope.py | 373 +++++++++++++++++++++-------- ibllib/io/raw_daq_loaders.py | 2 +- 4 files changed, 320 insertions(+), 126 deletions(-) diff --git a/ibllib/io/extractors/camera.py b/ibllib/io/extractors/camera.py index 7612c3e9e..93554c86a 100644 --- a/ibllib/io/extractors/camera.py +++ b/ibllib/io/extractors/camera.py @@ -513,12 +513,16 @@ def attribute_times(arr, events, tol=.1, injective=True, take='first'): Returns ------- numpy.array - An array the same length as `events`. + An array the same length as `events` containing indices of `arr` corresponding to each + event. """ if (take := take.lower()) not in ('first', 'nearest', 'after'): raise ValueError('Parameter `take` must be either "first", "nearest", or "after"') stack = np.ma.masked_invalid(arr, copy=False) stack.fill_value = np.inf + # If there are no invalid values, the mask is False so let's ensure it's a bool array + if stack.mask is np.bool_(0): + stack.mask = np.zeros(arr.shape, dtype=bool) assigned = np.full(events.shape, -1, dtype=int) # Initialize output array min_tol = 0 if take == 'after' else -tol for i, x in enumerate(events): diff --git a/ibllib/io/extractors/ephys_fpga.py b/ibllib/io/extractors/ephys_fpga.py index ad2cb0ab5..aa042ce8e 100644 --- a/ibllib/io/extractors/ephys_fpga.py +++ b/ibllib/io/extractors/ephys_fpga.py @@ -900,8 +900,8 @@ def load_sync(self, sync_collection='raw_ephys_data', **kwargs): Returns ------- one.alf.io.AlfBunch - A dictionary with keys ('times', 'polarities', 'channels'), containing the sync pulses and - the corresponding channel numbers. + A dictionary with keys ('times', 'polarities', 'channels'), containing the sync pulses + and the corresponding channel numbers. dict A map of channel names and their corresponding indices. """ @@ -992,24 +992,9 @@ def _extract(self, sync=None, chmap=None, sync_collection='raw_ephys_data', # Get the trial events from the DAQ sync TTLs fpga_trials = self.extract_behaviour_sync(sync, chmap, **kwargs) - # Sync the Bpod clock to the DAQ - self.bpod2fpga, drift_ppm, ibpod, ifpga = self.sync_bpod_clock(self.bpod_trials, fpga_trials, self.sync_field) + # Sync clocks and build final trials datasets + out = self.build_trials(fpga_trials, sync=sync, chmap=chmap, **kwargs) - if np.any(np.diff(ibpod) != 1) and self.sync_field == 'intervals_0': - # One issue is that sometimes pulses may not have been detected, in this case - # add the events that have not been detected and re-extract the behaviour sync. - # This is only really relevant for the Bpod interval events as the other TTLs are - # from devices where a missing TTL likely means the Bpod event was truly absent. - _logger.warning('Missing Bpod TTLs; reassigning events using aligned Bpod start times') - bpod_start = self.bpod_trials['intervals'][:, 0] - missing_bpod = self.bpod2fpga(bpod_start[np.setxor1d(ibpod, np.arange(len(bpod_start)))]) - t_trial_start = np.sort(np.r_[fpga_trials['intervals'][:, 0], missing_bpod]) - fpga_trials = self.extract_behaviour_sync(sync, chmap, start_times=t_trial_start, **kwargs) - - out = dict() - out.update({k: self.bpod_trials[k][ibpod] for k in self.bpod_fields}) - out.update({k: self.bpod2fpga(self.bpod_trials[k][ibpod]) for k in self.bpod_rsync_fields}) - out.update({k: fpga_trials[k][ifpga] for k in sorted(fpga_trials.keys())}) # extract the wheel data if any(x.startswith('wheel') for x in self.var_names): wheel, moves = self.get_wheel_positions(sync=sync, chmap=chmap, tmin=tmin, tmax=tmax) @@ -1096,9 +1081,16 @@ def extract_behaviour_sync(self, sync, chmap, start_times=None, display=False, * ] bpod_event_intervals[pretrial] = bpod_event_intervals[pretrial][1:, :] + t_trial_start = bpod_event_intervals['trial_start'][:, 0] t_iti_in, t_trial_end = bpod_event_intervals['trial_end'].T - # Drop last trial start if incomplete - t_trial_start = bpod_event_intervals['trial_start'][:len(t_trial_end), 0] + # Some protocols, e.g. Guido's ephys biased opto task, have no trial end TTL. + # This is not essential as the trial start is used to sync the clocks. + if t_trial_end.size == 0: + _logger.warning('No trial end / ITI in TTLs found') + t_trial_end = np.full_like(t_trial_start, np.nan) + else: + # Drop last trial start if incomplete + t_trial_start = t_trial_start[:len(t_trial_end)] t_valve_open = bpod_event_intervals['valve_open'][:, 0] t_ready_tone_in = audio_event_intervals['ready_tone'][:, 0] t_error_tone_in = audio_event_intervals['error_tone'][:, 0] @@ -1136,12 +1128,33 @@ def extract_behaviour_sync(self, sync, chmap, start_times=None, display=False, * for (event_name, event_times), c in zip(trials.to_df().items(), cycle(color_map)): plots.vertical_lines(event_times, ymin=0, ymax=ymax, ax=ax, color=c, label=event_name, linewidth=width) ax.legend() - ax.set_yticks([0, 1, 2]) - ax.set_yticklabels(['bpod', 'f2ttl', 'audio']) + ax.set_yticks([0, 1, 2, 3]) + ax.set_yticklabels(['', 'bpod', 'f2ttl', 'audio']) ax.set_ylim([0, 5]) return trials + def build_trials(self, fpga_trials, sync=None, chmap=None, **kwargs): + # Sync the Bpod clock to the DAQ + self.bpod2fpga, drift_ppm, ibpod, ifpga = self.sync_bpod_clock(self.bpod_trials, fpga_trials, self.sync_field) + + if np.any(np.diff(ibpod) != 1) and self.sync_field == 'intervals_0': + # One issue is that sometimes pulses may not have been detected, in this case + # add the events that have not been detected and re-extract the behaviour sync. + # This is only really relevant for the Bpod interval events as the other TTLs are + # from devices where a missing TTL likely means the Bpod event was truly absent. + _logger.warning('Missing Bpod TTLs; reassigning events using aligned Bpod start times') + bpod_start = self.bpod_trials['intervals'][:, 0] + missing_bpod = self.bpod2fpga(bpod_start[np.setxor1d(ibpod, np.arange(len(bpod_start)))]) + t_trial_start = np.sort(np.r_[fpga_trials['intervals'][:, 0], missing_bpod]) + fpga_trials = self.extract_behaviour_sync(sync, chmap, start_times=t_trial_start, **kwargs) + + out = dict() + out.update({k: self.bpod_trials[k][ibpod] for k in self.bpod_fields}) + out.update({k: self.bpod2fpga(self.bpod_trials[k][ibpod]) for k in self.bpod_rsync_fields}) + out.update({k: fpga_trials[k][ifpga] for k in sorted(fpga_trials.keys())}) + return out + def get_wheel_positions(self, *args, **kwargs): """Extract wheel and wheelMoves objects. @@ -1569,9 +1582,9 @@ def extract_behaviour_sync(self, sync, chmap, start_times=None, display=False, * for (event_name, event_times), c in zip(trials.to_df().items(), cycle(color_map)): plots.vertical_lines(event_times, ymin=0, ymax=ymax, ax=ax, color=c, label=event_name, linewidth=width) ax.legend() - ax.set_yticks([0, 1, 2]) - ax.set_yticklabels(['bpod', 'f2ttl', 'audio']) - ax.set_ylim([0, 5]) + ax.set_yticks([0, 1, 2, 3]) + ax.set_yticklabels(['', 'bpod', 'f2ttl', 'audio']) + ax.set_ylim([0, 4]) return trials diff --git a/ibllib/io/extractors/mesoscope.py b/ibllib/io/extractors/mesoscope.py index 4def5ed3a..84a7622e7 100644 --- a/ibllib/io/extractors/mesoscope.py +++ b/ibllib/io/extractors/mesoscope.py @@ -1,22 +1,24 @@ """Mesoscope (timeline) data extraction.""" import logging +from itertools import cycle import numpy as np +from scipy.signal import find_peaks import one.alf.io as alfio from one.util import ensure_list from one.alf.files import session_path_parts import matplotlib.pyplot as plt -from neurodsp.utils import falls +from matplotlib.colors import TABLEAU_COLORS from pkg_resources import parse_version from ibllib.plots.misc import squares, vertical_lines from ibllib.io.raw_daq_loaders import (extract_sync_timeline, timeline_get_channel, correct_counter_discontinuities, load_timeline_sync_and_chmap) import ibllib.io.extractors.base as extractors_base -from ibllib.io.extractors.ephys_fpga import FpgaTrials, WHEEL_TICKS, WHEEL_RADIUS_CM, get_sync_fronts, get_protocol_period +from ibllib.io.extractors.ephys_fpga import FpgaTrials, WHEEL_TICKS, WHEEL_RADIUS_CM, _assign_events_to_trial from ibllib.io.extractors.training_wheel import extract_wheel_moves from ibllib.io.extractors.camera import attribute_times -from ibllib.io.extractors.ephys_fpga import _assign_events_bpod +from brainbox.behavior.wheel import velocity_filtered _logger = logging.getLogger(__name__) @@ -102,103 +104,240 @@ def plot_timeline(timeline, channels=None, raw=True): class TimelineTrials(FpgaTrials): """Similar extraction to the FPGA, however counter and position channels are treated differently.""" - """one.alf.io.AlfBunch: The timeline data object""" timeline = None + """one.alf.io.AlfBunch: The timeline data object.""" + + sync_field = 'itiIn_times' # trial start events + """str: The trial event to synchronize (must be present in extracted trials).""" def __init__(self, *args, sync_collection='raw_sync_data', **kwargs): """An extractor for all ephys trial data, in Timeline time""" super().__init__(*args, **kwargs) self.timeline = alfio.load_object(self.session_path / sync_collection, 'DAQdata', namespace='timeline') - def _extract(self, sync=None, chmap=None, sync_collection='raw_sync_data', **kwargs) -> dict: - if not (sync or chmap): - sync, chmap = load_timeline_sync_and_chmap( - self.session_path / sync_collection, timeline=self.timeline, chmap=chmap) + def load_sync(self, sync_collection='raw_sync_data', chmap=None, **_): + """Load the DAQ sync and channel map data. + + Parameters + ---------- + sync_collection : str + The session subdirectory where the sync data are located. + chmap : dict + A map of channel names and their corresponding indices. If None, the channel map is + loaded using the :func:`ibllib.io.raw_daq_loaders.timeline_meta2chmap` method. + + Returns + ------- + one.alf.io.AlfBunch + A dictionary with keys ('times', 'polarities', 'channels'), containing the sync pulses + and the corresponding channel numbers. + dict + A map of channel names and their corresponding indices. + """ + if not self.timeline: + self.timeline = alfio.load_object(self.session_path / sync_collection, 'DAQdata', namespace='timeline') + sync, chmap = load_timeline_sync_and_chmap( + self.session_path / sync_collection, timeline=self.timeline, chmap=chmap) + return sync, chmap + def _extract(self, sync=None, chmap=None, sync_collection='raw_sync_data', **kwargs) -> dict: + trials = super()._extract(sync, chmap, sync_collection='raw_sync_data', **kwargs) if kwargs.get('display', False): plot_timeline(self.timeline, channels=chmap.keys(), raw=True) - trials = super()._extract(sync, chmap, sync_collection, extractor_type='ephys', **kwargs) - - # If no protocol number is defined, trim timestamps based on Bpod trials intervals - trials_table = trials['table'] - bpod = get_sync_fronts(sync, chmap['bpod']) - if kwargs.get('protocol_number') is None: - tmin = trials_table.intervals_0.iloc[0] - 1 - tmax = trials_table.intervals_1.iloc[-1] - # Ensure wheel is cut off based on trials - mask = np.logical_and(tmin <= trials['wheel_timestamps'], trials['wheel_timestamps'] <= tmax) - trials['wheel_timestamps'] = trials['wheel_timestamps'][mask] - trials['wheel_position'] = trials['wheel_position'][mask] - mask = np.logical_and(trials['wheelMoves_intervals'][:, 0] >= tmin, trials['wheelMoves_intervals'][:, 0] <= tmax) - trials['wheelMoves_intervals'] = trials['wheelMoves_intervals'][mask, :] + return trials + + def extract_behaviour_sync(self, sync, chmap, start_times=None, display=False, **kwargs): + """ + Extract task related event times from the sync. + + TODO Change docstring + The trial start times are the shortest Bpod TTLs and occur at the start of the trial. The + first trial start TTL of the session is longer and must be handled differently. The trial + start TTL is used to assign the other trial events to each trial. + + The trial end is the end of the so-called 'ITI' Bpod event TTL (classified as the longest + of the three Bpod event TTLs). Go cue audio TTLs are the shorter of the two expected audio + tones. The first of these after each trial start is taken to be the go cue time. Error + tones are longer audio TTLs and assigned as the last of such occurrence after each trial + start. The valve open Bpod TTLs are medium-length, the last of which is used for each trial. + The feedback times are times of either valve open or error tone as there should be only one + such event per trial. + + The stimulus times are taken from the frame2ttl events (with improbably high frequency TTLs + removed): the first TTL after each trial start is assumed to be the stim onset time; the + second to last and last are taken as the stimulus freeze and offset times, respectively. + + Parameters + ---------- + sync : dict + 'polarities' of fronts detected on sync trace for all 16 chans and their 'times' + chmap : dict + Map of channel names and their corresponding index. Default to constant. + start_times : numpy.array + An optional array of timestamps to separate trial events by. This is useful if after + syncing the clocks, some trial start TTLs are found to be missed. If None, uses + 'trial_start' Bpod event. + display : bool, matplotlib.pyplot.Axes + Show the full session sync pulses display. + + Returns + ------- + dict + A map of trial event timestamps. + """ + # Get the events from the sync. + # Store the cleaned frame2ttl, audio, and bpod pulses as this will be used for QC + self.frame2ttl = self.get_stimulus_update_times(sync, chmap, **kwargs) + self.audio, audio_event_intervals = self.get_audio_event_times(sync, chmap, **kwargs) + if not set(audio_event_intervals.keys()) >= {'ready_tone', 'error_tone'}: + raise ValueError( + 'Expected at least "ready_tone" and "error_tone" audio events.' + '`audio_event_ttls` kwarg may be incorrect.') + + self.bpod, bpod_event_intervals = self.get_bpod_event_times(sync, chmap, **kwargs) + if not set(bpod_event_intervals.keys()) >= {'valve_open', 'trial_end'}: + raise ValueError( + 'Expected at least "trial_end" and "valve_open" audio events. ' + '`bpod_event_ttls` kwarg may be incorrect.') + + t_iti_in, t_trial_end = bpod_event_intervals['trial_end'].T + trials = alfio.AlfBunch({ + 'itiIn_times': t_iti_in, + 'intervals_1': t_trial_end, + 'valveOpen_intervals': bpod_event_intervals['valve_open'], + 'goCue_times': audio_event_intervals['ready_tone'][:, 0], + 'errorTone_times': audio_event_intervals['error_tone'][:, 0] + }) + + if display: # pragma: no cover + width = 0.5 + ymax = 5 + if isinstance(display, bool): + plt.figure('Bpod FPGA Sync') + ax = plt.gca() + else: + ax = display + squares(self.bpod['times'], self.bpod['polarities'] * 0.4 + 1, ax=ax, color='k') + squares(self.frame2ttl['times'], self.frame2ttl['polarities'] * 0.4 + 2, ax=ax, color='k') + squares(self.audio['times'], self.audio['polarities'] * 0.4 + 3, ax=ax, color='k') + color_map = TABLEAU_COLORS.keys() + for (event_name, event_times), c in zip(trials.items(), cycle(color_map)): + vertical_lines(event_times.flat, ymin=0, ymax=ymax, ax=ax, color=c, label=event_name, linewidth=width) + ax.legend() + ax.set_yticks([0, 1, 2, 3]) + ax.set_yticklabels(['', 'bpod', 'f2ttl', 'audio']) + ax.set_ylim([0, 4]) + + return trials + + def build_trials(self, fpga_trials, sync=None, chmap=None, **kwargs): + # Sync the Bpod clock to the DAQ + self.bpod2fpga, drift_ppm, ibpod, ifpga = self.sync_bpod_clock(self.bpod_trials, fpga_trials, self.sync_field) + + out = dict() + out['intervals'] = self.bpod2fpga(self.bpod_trials['intervals']) + out['itiIn_times'] = fpga_trials['itiIn_times'][ifpga] + start_times = out['intervals'][:, 0] + + # Extract valve open times from the DAQ + valve_driver_ttls = fpga_trials.pop('valveOpen_intervals') + correct = self.bpod_trials['feedbackType'] == 1 + # If there is a reward_valve channel, the valve has + if any(ch['name'] == 'reward_valve' for ch in self.timeline['meta']['inputs']): + # TODO Let's look at the expected open length based on calibration and reward volume + # import scipy.interpolate + # # FIXME support v7 settings? + # fcn_vol2time = scipy.interpolate.pchip( + # self.bpod_extractor.settings['device_valve']['WATER_CALIBRATION_WEIGHT_PERDROP'], + # self.bpod_extractor.settings['device_valve']['WATER_CALIBRATION_OPEN_TIMES'] + # ) + # reward_time = fcn_vol2time(self.bpod_extractor.settings.get('REWARD_AMOUNT_UL')) / 1e3 + + # Use the driver TTLs to find the valve open times that correspond to the valve opening + valve_intervals, valve_open_times = self.get_valve_open_times(driver_ttls=valve_driver_ttls) + if valve_open_times.size != np.sum(correct): + _logger.warning( + 'Number of valve open times does not equal number of correct trials (%i != %i)', + valve_open_times.size, np.sum(correct)) + + out['valveOpen_times'] = _assign_events_to_trial(start_times, valve_open_times) else: - tmin, tmax = get_protocol_period(self.session_path, kwargs['protocol_number'], bpod) - bpod = get_sync_fronts(sync, chmap['bpod'], tmin, tmax) - - self.frame2ttl = get_sync_fronts(sync, chmap['frame2ttl'], tmin, tmax) # save for later access by QC - - # Replace valve open times with those extracted from the DAQ - # TODO Let's look at the expected open length based on calibration and reward volume - assert len(bpod['times']) > 0, 'No Bpod TTLs detected on DAQ' - _, driver_out, _, = _assign_events_bpod(bpod['times'], bpod['polarities'], False) - # Use the driver TTLs to find the valve open times that correspond to the valve opening - valve_open_times = self.get_valve_open_times(driver_ttls=driver_out) - assert len(valve_open_times) == sum(trials_table.feedbackType == 1) # TODO Relax assertion - correct = trials_table.feedbackType == 1 - trials['valveOpen_times'][correct] = valve_open_times - trials_table.feedback_times[correct] = valve_open_times - - # Replace audio events - self.audio = get_sync_fronts(sync, chmap['audio'], tmin, tmax) - # Attempt to assign the go cue and error tone onsets based on TTL length - go_cue, error_cue = self._assign_events_audio(self.audio['times'], self.audio['polarities']) - - assert error_cue.size == np.sum(~correct), 'N detected error tones does not match number of incorrect trials' - assert go_cue.size <= len(trials_table), 'More go cue tones detected than trials!' - - if go_cue.size < len(trials_table): - _logger.warning('%i go cue tones missed', len(trials_table) - go_cue.size) + # Use the valve controller TTLs recorded on the Bpod channel as the reward time + out['valveOpen_times'] = _assign_events_to_trial(start_times, valve_driver_ttls[:, 0]) + + # Stimulus times extracted the same as usual + out['stimFreeze_times'] = _assign_events_to_trial(start_times, self.frame2ttl['times'], take=-2) + out['stimOn_times'] = _assign_events_to_trial(start_times, self.frame2ttl['times'], take='first') + out['stimOff_times'] = _assign_events_to_trial(start_times, self.frame2ttl['times']) + + # Audio times + error_cue = fpga_trials['errorTone_times'] + if error_cue.size != np.sum(~correct): + _logger.warning( + 'N detected error tones does not match number of incorrect trials (%i != %i)', + error_cue.size, np.sum(~correct)) + go_cue = fpga_trials['goCue_times'] + out['goCue_times'] = _assign_events_to_trial(start_times, go_cue, take='first') + out['errorCue_times'] = _assign_events_to_trial(start_times, error_cue) + + if go_cue.size > start_times.size: + _logger.warning( + 'More go cue tones detected than trials! (%i vs %i)', go_cue.size, start_times.size) + elif go_cue.size < start_times.size: """ If the error cues are all assigned and some go cues are missed it may be that some - responses were so fast that the go cue and error tone merged. + responses were so fast that the go cue and error tone merged, or the go cue TTL was too + long. """ + _logger.warning('%i go cue tones missed', start_times.size - go_cue.size) err_trig = self.bpod2fpga(self.bpod_trials['errorCueTrigger_times']) go_trig = self.bpod2fpga(self.bpod_trials['goCueTrigger_times']) assert not np.any(np.isnan(go_trig)) - assert err_trig.size == go_trig.size - - def first_true(arr): - """Return the index of the first True value in an array.""" - indices = np.where(arr)[0] - return None if len(indices) == 0 else indices[0] + assert err_trig.size == go_trig.size # should be length of n trials with NaNs # Find which trials are missing a go cue - _go_cue = np.full(len(trials_table), np.nan) - for i, intervals in enumerate(trials_table[['intervals_0', 'intervals_1']].values): - idx = first_true(np.logical_and(go_cue > intervals[0], go_cue < intervals[1])) - if idx is not None: - _go_cue[i] = go_cue[idx] + _go_cue = _assign_events_to_trial(start_times, go_cue, take='first') + error_cue = _assign_events_to_trial(start_times, error_cue) + missing = np.isnan(_go_cue) # Get all the DAQ timestamps where audio channel was HIGH raw = timeline_get_channel(self.timeline, 'audio') raw = (raw - raw.min()) / (raw.max() - raw.min()) # min-max normalize ups = self.timeline.timestamps[raw > .5] # timestamps where input HIGH - for i in np.where(np.isnan(_go_cue))[0]: - # Get the timestamp of the first HIGH after the trigger times - _go_cue[i] = ups[first_true(ups > go_trig[i])] - idx = first_true(np.logical_and( - error_cue > trials_table['intervals_0'][i], - error_cue < trials_table['intervals_1'][i])) - if np.isnan(err_trig[i]): - if idx is not None: - error_cue = np.delete(error_cue, idx) # Remove mis-assigned error tone time - else: - error_cue[idx] = ups[first_true(ups > err_trig[i])] - go_cue = _go_cue - - trials_table.feedback_times[~correct] = error_cue - trials_table.goCue_times = go_cue - return {k: trials[k] for k in self.var_names} + + # Get the timestamps of the first HIGH after the trigger times (allow up to 200ms after). + # Indices of ups directly following a go trigger, or -1 if none found (or trigger NaN) + idx = attribute_times(ups, go_trig, tol=0.2, take='after') + # Trial indices that didn't have detected goCue and now has been assigned an `ups` index + assigned = np.where(idx != -1 & missing)[0] # ignore unassigned + _go_cue[assigned] = ups[idx[assigned]] + + # Remove mis-assigned error tone times (i.e. those that have now been assigned to goCue) + error_cue_without_trig, = np.where(~np.isnan(error_cue) & np.isnan(err_trig)) + i_to_remove = np.intersect1d(assigned, error_cue_without_trig, assume_unique=True) + error_cue[i_to_remove] = np.nan + + # For those trials where go cue was merged with the error cue and therefore mis-assigned, + # we must re-assign the error cue times as the first HIGH after the error trigger. + idx = attribute_times(ups, err_trig, tol=0.2, take='after') + assigned = np.where(idx != -1 & missing)[0] # ignore unassigned + error_cue[assigned] = ups[idx[assigned]] + out['goCue_times'] = _go_cue + out['errorCue_times'] = error_cue + + # Because we're not + assert np.intersect1d(out['goCue_times'], out['errorCue_times']).size == 0, \ + 'audio tones not assigned correctly; tones likely missed' + + # Feedback times + out['feedback_times'] = np.copy(out['valveOpen_times']) + ind_err = np.isnan(out['valveOpen_times']) + out['feedback_times'][ind_err] = out['errorCue_times'][ind_err] + + out.update({k: self.bpod_trials[k][ibpod] for k in self.bpod_fields}) + out.update({k: self.bpod2fpga(self.bpod_trials[k][ibpod]) for k in self.bpod_rsync_fields}) + + return out def extract_wheel_sync(self, ticks=WHEEL_TICKS, radius=WHEEL_RADIUS_CM, coding='x4', tmin=None, tmax=None): """ @@ -234,7 +373,7 @@ def extract_wheel_sync(self, ticks=WHEEL_TICKS, radius=WHEEL_RADIUS_CM, coding=' # Timeline evenly samples counter so we extract only change points d = np.diff(raw) - ind, = np.where(d.astype(int)) + ind, = np.where(~np.isclose(d, 0)) pos = raw[ind + 1] pos -= pos[0] # Start from zero pos = pos / ticks * np.pi * 2 * radius / int(coding[1]) # Convert to radians @@ -290,7 +429,7 @@ def get_wheel_positions(self, ticks=WHEEL_TICKS, radius=WHEEL_RADIUS_CM, coding= ax1.set_ylabel('DAQ wheel position / rad'), ax1.set_xlabel('Time / s') return wheel, moves - def get_valve_open_times(self, display=False, threshold=-2.5, floor_percentile=10, driver_ttls=None): + def get_valve_open_times(self, display=False, threshold=100, driver_ttls=None): """ Get the valve open times from the raw timeline voltage trace. @@ -299,44 +438,82 @@ def get_valve_open_times(self, display=False, threshold=-2.5, floor_percentile=1 display : bool Plot detected times on the raw voltage trace. threshold : float - The threshold for applying to analogue channels. - floor_percentile : float - 10% removes the percentile value of the analog trace before thresholding. This is to - avoid DC offset drift. + The threshold of voltage change to apply. The default was set by eye; units should be + Volts per sample but doesn't appear to be. driver_ttls : numpy.array An optional array of driver TTLs to use for assigning with the valve times. Returns ------- numpy.array - The detected valve open times. - - TODO extract close times too + The detected valve open intervals. + numpy.array + If driver_ttls is not None, returns an array of open times that occurred directly after + the driver TTLs. """ + WARN_THRESH = 10e-3 # open time threshold below which to log warning tl = self.timeline info = next(x for x in tl['meta']['inputs'] if x['name'] == 'reward_valve') values = tl['raw'][:, info['arrayColumn'] - 1] # Timeline indices start from 1 - offset = np.percentile(values, floor_percentile, axis=0) - idx = falls(values - offset, step=threshold) # Voltage falls when valve opens - open_times = tl['timestamps'][idx] + + # The voltage changes over ~1ms and can therefore occur over two DAQ samples at 2kHz + # making simple thresholding an issue. For this reason we convolve the signal with a + # window and detect the peaks and troughs. + if (Fs := tl['meta']['daqSampleRate']) != 2000: # e.g. 2kHz + _logger.warning('Reward valve detection not tested with a DAQ sample rate of %i', Fs) + dt = 1e-3 # change in voltage takes ~1ms when changing valve open state + N = dt / (1 / Fs) # this means voltage change occurs over N samples + vel, _ = velocity_filtered(values, int(Fs / N)) # filtered voltage change over time + ups, _ = find_peaks(vel, height=threshold) # valve closes (-5V -> 0V) + downs, _ = find_peaks(-1 * vel, height=threshold) # valve opens (0V -> -5V) + + # Convert these times into intervals + ixs = np.argsort(np.r_[downs, ups]) # sort indices + times = tl['timestamps'][np.r_[downs, ups]][ixs] # ordered valve event times + polarities = np.r_[np.zeros_like(downs) - 1, np.ones_like(ups)][ixs] # polarity sorted + missing = np.where(np.diff(polarities) == 0)[0] # if some changes were missed insert NaN + times = np.insert(times, missing + int(polarities[0] == -1), np.nan) + if polarities[-1] == -1: # ensure ends with a valve close + times = np.r_[times, np.nan] + if polarities[0] == 1: # ensure starts with a valve open + # It seems it can start out at -5V (open), then when the reward happens it closes and + # immediately opens. In this case we insert discard the first open time. + times = np.r_[np.nan, times] + intervals = times.reshape(-1, 2) + + # Log warning of improbably short intervals + short = np.sum(np.diff(intervals) < WARN_THRESH) + if short > 0: + _logger.warning('%i valve open intervals shorter than %i ms', short, WARN_THRESH) + # The closing of the valve is noisy. Keep only the falls that occur immediately after a Bpod TTL if driver_ttls is not None: # Returns an array of open_times indices, one for each driver TTL - ind = attribute_times(open_times, driver_ttls, tol=.1, take='after') - open_times = open_times[ind[ind >= 0]] + ind = attribute_times(intervals[:, 0], driver_ttls[:, 0], tol=.1, take='after') + open_times = intervals[ind[ind >= 0], 0] # TODO Log any > 40ms? Difficult to report missing valve times because of calibration if display: fig, (ax0, ax1) = plt.subplots(nrows=2, sharex=True) - ax0.plot(tl['timestamps'], timeline_get_channel(tl, 'bpod'), 'k-o') + ax0.plot(tl['timestamps'], timeline_get_channel(tl, 'bpod'), color='grey', linestyle='-') if driver_ttls is not None: - vertical_lines(driver_ttls, ymax=5, ax=ax0, linestyle='--', color='b') - ax1.plot(tl['timestamps'], values - offset, 'k-o') + x = np.empty_like(driver_ttls.flatten()) + x[0::2] = driver_ttls[:, 0] + x[1::2] = driver_ttls[:, 1] + y = np.ones_like(x) + y[1::2] -= 2 + squares(x, y, ax=ax0, yrange=[0, 5]) + # vertical_lines(driver_ttls, ymax=5, ax=ax0, linestyle='--', color='b') + ax0.plot(open_times, np.ones_like(open_times) * 4.5, 'g*') + ax1.plot(tl['timestamps'], values, 'k-o') ax1.set_ylabel('Voltage / V'), ax1.set_xlabel('Time / s') - ax1.plot(tl['timestamps'][idx], np.zeros_like(idx), 'r*') - if driver_ttls is not None: - ax1.plot(open_times, np.zeros_like(open_times), 'g*') - return open_times + + ax2 = ax1.twinx() + ax2.set_ylabel('dV', color='grey') + ax2.plot(tl['timestamps'], vel, linestyle='-', color='grey') + ax2.plot(intervals[:, 1], np.ones(len(intervals)) * threshold, 'r*', label='close') + ax2.plot(intervals[:, 0], np.ones(len(intervals)) * threshold, 'g*', label='open') + return intervals if driver_ttls is None else (intervals, open_times) def _assign_events_audio(self, audio_times, audio_polarities, display=False): """ @@ -360,7 +537,7 @@ def _assign_events_audio(self, audio_times, audio_polarities, display=False): """ # make sure that there are no 2 consecutive fall or consecutive rise events assert np.all(np.abs(np.diff(audio_polarities)) == 2) - # take only even time differences: ie. from rising to falling fronts + # take only even time differences: i.e. from rising to falling fronts dt = np.diff(audio_times) onsets = audio_polarities[:-1] == 1 diff --git a/ibllib/io/raw_daq_loaders.py b/ibllib/io/raw_daq_loaders.py index add980130..8ac58c3e7 100644 --- a/ibllib/io/raw_daq_loaders.py +++ b/ibllib/io/raw_daq_loaders.py @@ -292,7 +292,7 @@ def extract_sync_timeline(timeline, chmap=None, floor_percentile=10, threshold=N # Bidirectional; extract indices where delta != 0 raw = correct_counter_discontinuities(raw) d = np.diff(raw) - ind, = np.where(d.astype(int)) + ind, = np.where(~np.isclose(d, 0)) sync.polarities = np.concatenate((sync.polarities, np.sign(d[ind]).astype('i1'))) ind += 1 else: From 6b1416cb9a03dc3f2cd7dd064ab6830ae532666d Mon Sep 17 00:00:00 2001 From: Miles Wells Date: Fri, 8 Dec 2023 18:17:58 +0200 Subject: [PATCH 53/68] Build trials after syncing Bpod clock --- ibllib/io/extractors/ephys_fpga.py | 301 ++++++++++++++++------------ ibllib/io/extractors/mesoscope.py | 136 +++++++------ ibllib/pipes/ephys_preprocessing.py | 11 +- 3 files changed, 252 insertions(+), 196 deletions(-) diff --git a/ibllib/io/extractors/ephys_fpga.py b/ibllib/io/extractors/ephys_fpga.py index aa042ce8e..3c805293f 100644 --- a/ibllib/io/extractors/ephys_fpga.py +++ b/ibllib/io/extractors/ephys_fpga.py @@ -36,6 +36,7 @@ import uuid import re import warnings +from functools import partial import matplotlib.pyplot as plt from matplotlib.colors import TABLEAU_COLORS @@ -917,8 +918,9 @@ def _extract(self, sync=None, chmap=None, sync_collection='raw_ephys_data', Below are the steps involved: 0. Load sync and bpod trials, if required. 1. Determine protocol period and discard sync events outside the task. - 2. Classify and attribute DAQ TTLs to trial events (see :meth:`FpgaTrials.extract_behaviour_sync`). + 2. Classify multiplexed TTL events based on length (see :meth:`FpgaTrials.build_trials`). 3. Sync the Bpod clock to the DAQ clock using one of the assigned trial events. + 4. Assign classified TTL events to trial events based on order within the trial. 4. Convert Bpod software event times to DAQ clock. 5. Extract the wheel from the DAQ rotary encoder signal, if required. @@ -989,11 +991,8 @@ def _extract(self, sync=None, chmap=None, sync_collection='raw_ephys_data', _logger.debug('Protocol period from %.2fs to %.2fs (~%.0f min duration)', *sync['times'][[0, -1]], np.diff(sync['times'][[0, -1]]) / 60) - # Get the trial events from the DAQ sync TTLs - fpga_trials = self.extract_behaviour_sync(sync, chmap, **kwargs) - - # Sync clocks and build final trials datasets - out = self.build_trials(fpga_trials, sync=sync, chmap=chmap, **kwargs) + # Get the trial events from the DAQ sync TTLs, sync clocks and build final trials datasets + out = self.build_trials(sync=sync, chmap=chmap, **kwargs) # extract the wheel data if any(x.startswith('wheel') for x in self.var_names): @@ -1016,7 +1015,7 @@ def _extract(self, sync=None, chmap=None, sync_collection='raw_ephys_data', assert self.var_names == tuple(out.keys()) return out - def extract_behaviour_sync(self, sync, chmap, start_times=None, display=False, **kwargs): + def build_trials(self, sync, chmap, display=False, **kwargs): """ Extract task related event times from the sync. @@ -1042,10 +1041,6 @@ def extract_behaviour_sync(self, sync, chmap, start_times=None, display=False, * 'polarities' of fronts detected on sync trace for all 16 chans and their 'times' chmap : dict Map of channel names and their corresponding index. Default to constant. - start_times : numpy.array - An optional array of timestamps to separate trial events by. This is useful if after - syncing the clocks, some trial start TTLs are found to be missed. If None, uses - 'trial_start' Bpod event. display : bool, matplotlib.pyplot.Axes Show the full session sync pulses display. @@ -1068,50 +1063,58 @@ def extract_behaviour_sync(self, sync, chmap, start_times=None, display=False, * 'Expected at least "trial_start", "trial_end", and "valve_open" audio events. ' '`bpod_event_ttls` kwarg may be incorrect.') - # The first trial pulse is longer and often assigned to another event. - # Here we move the earliest non-trial_start event to the trial_start array. - t0 = bpod_event_intervals['trial_start'][0, 0] # expect 1st event to be trial_start - pretrial = [(k, v[0, 0]) for k, v in bpod_event_intervals.items() if v.size and v[0, 0] < t0] - if pretrial: - (pretrial, _) = sorted(pretrial, key=lambda x: x[1])[0] # take the earliest event - dt = np.diff(bpod_event_intervals[pretrial][0, :]) * 1e3 # record TTL length to log - _logger.debug('Reassigning first %s to trial_start. TTL length = %.3g ms', pretrial, dt) - bpod_event_intervals['trial_start'] = np.r_[ - bpod_event_intervals[pretrial][0:1, :], bpod_event_intervals['trial_start'] - ] - bpod_event_intervals[pretrial] = bpod_event_intervals[pretrial][1:, :] - - t_trial_start = bpod_event_intervals['trial_start'][:, 0] t_iti_in, t_trial_end = bpod_event_intervals['trial_end'].T - # Some protocols, e.g. Guido's ephys biased opto task, have no trial end TTL. - # This is not essential as the trial start is used to sync the clocks. - if t_trial_end.size == 0: - _logger.warning('No trial end / ITI in TTLs found') - t_trial_end = np.full_like(t_trial_start, np.nan) - else: - # Drop last trial start if incomplete - t_trial_start = t_trial_start[:len(t_trial_end)] - t_valve_open = bpod_event_intervals['valve_open'][:, 0] - t_ready_tone_in = audio_event_intervals['ready_tone'][:, 0] - t_error_tone_in = audio_event_intervals['error_tone'][:, 0] - - start_times = start_times or t_trial_start - - trials = alfio.AlfBunch({ - 'goCue_times': _assign_events_to_trial(start_times, t_ready_tone_in, take='first'), - 'errorCue_times': _assign_events_to_trial(start_times, t_error_tone_in), - 'valveOpen_times': _assign_events_to_trial(start_times, t_valve_open), - 'stimFreeze_times': _assign_events_to_trial(start_times, self.frame2ttl['times'], take=-2), - 'stimOn_times': _assign_events_to_trial(start_times, self.frame2ttl['times'], take='first'), - 'stimOff_times': _assign_events_to_trial(start_times, self.frame2ttl['times']), - 'itiIn_times': _assign_events_to_trial(start_times, t_iti_in) + fpga_events = alfio.AlfBunch({ + 'goCue_times': audio_event_intervals['ready_tone'][:, 0], + 'errorCue_times': audio_event_intervals['error_tone'][:, 0], + 'valveOpen_times': bpod_event_intervals['valve_open'][:, 0], + 'valveClose_times': bpod_event_intervals['valve_open'][:, 1], + 'itiIn_times': t_iti_in, + 'intervals_0': bpod_event_intervals['trial_start'][:, 0], + 'intervals_1': t_trial_end }) - # feedback times are valve open on correct trials and error tone in on incorrect trials - trials['feedback_times'] = np.copy(trials['valveOpen_times']) - ind_err = np.isnan(trials['valveOpen_times']) - trials['feedback_times'][ind_err] = trials['errorCue_times'][ind_err] - trials['intervals'] = np.c_[start_times, t_trial_end] + # Sync the Bpod clock to the DAQ. + # NB: The Bpod extractor typically drops the final, incomplete, trial. Hence there is + # usually at least one extra FPGA event. This shouldn't affect the sync. The final trial is + # dropped after assigning the FPGA events, using the `ifpga` index. Doing this after + # assigning the FPGA trial events ensures the last trial has the correct timestamps. + self.bpod2fpga, drift_ppm, ibpod, ifpga = self.sync_bpod_clock(self.bpod_trials, fpga_events, self.sync_field) + + if np.any(np.diff(ibpod) != 1) and self.sync_field == 'intervals_0': + # One issue is that sometimes pulses may not have been detected, in this case + # add the events that have not been detected and re-extract the behaviour sync. + # This is only really relevant for the Bpod interval events as the other TTLs are + # from devices where a missing TTL likely means the Bpod event was truly absent. + _logger.warning('Missing Bpod TTLs; reassigning events using aligned Bpod start times') + bpod_start = self.bpod_trials['intervals'][:, 0] + missing_bpod = self.bpod2fpga(bpod_start[np.setxor1d(ibpod, np.arange(len(bpod_start)))]) + t_trial_start = np.sort(np.r_[fpga_events['intervals_0'][:, 0], missing_bpod]) + else: + t_trial_start = fpga_events['intervals_0'] + + # Assign the FPGA events to individual trials + fpga_trials = { + 'goCue_times': _assign_events_to_trial(t_trial_start, fpga_events['goCue_times'], take='first'), + 'errorCue_times': _assign_events_to_trial(t_trial_start, fpga_events['errorCue_times']), + 'valveOpen_times': _assign_events_to_trial(t_trial_start, fpga_events['valveOpen_times']), + 'itiIn_times': _assign_events_to_trial(t_trial_start, fpga_events['itiIn_times']), + 'stimFreeze_times': _assign_events_to_trial(t_trial_start, self.frame2ttl['times'], take=-2), + 'stimOn_times': _assign_events_to_trial(t_trial_start, self.frame2ttl['times'], take='first'), + 'stimOff_times': _assign_events_to_trial(t_trial_start, self.frame2ttl['times']) + } + + # Feedback times are valve open on correct trials and error tone in on incorrect trials + fpga_trials['feedback_times'] = np.copy(fpga_trials['valveOpen_times']) + ind_err = np.isnan(fpga_trials['valveOpen_times']) + fpga_trials['feedback_times'][ind_err] = fpga_trials['errorCue_times'][ind_err] + + out = alfio.AlfBunch() + # Add the Bpod trial events, converting the timestamp fields to FPGA time. + # NB: The trial intervals are by default a Bpod rsync field. + out.update({k: self.bpod_trials[k][ibpod] for k in self.bpod_fields}) + out.update({k: self.bpod2fpga(self.bpod_trials[k][ibpod]) for k in self.bpod_rsync_fields}) + out.update({k: fpga_trials[k][ifpga] for k in fpga_trials.keys()}) if display: # pragma: no cover width = 0.5 @@ -1125,34 +1128,13 @@ def extract_behaviour_sync(self, sync, chmap, start_times=None, display=False, * plots.squares(self.frame2ttl['times'], self.frame2ttl['polarities'] * 0.4 + 2, ax=ax, color='k') plots.squares(self.audio['times'], self.audio['polarities'] * 0.4 + 3, ax=ax, color='k') color_map = TABLEAU_COLORS.keys() - for (event_name, event_times), c in zip(trials.to_df().items(), cycle(color_map)): + for (event_name, event_times), c in zip(fpga_events.items(), cycle(color_map)): plots.vertical_lines(event_times, ymin=0, ymax=ymax, ax=ax, color=c, label=event_name, linewidth=width) ax.legend() ax.set_yticks([0, 1, 2, 3]) ax.set_yticklabels(['', 'bpod', 'f2ttl', 'audio']) ax.set_ylim([0, 5]) - return trials - - def build_trials(self, fpga_trials, sync=None, chmap=None, **kwargs): - # Sync the Bpod clock to the DAQ - self.bpod2fpga, drift_ppm, ibpod, ifpga = self.sync_bpod_clock(self.bpod_trials, fpga_trials, self.sync_field) - - if np.any(np.diff(ibpod) != 1) and self.sync_field == 'intervals_0': - # One issue is that sometimes pulses may not have been detected, in this case - # add the events that have not been detected and re-extract the behaviour sync. - # This is only really relevant for the Bpod interval events as the other TTLs are - # from devices where a missing TTL likely means the Bpod event was truly absent. - _logger.warning('Missing Bpod TTLs; reassigning events using aligned Bpod start times') - bpod_start = self.bpod_trials['intervals'][:, 0] - missing_bpod = self.bpod2fpga(bpod_start[np.setxor1d(ibpod, np.arange(len(bpod_start)))]) - t_trial_start = np.sort(np.r_[fpga_trials['intervals'][:, 0], missing_bpod]) - fpga_trials = self.extract_behaviour_sync(sync, chmap, start_times=t_trial_start, **kwargs) - - out = dict() - out.update({k: self.bpod_trials[k][ibpod] for k in self.bpod_fields}) - out.update({k: self.bpod2fpga(self.bpod_trials[k][ibpod]) for k in self.bpod_rsync_fields}) - out.update({k: fpga_trials[k][ifpga] for k in sorted(fpga_trials.keys())}) return out def get_wheel_positions(self, *args, **kwargs): @@ -1233,7 +1215,7 @@ def get_bpod_event_times(self, sync, chmap, bpod_event_ttls=None, display=False, Gets the Bpod TTL times from the sync 'bpod' channel and classifies each TTL event by length. NB: The first trial has an abnormal trial_start TTL that is usually mis-assigned. - This is handled in the :meth:`FpgaTrials.extract_behaviour_sync` method. + This method accounts for this. Parameters ---------- @@ -1268,6 +1250,22 @@ def get_bpod_event_times(self, sync, chmap, bpod_event_ttls=None, display=False, bpod_event_intervals = self._assign_events( bpod['times'], bpod['polarities'], bpod_event_ttls, display=display) + if 'trial_start' not in bpod_event_intervals or bpod_event_intervals['trial_start'].size == 0: + return bpod, bpod_event_intervals + + # The first trial pulse is longer and often assigned to another event. + # Here we move the earliest non-trial_start event to the trial_start array. + t0 = bpod_event_intervals['trial_start'][0, 0] # expect 1st event to be trial_start + pretrial = [(k, v[0, 0]) for k, v in bpod_event_intervals.items() if v.size and v[0, 0] < t0] + if pretrial: + (pretrial, _) = sorted(pretrial, key=lambda x: x[1])[0] # take the earliest event + dt = np.diff(bpod_event_intervals[pretrial][0, :]) * 1e3 # record TTL length to log + _logger.debug('Reassigning first %s to trial_start. TTL length = %.3g ms', pretrial, dt) + bpod_event_intervals['trial_start'] = np.r_[ + bpod_event_intervals[pretrial][0:1, :], bpod_event_intervals['trial_start'] + ] + bpod_event_intervals[pretrial] = bpod_event_intervals[pretrial][1:, :] + return bpod, bpod_event_intervals @staticmethod @@ -1364,8 +1362,8 @@ def sync_bpod_clock(bpod_trials, fpga_trials, sync_field): bpod_trials : dict A dictionary of extracted Bpod trial events. fpga_trials : dict - A dictionary of trial events extracted from FPGA sync events (see - `extract_behaviour_sync` method). + A dictionary of TTL events extracted from FPGA sync (see `extract_behaviour_sync` + method). sync_field : str The trials key to use for syncing clocks. For intervals (i.e. Nx2 arrays) append the column index, e.g. 'intervals_0'. @@ -1387,27 +1385,28 @@ def sync_bpod_clock(bpod_trials, fpga_trials, sync_field): The key `sync_field` was not found in either the `bpod_trials` or `fpga_trials` dicts. """ _logger.info(f'Attempting to align Bpod clock to DAQ using trial event "{sync_field}"') - if sync_field not in bpod_trials: - # handle syncing on intervals - if not (m := re.match(r'(.*)_(\d)', sync_field)): - raise ValueError(f'Sync field "{sync_field}" not in extracted bpod trials') - sync_field, i = m.groups() - timestamps_bpod = bpod_trials[sync_field][:, int(i)] - timestamps_fpga = fpga_trials[sync_field][:, int(i)] - elif sync_field not in fpga_trials: - raise ValueError(f'Sync field "{sync_field}" not in extracted fpga trials') - else: - timestamps_bpod = bpod_trials[sync_field] - timestamps_fpga = fpga_trials[sync_field] + bpod_fpga_timestamps = [None, None] + for i, trials in enumerate((bpod_trials, fpga_trials)): + if sync_field not in trials: + # handle syncing on intervals + if not (m := re.match(r'(.*)_(\d)', sync_field)): + # If missing from bpod trials, either the sync field is incorrect, + # or the Bpod extractor is incorrect. If missing from the fpga events, check + # the sync field and the `extract_behaviour_sync` method. + raise ValueError( + f'Sync field "{sync_field}" not in extracted {"fpga" if i else "bpod"} events') + _sync_field, n = m.groups() + bpod_fpga_timestamps[i] = trials[_sync_field][:, int(n)] + else: + bpod_fpga_timestamps[i] = trials[sync_field] # Sync the two timestamps - fcn, drift, ibpod, ifpga = neurodsp.utils.sync_timestamps( - timestamps_bpod, timestamps_fpga, return_indices=True) + fcn, drift, ibpod, ifpga = neurodsp.utils.sync_timestamps(*bpod_fpga_timestamps, return_indices=True) # If it's drifting too much throw warning or error _logger.info('N trials: %i bpod, %i FPGA, %i merged, sync %.5f ppm', - len(timestamps_bpod), len(timestamps_fpga), len(ibpod), drift) - if drift > 200 and timestamps_bpod.size != timestamps_fpga.size: + *map(len, bpod_fpga_timestamps), len(ibpod), drift) + if drift > 200 and bpod_fpga_timestamps[0].size != bpod_fpga_timestamps[1].size: raise err.SyncBpodFpgaException('sync cluster f*ck') elif drift > BPOD_FPGA_DRIFT_THRESHOLD_PPM: _logger.warning('BPOD/FPGA synchronization shows values greater than %.2f ppm', @@ -1481,24 +1480,65 @@ def _extract(self, sync=None, chmap=None, sync_collection='raw_ephys_data', A dictionary of numpy arrays with `FpgaTrialsHabituation.var_names` as keys. """ # Version check: the ITI in TTL was added in a later version + if not self.settings: + self.settings = raw.load_settings(session_path=self.session_path, task_collection=task_collection) iblrig_version = version.parse(self.settings.get('IBL_VERSION', '0.0.0')) if version.parse('8.9.3') <= iblrig_version < version.parse('8.12.6'): """A second 1s TTL was added in this version during the 'iti' state, however this is unrelated to the trial ITI and is unfortunately the same length as the trial start TTL.""" raise NotImplementedError('Ambiguous TTLs in 8.9.3 >= version < 8.12.6') - # Currently (at least v8.12 and below) there is no trial start or end TTL, only an ITI pulse - if 'bpod_event_ttls' not in kwargs: - kwargs['bpod_event_ttls'] = {'trial_iti': (1, 1.1), 'valve_open': (0, 0.4)} trials = super()._extract(sync=sync, chmap=chmap, sync_collection=sync_collection, task_collection=task_collection, **kwargs) - n = trials['intervals'].shape[0] # number of trials - trials['intervals'][:, 1] = self.bpod2fpga(self.bpod_trials['intervals'][:n, 1]) - return trials - def extract_behaviour_sync(self, sync, chmap, start_times=None, display=False, **kwargs): + def get_bpod_event_times(self, sync, chmap, bpod_event_ttls=None, display=False, **kwargs): + """ + Extract Bpod times from sync. + + Currently (at least v8.12 and below) there is no trial start or end TTL, only an ITI pulse. + Also the first trial pulse is incorrectly assigned due to its abnormal length. + + Parameters + ---------- + sync : dict + A dictionary with keys ('times', 'polarities', 'channels'), containing the sync pulses + and the corresponding channel numbers. Must contain a 'bpod' key. + chmap : dict + A map of channel names and their corresponding indices. + bpod_event_ttls : dict of tuple + A map of event names to (min, max) TTL length. + + Returns + ------- + dict + A dictionary with keys {'times', 'polarities'} containing Bpod TTL fronts. + dict + A dictionary of events (from `bpod_event_ttls`) and their intervals as an Nx2 array. + """ + bpod = get_sync_fronts(sync, chmap['bpod']) + if bpod.times.size == 0: + raise err.SyncBpodFpgaException('No Bpod event found in FPGA. No behaviour extraction. ' + 'Check channel maps.') + # Assign the Bpod BNC2 events based on TTL length. The defaults are below, however these + # lengths are defined by the state machine of the task protocol and therefore vary. + if bpod_event_ttls is None: + # Currently (at least v8.12 and below) there is no trial start or end TTL, only an ITI pulse + bpod_event_ttls = {'trial_iti': (1, 1.1), 'valve_open': (0, 0.4)} + bpod_event_intervals = self._assign_events( + bpod['times'], bpod['polarities'], bpod_event_ttls, display=display) + + # The first trial pulse is shorter and assigned to valve_open. Here we remove the first + # valve event, prepend a 0 to the trial_start events, and drop the last trial if it was + # incomplete in Bpod. + bpod_event_intervals['trial_iti'] = np.r_[bpod_event_intervals['valve_open'][0:1, :], + bpod_event_intervals['trial_iti']] + bpod_event_intervals['valve_open'] = bpod_event_intervals['valve_open'][1:, :] + + return bpod, bpod_event_intervals + + def build_trials(self, sync, chmap, display=False, **kwargs): """ Extract task related event times from the sync. @@ -1511,10 +1551,6 @@ def extract_behaviour_sync(self, sync, chmap, start_times=None, display=False, * 'polarities' of fronts detected on sync trace for all 16 chans and their 'times' chmap : dict Map of channel names and their corresponding index. Default to constant. - start_times : numpy.array - An optional array of timestamps to separate trial events by. This is useful if after - syncing the clocks, some trial start TTLs are found to be missed. If None, uses - 'trial_start' Bpod event. display : bool, matplotlib.pyplot.Axes Show the full session sync pulses display. @@ -1532,40 +1568,47 @@ def extract_behaviour_sync(self, sync, chmap, start_times=None, display=False, * raise ValueError( 'Expected at least "trial_iti" and "valve_open" Bpod events. `bpod_event_ttls` kwarg may be incorrect.') - # The first trial pulse is shorter and assigned to valve_open. Here we remove the first - # valve event, prepend a 0 to the trial_start events, and drop the last trial if it was - # incomplete in Bpod. + fpga_events = alfio.AlfBunch({ + 'feedback_times': bpod_event_intervals['valve_open'][:, 0], + 'valveClose_times': bpod_event_intervals['valve_open'][:, 1], + 'intervals_0': bpod_event_intervals['trial_iti'][:, 1], + 'intervals_1': bpod_event_intervals['trial_iti'][:, 0], + 'goCue_times': audio_event_intervals['ready_tone'][:, 0] + }) n_trials = self.bpod_trials['intervals'].shape[0] - t_valve_open = bpod_event_intervals['valve_open'][1:, 0] # drop first spurious valve event - t_ready_tone_in = audio_event_intervals['ready_tone'][:, 0] - t_trial_start = np.r_[0, bpod_event_intervals['trial_iti'][:, 1]] - t_trial_end = bpod_event_intervals['trial_iti'][:, 0] - start_times = start_times or t_trial_start + # Sync the Bpod clock to the DAQ. + self.bpod2fpga, drift_ppm, ibpod, ifpga = self.sync_bpod_clock(self.bpod_trials, fpga_events, self.sync_field) + out = alfio.AlfBunch() + # Add the Bpod trial events, converting the timestamp fields to FPGA time. + # NB: The trial intervals are by default a Bpod rsync field. + out.update({k: self.bpod_trials[k][ibpod] for k in self.bpod_fields}) + out.update({k: self.bpod2fpga(self.bpod_trials[k][ibpod]) for k in self.bpod_rsync_fields}) + + # Assigning each event to a trial ensures exactly one event per trial (missing events are NaN) + assign_to_trial = partial(_assign_events_to_trial, fpga_events['intervals_0']) trials = alfio.AlfBunch({ - 'goCue_times': _assign_events_to_trial(start_times, t_ready_tone_in, take='first')[:n_trials], - 'feedback_times': _assign_events_to_trial(start_times, t_valve_open)[:n_trials], - 'stimFreeze_times': _assign_events_to_trial(start_times, self.frame2ttl['times'], take=-2)[:n_trials], - 'stimOn_times': _assign_events_to_trial(start_times, self.frame2ttl['times'], take='first')[:n_trials], - 'stimOff_times': _assign_events_to_trial(start_times, self.frame2ttl['times'])[:n_trials], - # These 'raw' intervals will be used in the sync - 'intervals_1': _assign_events_to_trial(start_times, t_trial_end), - 'intervals_0': start_times + 'goCue_times': assign_to_trial(fpga_events['goCue_times'], take='first')[:n_trials], + 'feedback_times': assign_to_trial(fpga_events['feedback_times'])[:n_trials], + 'stimCenter_times': assign_to_trial(self.frame2ttl['times'], take=-2)[:n_trials], + 'stimOn_times': assign_to_trial(self.frame2ttl['times'], take='first')[:n_trials], + 'stimOff_times': assign_to_trial(self.frame2ttl['times'])[:n_trials], }) # If stim on occurs before trial end, use stim on time. Likewise for trial end and stim off - trials['intervals'] = np.c_[trials['intervals_0'], trials['intervals_1']][:n_trials, :] - to_correct = ~np.isnan(trials['stimOn_times']) & (trials['stimOn_times'] < trials['intervals'][:, 0]) + to_correct = ~np.isnan(trials['stimOn_times']) & (trials['stimOn_times'] < out['intervals'][:, 0]) if np.any(to_correct): _logger.warning('%i/%i stim on events occurring outside trial intervals', sum(to_correct), len(to_correct)) - trials['intervals'][to_correct, 0] = trials['stimOn_times'][to_correct] - to_correct = ~np.isnan(trials['stimOff_times']) & (trials['stimOff_times'] > trials['intervals'][:, 1]) + out['intervals'][to_correct, 0] = trials['stimOn_times'][to_correct] + to_correct = ~np.isnan(trials['stimOff_times']) & (trials['stimOff_times'] > out['intervals'][:, 1]) if np.any(to_correct): _logger.debug( '%i/%i stim off events occurring outside trial intervals; using stim off times as trial end', sum(to_correct), len(to_correct)) - trials['intervals'][to_correct, 1] = trials['stimOff_times'][to_correct] + out['intervals'][to_correct, 1] = trials['stimOff_times'][to_correct] + + out.update({k: trials[k][ifpga] for k in trials.keys()}) if display: # pragma: no cover width = 0.5 @@ -1586,7 +1629,7 @@ def extract_behaviour_sync(self, sync, chmap, start_times=None, display=False, * ax.set_yticklabels(['', 'bpod', 'f2ttl', 'audio']) ax.set_ylim([0, 4]) - return trials + return out def extract_all(session_path, sync_collection='raw_ephys_data', save=True, save_path=None, @@ -1630,6 +1673,7 @@ def extract_all(session_path, sync_collection='raw_ephys_data', save=True, save_ 'ibllib.io.extractors.ephys_fpga.extract_all will be removed in future versions; ' 'use FpgaTrials instead. For reliable extraction, use the dynamic pipeline behaviour tasks.', FutureWarning) + return_extractor = kwargs.pop('return_extractor', False) # Extract Bpod trials bpod_raw = raw.load_data(session_path, task_collection=task_collection) assert bpod_raw is not None, 'No task trials data in raw_behavior_data - Exit' @@ -1646,7 +1690,10 @@ def extract_all(session_path, sync_collection='raw_ephys_data', save=True, save_ task_collection=task_collection, protocol_number=protocol_number, **kwargs) if not isinstance(outputs, dict): outputs = {k: v for k, v in zip(trials.var_names, outputs)} - return outputs, files + if return_extractor: + return outputs, files, trials + else: + return outputs, files def get_sync_and_chn_map(session_path, sync_collection): diff --git a/ibllib/io/extractors/mesoscope.py b/ibllib/io/extractors/mesoscope.py index 84a7622e7..e4ca6766b 100644 --- a/ibllib/io/extractors/mesoscope.py +++ b/ibllib/io/extractors/mesoscope.py @@ -1,6 +1,5 @@ """Mesoscope (timeline) data extraction.""" import logging -from itertools import cycle import numpy as np from scipy.signal import find_peaks @@ -8,7 +7,6 @@ from one.util import ensure_list from one.alf.files import session_path_parts import matplotlib.pyplot as plt -from matplotlib.colors import TABLEAU_COLORS from pkg_resources import parse_version from ibllib.plots.misc import squares, vertical_lines @@ -146,26 +144,51 @@ def _extract(self, sync=None, chmap=None, sync_collection='raw_sync_data', **kwa plot_timeline(self.timeline, channels=chmap.keys(), raw=True) return trials - def extract_behaviour_sync(self, sync, chmap, start_times=None, display=False, **kwargs): + def get_bpod_event_times(self, sync, chmap, bpod_event_ttls=None, display=False, **kwargs): """ - Extract task related event times from the sync. + Extract Bpod times from sync. + + Unlike the superclass method. This one doesn't reassign the first trial pulse. - TODO Change docstring - The trial start times are the shortest Bpod TTLs and occur at the start of the trial. The - first trial start TTL of the session is longer and must be handled differently. The trial - start TTL is used to assign the other trial events to each trial. + Parameters + ---------- + sync : dict + A dictionary with keys ('times', 'polarities', 'channels'), containing the sync pulses + and the corresponding channel numbers. Must contain a 'bpod' key. + chmap : dict + A map of channel names and their corresponding indices. + bpod_event_ttls : dict of tuple + A map of event names to (min, max) TTL length. + + Returns + ------- + dict + A dictionary with keys {'times', 'polarities'} containing Bpod TTL fronts. + dict + A dictionary of events (from `bpod_event_ttls`) and their intervals as an Nx2 array. + """ + # Assign the Bpod BNC2 events based on TTL length. The defaults are below, however these + # lengths are defined by the state machine of the task protocol and therefore vary. + if bpod_event_ttls is None: + # The trial start TTLs are often too short for the low sampling rate of the DAQ and are + # therefore not used in extraction + bpod_event_ttls = {'valve_open': (2.33e-4, 0.4), 'trial_end': (0.4, np.inf)} + bpod, bpod_event_intervals = super().get_bpod_event_times( + sync=sync, chmap=chmap, bpod_event_ttls=bpod_event_ttls, display=display, **kwargs) + + # TODO Here we can make use of the 'bpod_rising_edge' channel, if available + return bpod, bpod_event_intervals + + def build_trials(self, sync=None, chmap=None, **kwargs): + """ + Extract task related event times from the sync. - The trial end is the end of the so-called 'ITI' Bpod event TTL (classified as the longest - of the three Bpod event TTLs). Go cue audio TTLs are the shorter of the two expected audio - tones. The first of these after each trial start is taken to be the go cue time. Error - tones are longer audio TTLs and assigned as the last of such occurrence after each trial - start. The valve open Bpod TTLs are medium-length, the last of which is used for each trial. - The feedback times are times of either valve open or error tone as there should be only one - such event per trial. + The two major differences are that the sampling rate is lower for imaging so the short Bpod + trial start TTLs are often absent. For this reason, the sync happens using the ITI_in TTL. - The stimulus times are taken from the frame2ttl events (with improbably high frequency TTLs - removed): the first TTL after each trial start is assumed to be the stim onset time; the - second to last and last are taken as the stimulus freeze and offset times, respectively. + Second, the valve used at the mesoscope has a way to record the raw voltage across the + solenoid, giving a more accurate readout of the valve's activity. If the reward_valve + channel is present on the DAQ, this is used to extract the valve open times. Parameters ---------- @@ -173,12 +196,6 @@ def extract_behaviour_sync(self, sync, chmap, start_times=None, display=False, * 'polarities' of fronts detected on sync trace for all 16 chans and their 'times' chmap : dict Map of channel names and their corresponding index. Default to constant. - start_times : numpy.array - An optional array of timestamps to separate trial events by. This is useful if after - syncing the clocks, some trial start TTLs are found to be missed. If None, uses - 'trial_start' Bpod event. - display : bool, matplotlib.pyplot.Axes - Show the full session sync pulses display. Returns ------- @@ -201,46 +218,36 @@ def extract_behaviour_sync(self, sync, chmap, start_times=None, display=False, * '`bpod_event_ttls` kwarg may be incorrect.') t_iti_in, t_trial_end = bpod_event_intervals['trial_end'].T - trials = alfio.AlfBunch({ + fpga_events = alfio.AlfBunch({ 'itiIn_times': t_iti_in, 'intervals_1': t_trial_end, - 'valveOpen_intervals': bpod_event_intervals['valve_open'], 'goCue_times': audio_event_intervals['ready_tone'][:, 0], 'errorTone_times': audio_event_intervals['error_tone'][:, 0] }) - if display: # pragma: no cover - width = 0.5 - ymax = 5 - if isinstance(display, bool): - plt.figure('Bpod FPGA Sync') - ax = plt.gca() - else: - ax = display - squares(self.bpod['times'], self.bpod['polarities'] * 0.4 + 1, ax=ax, color='k') - squares(self.frame2ttl['times'], self.frame2ttl['polarities'] * 0.4 + 2, ax=ax, color='k') - squares(self.audio['times'], self.audio['polarities'] * 0.4 + 3, ax=ax, color='k') - color_map = TABLEAU_COLORS.keys() - for (event_name, event_times), c in zip(trials.items(), cycle(color_map)): - vertical_lines(event_times.flat, ymin=0, ymax=ymax, ax=ax, color=c, label=event_name, linewidth=width) - ax.legend() - ax.set_yticks([0, 1, 2, 3]) - ax.set_yticklabels(['', 'bpod', 'f2ttl', 'audio']) - ax.set_ylim([0, 4]) - - return trials - - def build_trials(self, fpga_trials, sync=None, chmap=None, **kwargs): # Sync the Bpod clock to the DAQ - self.bpod2fpga, drift_ppm, ibpod, ifpga = self.sync_bpod_clock(self.bpod_trials, fpga_trials, self.sync_field) + self.bpod2fpga, drift_ppm, ibpod, ifpga = self.sync_bpod_clock(self.bpod_trials, fpga_events, self.sync_field) out = dict() - out['intervals'] = self.bpod2fpga(self.bpod_trials['intervals']) - out['itiIn_times'] = fpga_trials['itiIn_times'][ifpga] + out.update({k: self.bpod_trials[k][ibpod] for k in self.bpod_fields}) + out.update({k: self.bpod2fpga(self.bpod_trials[k][ibpod]) for k in self.bpod_rsync_fields}) + start_times = out['intervals'][:, 0] + last_trial_end = out['intervals'][-1, 1] + + def assign_to_trial(events, take='last'): + """Assign DAQ events to trials. + + Because we may not have trial start TTLs on the DAQ (because of the low sampling rate), + there may be an extra last trial that's not in the Bpod intervals as the extractor + ignores the last trial. This function trims the input array before assigning so that + the last trial's events are correctly assigned. + """ + return _assign_events_to_trial(start_times, events[events <= last_trial_end], take) + out['itiIn_times'] = assign_to_trial(fpga_events['itiIn_times'][ifpga]) # Extract valve open times from the DAQ - valve_driver_ttls = fpga_trials.pop('valveOpen_intervals') + valve_driver_ttls = bpod_event_intervals['valve_open'] correct = self.bpod_trials['feedbackType'] == 1 # If there is a reward_valve channel, the valve has if any(ch['name'] == 'reward_valve' for ch in self.timeline['meta']['inputs']): @@ -260,25 +267,25 @@ def build_trials(self, fpga_trials, sync=None, chmap=None, **kwargs): 'Number of valve open times does not equal number of correct trials (%i != %i)', valve_open_times.size, np.sum(correct)) - out['valveOpen_times'] = _assign_events_to_trial(start_times, valve_open_times) + out['valveOpen_times'] = assign_to_trial(valve_open_times) else: # Use the valve controller TTLs recorded on the Bpod channel as the reward time - out['valveOpen_times'] = _assign_events_to_trial(start_times, valve_driver_ttls[:, 0]) + out['valveOpen_times'] = assign_to_trial(valve_driver_ttls[:, 0]) # Stimulus times extracted the same as usual - out['stimFreeze_times'] = _assign_events_to_trial(start_times, self.frame2ttl['times'], take=-2) - out['stimOn_times'] = _assign_events_to_trial(start_times, self.frame2ttl['times'], take='first') - out['stimOff_times'] = _assign_events_to_trial(start_times, self.frame2ttl['times']) + out['stimFreeze_times'] = assign_to_trial(self.frame2ttl['times'], take=-2) + out['stimOn_times'] = assign_to_trial(self.frame2ttl['times'], take='first') + out['stimOff_times'] = assign_to_trial(self.frame2ttl['times']) # Audio times - error_cue = fpga_trials['errorTone_times'] + error_cue = fpga_events['errorTone_times'] if error_cue.size != np.sum(~correct): _logger.warning( 'N detected error tones does not match number of incorrect trials (%i != %i)', error_cue.size, np.sum(~correct)) - go_cue = fpga_trials['goCue_times'] - out['goCue_times'] = _assign_events_to_trial(start_times, go_cue, take='first') - out['errorCue_times'] = _assign_events_to_trial(start_times, error_cue) + go_cue = fpga_events['goCue_times'] + out['goCue_times'] = assign_to_trial(go_cue, take='first') + out['errorCue_times'] = assign_to_trial(error_cue) if go_cue.size > start_times.size: _logger.warning( @@ -296,8 +303,8 @@ def build_trials(self, fpga_trials, sync=None, chmap=None, **kwargs): assert err_trig.size == go_trig.size # should be length of n trials with NaNs # Find which trials are missing a go cue - _go_cue = _assign_events_to_trial(start_times, go_cue, take='first') - error_cue = _assign_events_to_trial(start_times, error_cue) + _go_cue = assign_to_trial(go_cue, take='first') + error_cue = assign_to_trial(error_cue) missing = np.isnan(_go_cue) # Get all the DAQ timestamps where audio channel was HIGH @@ -334,9 +341,6 @@ def build_trials(self, fpga_trials, sync=None, chmap=None, **kwargs): ind_err = np.isnan(out['valveOpen_times']) out['feedback_times'][ind_err] = out['errorCue_times'][ind_err] - out.update({k: self.bpod_trials[k][ibpod] for k in self.bpod_fields}) - out.update({k: self.bpod2fpga(self.bpod_trials[k][ibpod]) for k in self.bpod_rsync_fields}) - return out def extract_wheel_sync(self, ticks=WHEEL_TICKS, radius=WHEEL_RADIUS_CM, coding='x4', tmin=None, tmax=None): diff --git a/ibllib/pipes/ephys_preprocessing.py b/ibllib/pipes/ephys_preprocessing.py index 26cef7050..09591ce2d 100644 --- a/ibllib/pipes/ephys_preprocessing.py +++ b/ibllib/pipes/ephys_preprocessing.py @@ -694,7 +694,8 @@ def _behaviour_criterion(self): ) def _extract_behaviour(self): - dsets, out_files = ephys_fpga.extract_all(self.session_path, save=True) + dsets, out_files, self.extractor = ephys_fpga.extract_all( + self.session_path, save=True, return_extractor=True) return dsets, out_files @@ -709,8 +710,12 @@ def _run(self, plot_qc=True): qc = TaskQC(self.session_path, one=self.one, log=_logger) qc.extractor = TaskQCExtractor(self.session_path, lazy=True, one=qc.one) # Extract extra datasets required for QC - qc.extractor.data = dsets - qc.extractor.extract_data() + qc.extractor.data = qc.extractor.rename_data(dsets) + wheel_ts_bpod = self.extractor.bpod2fpga(self.extractor.bpod_trials['wheel_timestamps']) + qc.extractor.data['wheel_timestamps_bpod'] = wheel_ts_bpod + qc.extractor.data['wheel_position_bpod'] = self.extractor.bpod_trials['wheel_position'] + qc.extractor.wheel_encoding = 'X4' + # Aggregate and update Alyx QC fields qc.run(update=True) From bac237450b6b1011c336af7d564f73311005e45d Mon Sep 17 00:00:00 2001 From: Miles Wells Date: Fri, 8 Dec 2023 18:48:12 +0200 Subject: [PATCH 54/68] Include wheel in Bpod trials dict passed to FpgaTrials --- ibllib/io/extractors/ephys_fpga.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ibllib/io/extractors/ephys_fpga.py b/ibllib/io/extractors/ephys_fpga.py index 3c805293f..187d216f6 100644 --- a/ibllib/io/extractors/ephys_fpga.py +++ b/ibllib/io/extractors/ephys_fpga.py @@ -1677,14 +1677,14 @@ def extract_all(session_path, sync_collection='raw_ephys_data', save=True, save_ # Extract Bpod trials bpod_raw = raw.load_data(session_path, task_collection=task_collection) assert bpod_raw is not None, 'No task trials data in raw_behavior_data - Exit' - bpod_trials, *_ = bpod_extract_all( + bpod_trials, bpod_wheel, *_ = bpod_extract_all( session_path=session_path, bpod_trials=bpod_raw, task_collection=task_collection, save=False, extractor_type=kwargs.get('extractor_type')) # Sync Bpod trials to FPGA sync, chmap = get_sync_and_chn_map(session_path, sync_collection) # sync, chmap = get_main_probe_sync(session_path, bin_exists=bin_exists) - trials = FpgaTrials(session_path, bpod_trials=bpod_trials) + trials = FpgaTrials(session_path, bpod_trials=bpod_trials | bpod_wheel) outputs, files = trials.extract( save=save, sync=sync, chmap=chmap, path_out=save_path, task_collection=task_collection, protocol_number=protocol_number, **kwargs) From ba2553dbb757d5049ec39da5f9f6b3d88baa6ab3 Mon Sep 17 00:00:00 2001 From: Miles Wells Date: Fri, 8 Dec 2023 19:18:05 +0200 Subject: [PATCH 55/68] Add more fields to qc extractor --- ibllib/pipes/ephys_preprocessing.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/ibllib/pipes/ephys_preprocessing.py b/ibllib/pipes/ephys_preprocessing.py index 09591ce2d..7ea845d18 100644 --- a/ibllib/pipes/ephys_preprocessing.py +++ b/ibllib/pipes/ephys_preprocessing.py @@ -715,6 +715,10 @@ def _run(self, plot_qc=True): qc.extractor.data['wheel_timestamps_bpod'] = wheel_ts_bpod qc.extractor.data['wheel_position_bpod'] = self.extractor.bpod_trials['wheel_position'] qc.extractor.wheel_encoding = 'X4' + qc.extractor.settings = self.extractor.settings + qc.extractor.frame_ttls = self.extractor.frame2ttl + qc.extractor.audio_ttls = self.extractor.audio + qc.extractor.bpod_ttls = self.extractor.bpod # Aggregate and update Alyx QC fields qc.run(update=True) From f3f057c13573e0e1d56f93590c396a83faa25fd2 Mon Sep 17 00:00:00 2001 From: Miles Wells Date: Mon, 11 Dec 2023 12:36:14 +0200 Subject: [PATCH 56/68] flake --- ibllib/io/extractors/mesoscope.py | 1 - 1 file changed, 1 deletion(-) diff --git a/ibllib/io/extractors/mesoscope.py b/ibllib/io/extractors/mesoscope.py index 299c6fca9..20b349eb0 100644 --- a/ibllib/io/extractors/mesoscope.py +++ b/ibllib/io/extractors/mesoscope.py @@ -7,7 +7,6 @@ from one.util import ensure_list from one.alf.files import session_path_parts import matplotlib.pyplot as plt -from neurodsp.utils import falls from packaging import version from ibllib.plots.misc import squares, vertical_lines From f5290b05d613ff9eb71fdcea63f01a9b31d757d7 Mon Sep 17 00:00:00 2001 From: Miles Wells Date: Mon, 11 Dec 2023 14:29:37 +0200 Subject: [PATCH 57/68] Skip phase distribution check; handle NaNs in stim move before go cue --- ibllib/qc/task_metrics.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/ibllib/qc/task_metrics.py b/ibllib/qc/task_metrics.py index 99c229f15..efe30f73c 100644 --- a/ibllib/qc/task_metrics.py +++ b/ibllib/qc/task_metrics.py @@ -275,10 +275,6 @@ def compute_session_status(self): class HabituationQC(TaskQC): - criteria = dict() - criteria['default'] = {'PASS': 0.99, 'WARNING': 0.90, 'FAIL': 0} # Note: WARNING was 0.95 prior to Aug 2022 - criteria['_task_phase_distribution'] = {'PASS': 0.99, 'NOT_SET': 0} # This rarely passes due to low trial num - def compute(self, download_data=None, **kwargs): """Compute and store the QC metrics. @@ -368,7 +364,7 @@ def compute(self, download_data=None, **kwargs): check = prefix + 'phase_distribution' metric, _ = np.histogram(data['phase']) _, p = chisquare(metric) - passed[check] = p < 0.05 + passed[check] = p < 0.05 if len(data['phase']) >= 400 else None # skip if too few trials metrics[check] = metric # Checks common to training QC @@ -1075,7 +1071,7 @@ def check_wheel_integrity(data, re_encoding='X1', enc_res=None, **_): # === Pre-stimulus checks === def check_stimulus_move_before_goCue(data, photodiode=None, **_): """ Check that there are no visual stimulus change(s) between the start of the trial and the - go cue sound onset, expect for stim on. + go cue sound onset, except for stim on. Metric: M = number of visual stimulus change events between trial start and goCue_times Criterion: M == 1 @@ -1088,6 +1084,7 @@ def check_stimulus_move_before_goCue(data, photodiode=None, **_): ----- - There should be exactly 1 stimulus change before goCue; stimulus onset. Even if the stimulus contrast is 0, the sync square will still flip at stimulus onset, etc. + - If there are no goCue times (all are NaN), the status should be NOT_SET. """ if photodiode is None: _log.warning('No photodiode TTL input in function call, returning None') @@ -1100,6 +1097,7 @@ def check_stimulus_move_before_goCue(data, photodiode=None, **_): metric = np.append(metric, np.count_nonzero(s[s > i] < c)) passed = (metric == 1).astype(float) + passed[np.isnan(data['goCue_times'])] = np.nan assert data['intervals'].shape[0] == len(metric) == len(passed) return metric, passed From 0ca5082fbd62191d6fbe7970af353bd1264241a9 Mon Sep 17 00:00:00 2001 From: Miles Wells Date: Mon, 11 Dec 2023 15:11:40 +0200 Subject: [PATCH 58/68] use get_bpod_event_times --- ibllib/io/extractors/ephys_fpga.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/ibllib/io/extractors/ephys_fpga.py b/ibllib/io/extractors/ephys_fpga.py index c4516c7c0..66dcda9ab 100644 --- a/ibllib/io/extractors/ephys_fpga.py +++ b/ibllib/io/extractors/ephys_fpga.py @@ -982,12 +982,14 @@ def _extract(self, sync=None, chmap=None, sync_collection='raw_ephys_data', # Older sessions don't have protocol spacers so we sync the Bpod intervals here to # find the approximate end time of the protocol (this will exclude the passive signals # in ephysChoiceWorld that tend to ruin the final trial extraction). - t_trial_start, *_ = _assign_events_bpod(bpod['times'], bpod['polarities']) - bpod_start = self.bpod_trials['intervals_bpod'][:, 0] - if len(t_trial_start) > len(bpod_start) / 2: + _, trial_ints = self.get_bpod_event_times(sync, chmap, **kwargs) + t_trial_start = trial_ints.get('trial_start', np.array([[np.nan, np.nan]]))[:, 0] + bpod_start = self.bpod_trials['intervals'][:, 0] + if len(t_trial_start) > len(bpod_start) / 2: # if least half the trial start TTLs detected + _logger.warning('Attempting to get protocol period from aligning trial start TTLs') fcn, *_ = neurodsp.utils.sync_timestamps(bpod_start, t_trial_start) - tmin = fcn(trials_table['intervals'][0, 0]) - 1 - tmax = fcn(trials_table['intervals'][-1, 1]) + 1 + tmin = fcn(self.bpod_trials['intervals'][0, 0]) - 1 + tmax = fcn(self.bpod_trials['intervals'][-1, 1]) + 1 else: # This type of alignment fails for some sessions, e.g. mesoscope tmin = tmax = None From b31d14e5113180b50621c985b2f230ba84da1dd3 Mon Sep 17 00:00:00 2001 From: Miles Wells Date: Tue, 12 Dec 2023 13:36:53 +0200 Subject: [PATCH 59/68] Increase default protocol period --- ibllib/io/extractors/ephys_fpga.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/ibllib/io/extractors/ephys_fpga.py b/ibllib/io/extractors/ephys_fpga.py index 66dcda9ab..ed90efb61 100644 --- a/ibllib/io/extractors/ephys_fpga.py +++ b/ibllib/io/extractors/ephys_fpga.py @@ -988,8 +988,10 @@ def _extract(self, sync=None, chmap=None, sync_collection='raw_ephys_data', if len(t_trial_start) > len(bpod_start) / 2: # if least half the trial start TTLs detected _logger.warning('Attempting to get protocol period from aligning trial start TTLs') fcn, *_ = neurodsp.utils.sync_timestamps(bpod_start, t_trial_start) - tmin = fcn(self.bpod_trials['intervals'][0, 0]) - 1 - tmax = fcn(self.bpod_trials['intervals'][-1, 1]) + 1 + buffer = 2.5 # the number of seconds to include before/after task + start, end = fcn(self.bpod_trials['intervals'].flat[[0, -1]]) + tmin = min(sync['times'][0], start - buffer) + tmax = max(sync['times'][-1], end + buffer) else: # This type of alignment fails for some sessions, e.g. mesoscope tmin = tmax = None From 49ad5a3326d8353687d45c2a83642300c73473d2 Mon Sep 17 00:00:00 2001 From: Miles Wells Date: Tue, 12 Dec 2023 15:25:28 +0200 Subject: [PATCH 60/68] Test new extractor methods --- ibllib/tests/extractors/test_ephys_fpga.py | 43 ++++++++++++++++------ 1 file changed, 31 insertions(+), 12 deletions(-) diff --git a/ibllib/tests/extractors/test_ephys_fpga.py b/ibllib/tests/extractors/test_ephys_fpga.py index fdfe27218..f5bf85491 100644 --- a/ibllib/tests/extractors/test_ephys_fpga.py +++ b/ibllib/tests/extractors/test_ephys_fpga.py @@ -4,6 +4,7 @@ from pathlib import Path import numpy as np +import numpy.testing from ibllib.io.extractors import ephys_fpga import spikeglx @@ -164,7 +165,7 @@ def test_audio_ttl_wiring_camera(self): audio_ = ephys_fpga._clean_audio(audio) expected = {'times': np.array([3399.4090251, 3399.50411559, 3400.306602, 3400.80061926]), 'polarities': np.array([1., -1., 1., -1.])} - assert all([np.all(audio_[k] == expected[k]) for k in audio_]) + assert all(np.all(audio_[k] == expected[k]) for k in audio_) def test_audio_ttl_start_up_down(self): """ @@ -173,18 +174,28 @@ def test_audio_ttl_start_up_down(self): The extraction should handle both cases seamlessly: cf eid d839491f-55d8-4cbe-a298-7839208ba12b """ - def _test_audio(audio): - ready, error = ephys_fpga._assign_events_audio(audio['times'], audio['polarities']) - assert np.all(ready == audio['times'][audio['ready_tone']]) - assert np.all(error == audio['times'][audio['error_tone']]) + def _test_audio(audio, audio_intervals): + for tone, intervals in audio_intervals.items(): + if tone == 'unassigned': + self.assertFalse(len(intervals)) + else: + np.testing.assert_array_almost_equal( + intervals[:, 0], audio['times'][audio[tone]], err_msg=tone) audio = { 'times': np.array([1740.1032, 1740.20176667, 1741.0786, 1741.57713333, 1744.78716667, 1744.88573333]), 'polarities': np.array([1., -1., 1., -1., 1., -1.]), 'error_tone': np.array([False, False, True, False, False, False]), - 'ready_tone': np.array([True, False, False, False, True, False]) + 'ready_tone': np.array([True, False, False, False, True, False]), + 'channels': np.zeros(6, dtype=int) } - _test_audio(audio) # this tests the usual pulses - _test_audio({k: audio[k][1:] for k in audio}) # this tests when it starts in upstate + extractor = ephys_fpga.FpgaTrials('subject/2023-01-01/000') # placeholder session path unused + _audio, audio_intervals = extractor.get_audio_event_times(audio, {'audio': 0}) + _test_audio(audio, audio_intervals) # this tests the usual pulses + audio['channels'][0] = 1 # this tests when it starts in upstate + audio['ready_tone'][0] = False # the first ready tone should now be skipped + _audio, audio_intervals = extractor.get_audio_event_times(audio, {'audio': 0}) + _test_audio(audio, audio_intervals) + np.testing.assert_array_equal(_audio['times'], audio['times'][1:]) def test_ttl_bpod_gaelle_writes_protocols_but_guido_doesnt_read_them(self): bpod_t = np.array([5.423290950005423, 6.397993470006398, 6.468919710006469, @@ -194,13 +205,21 @@ def test_ttl_bpod_gaelle_writes_protocols_but_guido_doesnt_read_them(self): 17.175015660017174, 18.204012750018205, 18.704029410018705, 19.286337840019286, 19.28643783001929, 21.76005711002176, 21.83095002002183, 22.85998044002286]) + pol = (np.mod(np.arange(bpod_t.size), 2) - 0.5) * 2 + sync = {'times': bpod_t, 'polarities': pol, 'channels': np.zeros_like(bpod_t, dtype=int)} + extractor = ephys_fpga.FpgaTrials('subject/2023-01-01/000') # placeholder session path unused + _, bpod_intervals = extractor.get_bpod_event_times(sync, {'bpod': 0}) + + expected = bpod_t[[1, 5, 9, 15]] + np.testing.assert_array_equal(bpod_intervals['trial_start'][:, 0], expected) + # when the bpod has started before the ephys, the first pulses may have been missed # and the first TTL may be negative/. This needs to yield the same result as if the # bpod was started properly - pol = (np.mod(np.arange(bpod_t.size), 2) - 0.5) * 2 - st, op, iti = ephys_fpga._assign_events_bpod(bpod_t=bpod_t, bpod_polarities=pol) - st_, op_, iti_ = ephys_fpga._assign_events_bpod(bpod_t=bpod_t[1:], bpod_polarities=pol[1:]) - assert np.all(st == st_) and np.all(op == op_) and np.all(iti_ == iti) + sync['channels'][0] = 1 + _, bpod_intervals_ = extractor.get_bpod_event_times(sync, {'bpod': 0}) + for event, intervals in bpod_intervals.items(): + np.testing.assert_array_equal(intervals, bpod_intervals_[event], err_msg=event) def test_frame2ttl_flickers(self): """ From 48cfcd04dfa0d0cfdff478a34befba26c2c90336 Mon Sep 17 00:00:00 2001 From: Miles Wells Date: Tue, 12 Dec 2023 16:22:38 +0200 Subject: [PATCH 61/68] Move test_ephys_trials tests to test_ephys_fpga; remove old functions --- ibllib/ephys/ephysqc.py | 47 +-- ibllib/io/extractors/ephys_fpga.py | 212 ------------ ibllib/tests/extractors/test_ephys_fpga.py | 326 ++++++++++++++++-- ibllib/tests/extractors/test_ephys_trials.py | 330 ------------------- ibllib/tests/extractors/test_extractors.py | 21 +- 5 files changed, 326 insertions(+), 610 deletions(-) delete mode 100644 ibllib/tests/extractors/test_ephys_trials.py diff --git a/ibllib/ephys/ephysqc.py b/ibllib/ephys/ephysqc.py index 16ab9f870..483fe51f9 100644 --- a/ibllib/ephys/ephysqc.py +++ b/ibllib/ephys/ephysqc.py @@ -3,7 +3,6 @@ """ from pathlib import Path import logging -import shutil import numpy as np import pandas as pd @@ -19,7 +18,7 @@ from brainbox.metrics.single_units import spike_sorting_metrics from ibllib.ephys import sync_probes, spikes from ibllib.qc import base -from ibllib.io.extractors import ephys_fpga, training_wheel +from ibllib.io.extractors import ephys_fpga from phylib.io import model @@ -370,8 +369,9 @@ def _single_test(assertion, str_ok, str_ko): str_ok="PASS: Bpod", str_ko="FAILED: Bpod") try: # note: tried to depend as little as possible on the extraction code but for the valve... - bpod = ephys_fpga.get_sync_fronts(rawsync, sync_map['bpod']) - _, t_valve_open, _ = ephys_fpga._assign_events_bpod(bpod['times'], bpod['polarities']) + extractor = ephys_fpga.FpgaTrials(ses_path) + bpod_intervals = extractor.get_bpod_event_times(sync, sync_map) + t_valve_open = bpod_intervals['valve_open'][:, 0] res = t_valve_open.size > 1 except AssertionError: res = False @@ -569,42 +569,3 @@ def strictly_after(t0, t1, threshold): qc_session = {k: np.all(qc_trials[k]) for k in qc_trials} return qc_session, qc_trials - - -def _qc_from_path(sess_path, display=True): - WHEEL = False - sess_path = Path(sess_path) - temp_alf_folder = sess_path.joinpath('fpga_test', 'alf') - temp_alf_folder.mkdir(parents=True, exist_ok=True) - - sync, chmap = ephys_fpga.get_main_probe_sync(sess_path, bin_exists=False) - _ = ephys_fpga.extract_all(sess_path, output_path=temp_alf_folder, save=True) - # check that the output is complete - fpga_trials, *_ = ephys_fpga.extract_behaviour_sync(sync, chmap=chmap, display=display) - # align with the bpod - bpod2fpga = ephys_fpga.align_with_bpod(temp_alf_folder.parent) - alf_trials = alfio.load_object(temp_alf_folder, 'trials') - shutil.rmtree(temp_alf_folder) - # do the QC - qcs, qct = qc_fpga_task(fpga_trials, alf_trials) - - # do the wheel part - if WHEEL: - bpod_wheel = training_wheel.get_wheel_data(sess_path, save=False) - fpga_wheel = ephys_fpga.extract_wheel_sync(sync, chmap=chmap, save=False) - - if display: - import matplotlib.pyplot as plt - t0 = max(np.min(bpod2fpga(bpod_wheel['re_ts'])), np.min(fpga_wheel['re_ts'])) - dy = np.interp(t0, fpga_wheel['re_ts'], fpga_wheel['re_pos']) - np.interp( - t0, bpod2fpga(bpod_wheel['re_ts']), bpod_wheel['re_pos']) - - fix, axes = plt.subplots(nrows=2, sharex='all', sharey='all') - # axes[0].plot(t, pos), axes[0].title.set_text('Extracted') - axes[0].plot(bpod2fpga(bpod_wheel['re_ts']), bpod_wheel['re_pos'] + dy) - axes[0].plot(fpga_wheel['re_ts'], fpga_wheel['re_pos']) - axes[0].title.set_text('FPGA') - axes[1].plot(bpod2fpga(bpod_wheel['re_ts']), bpod_wheel['re_pos'] + dy) - axes[1].title.set_text('Bpod') - - return alfio.dataframe({**fpga_trials, **alf_trials, **qct}) diff --git a/ibllib/io/extractors/ephys_fpga.py b/ibllib/io/extractors/ephys_fpga.py index ed90efb61..03eb7079c 100644 --- a/ibllib/io/extractors/ephys_fpga.py +++ b/ibllib/io/extractors/ephys_fpga.py @@ -187,62 +187,6 @@ def _sync_to_alf(raw_ephys_apfile, output_path=None, save=False, parts=''): return Bunch(sync) -def _assign_events_bpod(bpod_t, bpod_polarities, ignore_first_valve=True): - """ - From detected fronts on the bpod sync traces, outputs the synchronisation events - related to trial start and valve opening - :param bpod_t: numpy vector containing times of fronts - :param bpod_fronts: numpy vector containing polarity of fronts (1 rise, -1 fall) - :param ignore_first_valve (True): removes detected valve events at indices le 2 - :return: numpy arrays of times t_trial_start, t_valve_open and t_iti_in - - TODO Remove function (now using FpgaTrials._assign_events) - """ - TRIAL_START_TTL_LEN = 2.33e-4 # the TTL length is 0.1ms but this has proven to drift on - # some bpods and this is the highest possible value that discriminates trial start from valve - ITI_TTL_LEN = 0.4 - # make sure that there are no 2 consecutive fall or consecutive rise events - assert np.all(np.abs(np.diff(bpod_polarities)) == 2) - if bpod_polarities[0] == -1: - bpod_t = np.delete(bpod_t, 0) - # take only even time differences: ie. from rising to falling fronts - dt = np.diff(bpod_t)[::2] - # detect start trials event assuming length is 0.23 ms except the first trial - i_trial_start = np.r_[0, np.where(dt <= TRIAL_START_TTL_LEN)[0] * 2] - t_trial_start = bpod_t[i_trial_start] - # the last trial is a dud and should be removed - t_trial_start = t_trial_start[:-1] - # valve open events are between 50ms to 300 ms - i_valve_open = np.where(np.logical_and(dt > TRIAL_START_TTL_LEN, - dt < ITI_TTL_LEN))[0] * 2 - if ignore_first_valve: - i_valve_open = np.delete(i_valve_open, np.where(i_valve_open < 2)) - t_valve_open = bpod_t[i_valve_open] - # ITI events are above 400 ms - i_iti_in = np.where(dt > ITI_TTL_LEN)[0] * 2 - i_iti_in = np.delete(i_iti_in, np.where(i_valve_open < 2)) - t_iti_in = bpod_t[i_iti_in] - ## some debug plots when needed - # import matplotlib.pyplot as plt - # import ibllib.plots as plots - # events = {'id': np.zeros(bpod_t.shape), 't': bpod_t, 'p': bpod_polarities} - # events['id'][i_trial_start] = 1 - # events['id'][i_valve_open] = 2 - # events['id'][i_iti_in] = 3 - # i_abnormal = np.where(np.diff(events['id'][bpod_polarities != -1]) == 0) - # t_abnormal = events['t'][bpod_polarities != -1][i_abnormal] - # assert np.all(events != 0) - # plt.figure() - # plots.squares(bpod_t, bpod_polarities, label='raw fronts') - # plots.vertical_lines(t_trial_start, ymin=-0.2, ymax=1.1, linewidth=0.5, label='trial start') - # plots.vertical_lines(t_valve_open, ymin=-0.2, ymax=1.1, linewidth=0.5, label='valve open') - # plots.vertical_lines(t_iti_in, ymin=-0.2, ymax=1.1, linewidth=0.5, label='iti_in') - # plt.plot(t_abnormal, t_abnormal * 0 + .5, 'k*') - # plt.legend() - - return t_trial_start, t_valve_open, t_iti_in - - def _rotary_encoder_positions_from_fronts(ta, pa, tb, pb, ticks=WHEEL_TICKS, radius=WHEEL_RADIUS_CM, coding='x4'): """ Extracts the rotary encoder absolute position as function of time from fronts detected @@ -287,42 +231,6 @@ def _rotary_encoder_positions_from_fronts(ta, pa, tb, pb, ticks=WHEEL_TICKS, rad return t, p -def _assign_events_audio(audio_t, audio_polarities, return_indices=False, display=False): - """ - From detected fronts on the audio sync traces, outputs the synchronisation events - related to tone in - - :param audio_t: numpy vector containing times of fronts - :param audio_fronts: numpy vector containing polarity of fronts (1 rise, -1 fall) - :param return_indices (False): returns indices of tones - :param display (False): for debug mode, displays the raw fronts overlaid with detections - :return: numpy arrays t_ready_tone_in, t_error_tone_in - :return: numpy arrays ind_ready_tone_in, ind_error_tone_in if return_indices=True - - TODO Remove function (now using FpgaTrials._assign_events) - """ - # make sure that there are no 2 consecutive fall or consecutive rise events - assert np.all(np.abs(np.diff(audio_polarities)) == 2) - # take only even time differences: ie. from rising to falling fronts - dt = np.diff(audio_t) - # detect ready tone by length below 110 ms - i_ready_tone_in = np.where(np.logical_and(dt <= 0.11, audio_polarities[:-1] == 1))[0] - t_ready_tone_in = audio_t[i_ready_tone_in] - # error tones are events lasting from 400ms to 1200ms - i_error_tone_in = np.where(np.logical_and(np.logical_and(0.4 < dt, dt < 1.2), audio_polarities[:-1] == 1))[0] - t_error_tone_in = audio_t[i_error_tone_in] - if display: # pragma: no cover - from ibllib.plots import squares, vertical_lines - squares(audio_t, audio_polarities, yrange=[-1, 1],) - vertical_lines(t_ready_tone_in, ymin=-.8, ymax=.8) - vertical_lines(t_error_tone_in, ymin=-.8, ymax=.8) - - if return_indices: - return t_ready_tone_in, t_error_tone_in, i_ready_tone_in, i_error_tone_in - else: - return t_ready_tone_in, t_error_tone_in - - def _assign_events_to_trial(t_trial_start, t_event, take='last'): """ Assign events to a trial given trial start times and event times. @@ -512,126 +420,6 @@ def extract_wheel_sync(sync, chmap=None, tmin=None, tmax=None): return re_ts, re_pos -def extract_behaviour_sync(sync, chmap, display=False, bpod_trials=None, tmin=None, tmax=None): - """ - Extract task related event times from the sync. - - Parameters - ---------- - sync : dict - 'polarities' of fronts detected on sync trace for all 16 chans and their 'times' - chmap : dict - Map of channel names and their corresponding index. Default to constant. - display : bool, matplotlib.pyplot.Axes - Show the full session sync pulses display - bpod_trials : dict - The same trial events as recorded through Bpod. Assumed to contain an 'intervals_bpod' key. - tmin : float - The minimum time from which to extract the sync pulses. - tmax : float - The maximum time up to which we extract the sync pulses. - - Returns - ------- - dict - A map of trial event timestamps. - - TODO Remove this function (now using FpgaTrials.extract_behaviour_sync) - """ - bpod = get_sync_fronts(sync, chmap['bpod'], tmin=tmin, tmax=tmax) - if bpod.times.size == 0: - raise err.SyncBpodFpgaException('No Bpod event found in FPGA. No behaviour extraction. ' - 'Check channel maps.') - frame2ttl = get_sync_fronts(sync, chmap['frame2ttl'], tmin=tmin, tmax=tmax) - frame2ttl = _clean_frame2ttl(frame2ttl) - audio = get_sync_fronts(sync, chmap['audio'], tmin=tmin, tmax=tmax) - audio = _clean_audio(audio) - # extract events from the fronts for each trace - t_trial_start, t_valve_open, t_iti_in = _assign_events_bpod(bpod['times'], bpod['polarities']) - if not bpod_trials: - raise ValueError('No Bpod trials to align') - intervals_bpod = bpod_trials['intervals'] - # If there are no detected trial start times or more than double the trial end pulses, - # the trial start pulses may be too small to be detected, in which case, sync using the ini_in - if t_trial_start.size == 0 or (t_trial_start.size / t_iti_in.size) < .5: - _logger.info('Attempting to align on ITI in') - assert t_iti_in.size > 0, 'no detected ITI in TTLs on the DAQ to align' - bpod_end = bpod_trials['itiIn_times'] - fcn, drift = neurodsp.utils.sync_timestamps(bpod_end, t_iti_in) - # if it's drifting too much - if drift > 200 and bpod_end.size != t_iti_in.size: - raise err.SyncBpodFpgaException('sync cluster f*ck') - t_trial_start = fcn(intervals_bpod[:, 0]) - else: - # one issue is that sometimes bpod pulses may not have been detected, in this case - # perform the sync bpod/FPGA, and add the start that have not been detected - _logger.info('Attempting to align on trial start') - bpod_start = intervals_bpod[:, 0] - fcn, drift, ibpod, ifpga = neurodsp.utils.sync_timestamps( - bpod_start, t_trial_start, return_indices=True) - # if it's drifting too much - if drift > 200 and bpod_start.size != t_trial_start.size: - raise err.SyncBpodFpgaException('sync cluster f*ck') - missing_bpod = fcn(bpod_start[np.setxor1d(ibpod, np.arange(len(bpod_start)))]) - t_trial_start = np.sort(np.r_[t_trial_start, missing_bpod]) - - t_ready_tone_in, t_error_tone_in = _assign_events_audio(audio['times'], audio['polarities']) - trials = Bunch({ - 'goCue_times': _assign_events_to_trial(t_trial_start, t_ready_tone_in, take='first'), - 'errorCue_times': _assign_events_to_trial(t_trial_start, t_error_tone_in), - 'valveOpen_times': _assign_events_to_trial(t_trial_start, t_valve_open), - 'stimFreeze_times': _assign_events_to_trial(t_trial_start, frame2ttl['times'], take=-2), - 'stimOn_times': _assign_events_to_trial(t_trial_start, frame2ttl['times'], take='first'), - 'stimOff_times': _assign_events_to_trial(t_trial_start, frame2ttl['times']), - 'itiIn_times': _assign_events_to_trial(t_trial_start, t_iti_in) - }) - # feedback times are valve open on good trials and error tone in on error trials - trials['feedback_times'] = np.copy(trials['valveOpen_times']) - ind_err = np.isnan(trials['valveOpen_times']) - trials['feedback_times'][ind_err] = trials['errorCue_times'][ind_err] - trials['intervals'] = np.c_[t_trial_start, trials['itiIn_times']] - - if display: # pragma: no cover - width = 0.5 - ymax = 5 - if isinstance(display, bool): - plt.figure("Ephys FPGA Sync") - ax = plt.gca() - else: - ax = display - r0 = get_sync_fronts(sync, chmap['rotary_encoder_0'], tmin=tmin, tmax=tmax) - plots.squares(bpod['times'], bpod['polarities'] * 0.4 + 1, ax=ax, color='k') - plots.squares(frame2ttl['times'], frame2ttl['polarities'] * 0.4 + 2, ax=ax, color='k') - plots.squares(audio['times'], audio['polarities'] * 0.4 + 3, ax=ax, color='k') - plots.squares(r0['times'], r0['polarities'] * 0.4 + 4, ax=ax, color='k') - plots.vertical_lines(t_ready_tone_in, ymin=0, ymax=ymax, - ax=ax, label='goCue_times', color='b', linewidth=width) - plots.vertical_lines(t_trial_start, ymin=0, ymax=ymax, - ax=ax, label='start_trial', color='m', linewidth=width) - plots.vertical_lines(t_error_tone_in, ymin=0, ymax=ymax, - ax=ax, label='error tone', color='r', linewidth=width) - plots.vertical_lines(t_valve_open, ymin=0, ymax=ymax, - ax=ax, label='valveOpen_times', color='g', linewidth=width) - plots.vertical_lines(trials['stimFreeze_times'], ymin=0, ymax=ymax, - ax=ax, label='stimFreeze_times', color='y', linewidth=width) - plots.vertical_lines(trials['stimOff_times'], ymin=0, ymax=ymax, - ax=ax, label='stim off', color='c', linewidth=width) - plots.vertical_lines(trials['stimOn_times'], ymin=0, ymax=ymax, - ax=ax, label='stimOn_times', color='tab:orange', linewidth=width) - c = get_sync_fronts(sync, chmap['left_camera'], tmin=tmin, tmax=tmax) - plots.squares(c['times'], c['polarities'] * 0.4 + 5, ax=ax, color='k') - c = get_sync_fronts(sync, chmap['right_camera'], tmin=tmin, tmax=tmax) - plots.squares(c['times'], c['polarities'] * 0.4 + 6, ax=ax, color='k') - c = get_sync_fronts(sync, chmap['body_camera'], tmin=tmin, tmax=tmax) - plots.squares(c['times'], c['polarities'] * 0.4 + 7, ax=ax, color='k') - ax.legend() - ax.set_yticklabels(['', 'bpod', 'f2ttl', 'audio', 're_0', '']) - ax.set_yticks([0, 1, 2, 3, 4, 5]) - ax.set_ylim([0, 5]) - - return trials, frame2ttl, audio, bpod - - def extract_sync(session_path, overwrite=False, ephys_files=None, namespace='spikeglx'): """ Reads ephys binary file (s) and extract sync within the binary file folder diff --git a/ibllib/tests/extractors/test_ephys_fpga.py b/ibllib/tests/extractors/test_ephys_fpga.py index f5bf85491..a5d4ef254 100644 --- a/ibllib/tests/extractors/test_ephys_fpga.py +++ b/ibllib/tests/extractors/test_ephys_fpga.py @@ -2,12 +2,15 @@ import unittest import tempfile from pathlib import Path +import pickle import numpy as np -import numpy.testing +import spikeglx from ibllib.io.extractors import ephys_fpga -import spikeglx +from ibllib.io.extractors.training_wheel import extract_first_movement_times, infer_wheel_units +from ibllib.io.extractors.training_wheel import extract_wheel_moves +import brainbox.behavior.wheel as wh class TestsFolderStructure(unittest.TestCase): @@ -71,7 +74,7 @@ def sync_gen(self, fn, ns, nc, sync_depth): self.assertIn('SGLX sync found', log.output[0]) -class TestIblChannelMaps(unittest.TestCase): +class TestIBLChannelMaps(unittest.TestCase): def setUp(self): self.workdir = Path(__file__).parents[1] / 'fixtures' @@ -85,7 +88,7 @@ def test_ibl_sync_maps(self): self.assertEqual(s, ephys_fpga.CHMAPS['3B']['ap']) -class TestEphysFPGA_TTLsExtraction(unittest.TestCase): +class TestTTLsExtraction(unittest.TestCase): def test_audio_ttl_wiring_camera(self): """ @@ -197,6 +200,64 @@ def _test_audio(audio, audio_intervals): _test_audio(audio, audio_intervals) np.testing.assert_array_equal(_audio['times'], audio['times'][1:]) + def test_frame2ttl_flickers(self): + """ + Frame2ttl can flicker abnormally. One way to detect this is to remove consecutive polarity + switches under a given threshold + """ + DISPLAY = False # for debug purposes + F2TTL_THRESH = 0.01 + diff = F2TTL_THRESH * np.array([0.5, 10]) + + # flicker ends with a polarity switch - downgoing pulse is removed + t = np.r_[0, np.cumsum(diff[np.array([1, 1, 0, 0, 1])])] + 1 + frame2ttl = {'times': t, 'polarities': np.mod(np.arange(t.size) + 1, 2) * 2 - 1} + expected = {'times': np.array([1., 1.1, 1.2, 1.31]), + 'polarities': np.array([1, -1, 1, -1])} + frame2ttl_ = ephys_fpga._clean_frame2ttl(frame2ttl, display=DISPLAY, threshold=F2TTL_THRESH) + assert all(np.all(frame2ttl_[k] == expected[k]) for k in frame2ttl_) + + # stand-alone flicker + t = np.r_[0, np.cumsum(diff[np.array([1, 1, 0, 0, 0, 1])])] + 1 + frame2ttl = {'times': t, 'polarities': np.mod(np.arange(t.size) + 1, 2) * 2 - 1} + expected = {'times': np.array([1., 1.1, 1.2, 1.215, 1.315]), + 'polarities': np.array([1, -1, 1, -1, 1])} + frame2ttl_ = ephys_fpga._clean_frame2ttl(frame2ttl, display=DISPLAY) + assert all(np.all(frame2ttl_[k] == expected[k]) for k in frame2ttl_) + + def test_bpod_trace_extraction(self): + """Test FpgaTrials.get_bpod_event_times method.""" + expected = { + 'trial_start': np.array([0, 4, 8, 12, 14, 16, 20, 24, 26, 28, 32, 34, 38, 42, 46, 48, 52, 56]), + 'valve_open': np.array([2, 6, 10, 18, 22, 30, 36, 40, 44, 50, 54]) + } + bpod_fronts_ = np.array([1., -1., 1., -1., 1., -1., 1., -1., 1., -1., 1., -1., 1., + -1., 1., -1., 1., -1., 1., -1., 1., -1., 1., -1., 1., -1., + 1., -1., 1., -1., 1., -1., 1., -1., 1., -1., 1., -1., 1., + -1., 1., -1., 1., -1., 1., -1., 1., -1., 1., -1., 1., -1., + 1., -1., 1., -1., 1., -1.]) + + bpod_times_ = np.array([6.75033333, 109.7648, 117.12136667, 117.27136667, + 118.51416667, 118.51426667, 122.3873, 122.5373, + 123.7964, 123.7965, 127.82903333, 127.97903333, + 129.24503333, 129.24513333, 132.97976667, 132.97986667, + 136.8624, 136.8625, 140.56083333, 140.71083333, + 141.95523333, 141.95533333, 143.55326667, 143.70326667, + 144.93636667, 144.93646667, 149.5042, 149.5043, + 153.08273333, 153.08283333, 155.29713333, 155.44713333, + 156.70316667, 156.70326667, 164.0096, 164.0097, + 164.9186, 165.0686, 166.30633333, 166.30643333, + 167.91133333, 168.06133333, 169.28373333, 169.28386667, + 171.39736667, 171.54736667, 172.77786667, 172.77796667, + 176.7828, 176.7829, 178.0305, 178.1805, + 179.41063333, 179.41073333, 181.70343333, 181.85343333, + 183.12896667, 183.12906667]) + extractor = ephys_fpga.FpgaTrials('subject/2023-01-01/000') # placeholder session path unused + sync = {'times': bpod_times_, 'polarities': bpod_fronts_, 'channels': np.zeros_like(bpod_times_, dtype=int)} + _, bpod_intervals = extractor.get_bpod_event_times(sync, {'bpod': 0}) + for k in expected: + np.testing.assert_array_equal(bpod_intervals[k][:, 0], bpod_times_[expected[k]]) + def test_ttl_bpod_gaelle_writes_protocols_but_guido_doesnt_read_them(self): bpod_t = np.array([5.423290950005423, 6.397993470006398, 6.468919710006469, 7.497916800007498, 7.997933460007998, 8.599239990008599, @@ -221,30 +282,247 @@ def test_ttl_bpod_gaelle_writes_protocols_but_guido_doesnt_read_them(self): for event, intervals in bpod_intervals.items(): np.testing.assert_array_equal(intervals, bpod_intervals_[event], err_msg=event) - def test_frame2ttl_flickers(self): + def test_align_to_trial(self): + """Test ephys_fpga._assign_events_to_trial function.""" + # simple test with one missing at the end + t_trial_start = np.arange(0, 5) * 10 + t_event = np.arange(0, 5) * 10 + 2 + t_event_nans = ephys_fpga._assign_events_to_trial(t_trial_start, t_event) + self.assertTrue(np.allclose(t_event_nans, t_event, equal_nan=True, atol=0, rtol=0)) + + # test with missing values + t_trial_start = np.array([109, 118, 123, 129, 132, 136, 141, 144, 149, 153]) + t_event = np.array([122, 133, 140, 143, 146, 150, 154]) + t_event_out_ = np.array([np.nan, 122, np.nan, np.nan, 133, 140, 143, 146, 150, 154]) + t_event_nans = ephys_fpga._assign_events_to_trial(t_trial_start, t_event) + self.assertTrue(np.allclose(t_event_out_, t_event_nans, equal_nan=True, atol=0, rtol=0)) + + # test with events before initial start trial + t_trial_start = np.arange(0, 5) * 10 + t_event = np.arange(0, 5) * 10 - 2 + t_event_nans = ephys_fpga._assign_events_to_trial(t_trial_start, t_event) + desired_out = np.array([8., 18., 28., 38., np.nan]) + self.assertTrue(np.allclose(desired_out, t_event_nans, equal_nan=True, atol=0, rtol=0)) + + # test with several events per trial, missing events and events before + t_trial_start = np.array([0, 10, 20, 30, 40]) + t_event = np.array([-1, 2, 4, 12, 35, 42]) + t_event_nans = ephys_fpga._assign_events_to_trial(t_trial_start, t_event) + desired_out = np.array([4, 12., np.nan, 35, 42]) + self.assertTrue(np.allclose(desired_out, t_event_nans, equal_nan=True, atol=0, rtol=0)) + + # same test above but this time take the first index instead of last + t_event_nans = ephys_fpga._assign_events_to_trial(t_trial_start, t_event, take='first') + desired_out = np.array([2, 12., np.nan, 35, 42]) + self.assertTrue(np.allclose(desired_out, t_event_nans, equal_nan=True, atol=0, rtol=0)) + + # take second to last + t_trial_start = np.array([0, 10, 20, 30, 40]) + t_event = np.array([2, 4, 12, 13, 14, 21, 32, 33, 34, 35, 42]) + t_event_nans = ephys_fpga._assign_events_to_trial(t_trial_start, t_event, take=-2) + desired_out = np.array([2, 13, np.nan, 34, np.nan]) + self.assertTrue(np.allclose(desired_out, t_event_nans, equal_nan=True, atol=0, rtol=0)) + t_event_nans = ephys_fpga._assign_events_to_trial(t_trial_start, t_event, take=1) + desired_out = np.array([4, 13, np.nan, 33, np.nan]) + self.assertTrue(np.allclose(desired_out, t_event_nans, equal_nan=True, atol=0, rtol=0)) + + # test errors + self.assertRaises(ValueError, ephys_fpga._assign_events_to_trial, np.array([0., 2., 1.]), t_event) + self.assertRaises(ValueError, ephys_fpga._assign_events_to_trial, t_trial_start, np.array([0., 2., 1.])) + + +class TestWheelExtraction(unittest.TestCase): + + def setUp(self) -> None: + self.ta = np.array([2, 4, 6, 8, 12, 14, 16, 18]) + self.pa = np.array([1, -1, 1, -1, 1, -1, 1, -1]) + self.tb = np.array([3, 5, 7, 9, 11, 13, 15, 17]) + self.pb = np.array([1, -1, 1, -1, 1, -1, 1, -1]) + + def test_x1_decoding(self): + p_ = np.array([1, 2, 1, 0]) + t_ = np.array([2, 6, 11, 15]) + t, p = ephys_fpga._rotary_encoder_positions_from_fronts(self.ta, self.pa, self.tb, self.pb, ticks=np.pi * 2, coding='x1') + self.assertTrue(np.all(t == t_)) + self.assertTrue(np.all(p == p_)) + + def test_x4_decoding(self): + p_ = np.array([1, 2, 3, 4, 5, 6, 7, 8, 7, 6, 5, 4, 3, 2, 1, 0]) / 4 + t_ = np.array([2, 3, 4, 5, 6, 7, 8, 9, 11, 12, 13, 14, 15, 16, 17, 18]) + t, p = ephys_fpga._rotary_encoder_positions_from_fronts(self.ta, self.pa, self.tb, self.pb, ticks=np.pi * 2, coding='x4') + self.assertTrue(np.all(t == t_)) + self.assertTrue(np.all(np.isclose(p, p_))) + + def test_x2_decoding(self): + p_ = np.array([1, 2, 3, 4, 3, 2, 1, 0]) / 2 + t_ = np.array([2, 4, 6, 8, 12, 14, 16, 18]) + t, p = ephys_fpga._rotary_encoder_positions_from_fronts(self.ta, self.pa, self.tb, self.pb, ticks=np.pi * 2, coding='x2') + self.assertTrue(np.all(t == t_)) + self.assertTrue(np.all(p == p_)) + + def test_wheel_trace_from_sync(self): + """Test ephys_fpga._rotary_encoder_positions_from_fronts function.""" + pos_ = - np.array([-1, 0, -1, -2, -1, -2]) * (np.pi / ephys_fpga.WHEEL_TICKS) + ta = np.array([1, 2, 3, 4, 5, 6]) + tb = np.array([0.5, 3.2, 3.3, 3.4, 5.25, 5.5]) + pa = (np.mod(np.arange(6), 2) - 0.5) * 2 + pb = (np.mod(np.arange(6) + 1, 2) - .5) * 2 + t, pos = ephys_fpga._rotary_encoder_positions_from_fronts(ta, pa, tb, pb, coding='x2') + self.assertTrue(np.all(np.isclose(pos_, pos))) + + pos_ = - np.array([-1, 0, -1, 0, -1, -2]) * (np.pi / ephys_fpga.WHEEL_TICKS) + tb = np.array([0.5, 3.2, 3.4, 5.25]) + pb = (np.mod(np.arange(4) + 1, 2) - .5) * 2 + t, pos = ephys_fpga._rotary_encoder_positions_from_fronts(ta, pa, tb, pb, coding='x2') + self.assertTrue(np.all(np.isclose(pos_, pos))) + + +class TestExtractedWheelUnits(unittest.TestCase): + """Tests the infer_wheel_units function.""" + + wheel_radius_cm = 3.1 + + def setUp(self) -> None: """ - Frame2ttl can flicker abnormally. One way to detect this is to remove consecutive polarity - switches under a given threshold + Create the wheel position data for testing: the positions attribute holds a dictionary of + units, each holding a dictionary of encoding types to test, e.g. + + positions = { + 'rad': { + 'X1': ..., + 'X2': ..., + 'X4': ... + }, + 'cm': { + 'X1': ..., + 'X2': ..., + 'X4': ... + } + } + :return: """ - DISPLAY = False # for debug purposes - F2TTL_THRESH = 0.01 - diff = F2TTL_THRESH * np.array([0.5, 10]) - # flicker ends with a polarity switch - downgoing pulse is removed - t = np.r_[0, np.cumsum(diff[np.array([1, 1, 0, 0, 1])])] + 1 - frame2ttl = {'times': t, 'polarities': np.mod(np.arange(t.size) + 1, 2) * 2 - 1} - expected = {'times': np.array([1., 1.1, 1.2, 1.31]), - 'polarities': np.array([1, -1, 1, -1])} - frame2ttl_ = ephys_fpga._clean_frame2ttl(frame2ttl, display=DISPLAY, threshold=F2TTL_THRESH) - assert all([np.all(frame2ttl_[k] == expected[k]) for k in frame2ttl_]) + def x(unit, enc=int(1), wheel_radius=self.wheel_radius_cm): + radius = 1 if unit == 'rad' else wheel_radius + return 1 / ephys_fpga.WHEEL_TICKS * np.pi * 2 * radius / enc - # stand-alone flicker - t = np.r_[0, np.cumsum(diff[np.array([1, 1, 0, 0, 0, 1])])] + 1 - frame2ttl = {'times': t, 'polarities': np.mod(np.arange(t.size) + 1, 2) * 2 - 1} - expected = {'times': np.array([1., 1.1, 1.2, 1.215, 1.315]), - 'polarities': np.array([1, -1, 1, -1, 1])} - frame2ttl_ = ephys_fpga._clean_frame2ttl(frame2ttl, display=DISPLAY) - assert all([np.all(frame2ttl_[k] == expected[k]) for k in frame2ttl_]) + # A pseudo-random sequence of integrated fronts + seq = np.array([-1, 0, 1, 2, 1, 2, 3, 4, 3, 2, 1, 0, -1, -2, 1, -2]) + encs = (1, 2, 4) # Encoding types to test + units = ('rad', 'cm') # Units to test + self.positions = {unit: {f'X{e}': x(unit, e) * seq for e in encs} for unit in units} + + def test_extract_wheel_moves(self): + for unit in self.positions.keys(): + for encoding, pos in self.positions[unit].items(): + result = infer_wheel_units(pos) + self.assertEqual(unit, result[0], f'failed to determine units for {encoding}') + expected = int(ephys_fpga.WHEEL_TICKS * int(encoding[1])) + self.assertEqual(expected, result[1], f'failed to determine number of ticks for {encoding} in {unit}') + self.assertEqual(encoding, result[2], f'failed to determine encoding in {unit}') + + +class TestWheelMovesExtraction(unittest.TestCase): + + def setUp(self) -> None: + """ + Test data is in the form ((inputs), (outputs)) where inputs is a tuple containing a + numpy array of timestamps and one of positions; outputs is a tuple of outputs from + the functions. For details, see help on TestWheel.setUp method in module + brainbox.tests.test_behavior + """ + pickle_file = Path(__file__).parents[3].joinpath('brainbox', 'tests', 'fixtures', 'wheel_test.p') + if not pickle_file.exists(): + self.test_data = None + else: + with open(pickle_file, 'rb') as f: + self.test_data = pickle.load(f) + + # Some trial times for trial_data[1] + self.trials = {'goCue_times': np.array([162.5, 105.6, 55]), 'feedback_times': np.array([164.3, 108.3, 56])} + + def test_extract_wheel_moves(self): + test_data = self.test_data[1] + # Wrangle data into expected form + re_ts = test_data[0][0] + re_pos = test_data[0][1] + + logname = 'ibllib.io.extractors.training_wheel' + with self.assertLogs(logname, level='INFO') as cm: + wheel_moves = extract_wheel_moves(re_ts, re_pos) + self.assertEqual([f'INFO:{logname}:Wheel in cm units using X2 encoding'], cm.output) + + n = 56 # expected number of movements + self.assertTupleEqual(wheel_moves['intervals'].shape, (n, 2), 'failed to return the correct number of intervals') + self.assertEqual(wheel_moves['peakAmplitude'].size, n) + self.assertEqual(wheel_moves['peakVelocity_times'].size, n) + + # Check the first 3 intervals + ints = np.array([[24.78462599, 25.22562599], [29.58762599, 31.15062599], [31.64262599, 31.81662599]]) + actual = wheel_moves['intervals'][:3, ] + self.assertIsNone(np.testing.assert_allclose(actual, ints), 'unexpected intervals') + + # Check amplitudes + actual = wheel_moves['peakAmplitude'][-3:] + expected = [0.50255486, -1.70103154, 1.00740789] + self.assertIsNone(np.testing.assert_allclose(actual, expected), 'unexpected amplitudes') + + # Check peak velocities + actual = wheel_moves['peakVelocity_times'][-3:] + expected = [175.13662599, 176.65762599, 178.57262599] + self.assertIsNone(np.testing.assert_allclose(actual, expected), 'peak times') + + # Test extraction in rad + re_pos = wh.cm_to_rad(re_pos) + with self.assertLogs(logname, level='INFO') as cm: + wheel_moves = ephys_fpga.extract_wheel_moves(re_ts, re_pos) + self.assertEqual([f'INFO:{logname}:Wheel in rad units using X2 encoding'], cm.output) + + # Check the first 3 intervals. As position thresholds are adjusted by units and + # encoding, we should expect the intervals to be identical to above + actual = wheel_moves['intervals'][:3, ] + self.assertIsNone(np.testing.assert_allclose(actual, ints), 'unexpected intervals') + + def test_movement_log(self): + """ + Integration test for inferring the units and decoding type for wheel data input for + extract_wheel_moves. Only expected to work for the default wheel diameter. + """ + ta = np.array([2, 4, 6, 8, 12, 14, 16, 18]) + pa = np.array([1, -1, 1, -1, 1, -1, 1, -1]) + tb = np.array([3, 5, 7, 9, 11, 13, 15, 17]) + pb = np.array([1, -1, 1, -1, 1, -1, 1, -1]) + logname = 'ibllib.io.extractors.training_wheel' + + for unit in ['cm', 'rad']: + for i in (1, 2, 4): + encoding = 'X' + str(i) + r = 3.1 if unit == 'cm' else 1 + # print(encoding, unit) + t, p = ephys_fpga._rotary_encoder_positions_from_fronts( + ta, pa, tb, pb, ticks=1024, coding=encoding.lower(), radius=r) + expected = f'INFO:{logname}:Wheel in {unit} units using {encoding} encoding' + with self.assertLogs(logname, level='INFO') as cm: + ephys_fpga.extract_wheel_moves(t, p) + self.assertEqual([expected], cm.output) + + def test_extract_first_movement_times(self): + test_data = self.test_data[1] + wheel_moves = ephys_fpga.extract_wheel_moves(test_data[0][0], test_data[0][1]) + first, is_final, ind = extract_first_movement_times(wheel_moves, self.trials) + np.testing.assert_allclose(first, [162.48462599, 105.62562599, np.nan]) + np.testing.assert_array_equal(is_final, [False, True, False]) + np.testing.assert_array_equal(ind, [46, 18]) + + +class TestFpgaTrials(unittest.TestCase): + """Test FpgaTrials class.""" + + def test_time_fields(self): + """Test for FpgaTrials._time_fields static method.""" + expected = ('intervals', 'fooBar_times_bpod', 'spike_times', 'baz_timestamps') + fields = ephys_fpga.FpgaTrials._time_fields(expected + ('position', 'timebase', 'fooBaz')) + self.assertCountEqual(expected, fields) if __name__ == '__main__': diff --git a/ibllib/tests/extractors/test_ephys_trials.py b/ibllib/tests/extractors/test_ephys_trials.py deleted file mode 100644 index 7d77079af..000000000 --- a/ibllib/tests/extractors/test_ephys_trials.py +++ /dev/null @@ -1,330 +0,0 @@ -import unittest -from pathlib import Path -import pickle - -import numpy as np - -from ibllib.io.extractors import ephys_fpga, biased_trials -import ibllib.io.raw_data_loaders as raw -from ibllib.io.extractors.training_wheel import extract_first_movement_times, infer_wheel_units -from ibllib.io.extractors.training_wheel import extract_wheel_moves -import brainbox.behavior.wheel as wh - - -class TestEphysSyncExtraction(unittest.TestCase): - - def test_bpod_trace_extraction(self): - """Test ephys_fpga._assign_events_bpod function. - - TODO Remove this test and corresponding function. - """ - t_valve_open_ = np.array([117.12136667, 122.3873, 127.82903333, 140.56083333, - 143.55326667, 155.29713333, 164.9186, 167.91133333, - 171.39736667, 178.0305, 181.70343333]) - - t_trial_start_ = np.array([109.7647, 118.51416667, 123.7964, 129.24503333, - 132.97976667, 136.8624, 141.95523333, 144.93636667, - 149.5042, 153.08273333, 156.70316667, 164.0096, - 166.30633333, 169.28373333, 172.77786667, 176.7828, - 179.41063333]) - t_trial_start_[0] = 6.75033333 # rising front for first trial instead of falling - bpod_fronts_ = np.array([1., -1., 1., -1., 1., -1., 1., -1., 1., -1., 1., -1., 1., - -1., 1., -1., 1., -1., 1., -1., 1., -1., 1., -1., 1., -1., - 1., -1., 1., -1., 1., -1., 1., -1., 1., -1., 1., -1., 1., - -1., 1., -1., 1., -1., 1., -1., 1., -1., 1., -1., 1., -1., - 1., -1., 1., -1., 1., -1.]) - - bpod_times_ = np.array([6.75033333, 109.7648, 117.12136667, 117.27136667, - 118.51416667, 118.51426667, 122.3873, 122.5373, - 123.7964, 123.7965, 127.82903333, 127.97903333, - 129.24503333, 129.24513333, 132.97976667, 132.97986667, - 136.8624, 136.8625, 140.56083333, 140.71083333, - 141.95523333, 141.95533333, 143.55326667, 143.70326667, - 144.93636667, 144.93646667, 149.5042, 149.5043, - 153.08273333, 153.08283333, 155.29713333, 155.44713333, - 156.70316667, 156.70326667, 164.0096, 164.0097, - 164.9186, 165.0686, 166.30633333, 166.30643333, - 167.91133333, 168.06133333, 169.28373333, 169.28386667, - 171.39736667, 171.54736667, 172.77786667, 172.77796667, - 176.7828, 176.7829, 178.0305, 178.1805, - 179.41063333, 179.41073333, 181.70343333, 181.85343333, - 183.12896667, 183.12906667]) - - t_trial_start, t_valve_open, _ = ephys_fpga._assign_events_bpod(bpod_times_, - bpod_fronts_) - self.assertTrue(np.all(np.isclose(t_trial_start, t_trial_start_))) - self.assertTrue(np.all(np.isclose(t_valve_open, t_valve_open_))) - - def test_align_to_trial(self): - """Test ephys_fpga._assign_events_to_trial function.""" - # simple test with one missing at the end - t_trial_start = np.arange(0, 5) * 10 - t_event = np.arange(0, 5) * 10 + 2 - t_event_nans = ephys_fpga._assign_events_to_trial(t_trial_start, t_event) - self.assertTrue(np.allclose(t_event_nans, t_event, equal_nan=True, atol=0, rtol=0)) - - # test with missing values - t_trial_start = np.array([109, 118, 123, 129, 132, 136, 141, 144, 149, 153]) - t_event = np.array([122, 133, 140, 143, 146, 150, 154]) - t_event_out_ = np.array([np.nan, 122, np.nan, np.nan, 133, 140, 143, 146, 150, 154]) - t_event_nans = ephys_fpga._assign_events_to_trial(t_trial_start, t_event) - self.assertTrue(np.allclose(t_event_out_, t_event_nans, equal_nan=True, atol=0, rtol=0)) - - # test with events before initial start trial - t_trial_start = np.arange(0, 5) * 10 - t_event = np.arange(0, 5) * 10 - 2 - t_event_nans = ephys_fpga._assign_events_to_trial(t_trial_start, t_event) - desired_out = np.array([8., 18., 28., 38., np.nan]) - self.assertTrue(np.allclose(desired_out, t_event_nans, equal_nan=True, atol=0, rtol=0)) - - # test with several events per trial, missing events and events before - t_trial_start = np.array([0, 10, 20, 30, 40]) - t_event = np.array([-1, 2, 4, 12, 35, 42]) - t_event_nans = ephys_fpga._assign_events_to_trial(t_trial_start, t_event) - desired_out = np.array([4, 12., np.nan, 35, 42]) - self.assertTrue(np.allclose(desired_out, t_event_nans, equal_nan=True, atol=0, rtol=0)) - - # same test above but this time take the first index instead of last - t_event_nans = ephys_fpga._assign_events_to_trial(t_trial_start, t_event, take='first') - desired_out = np.array([2, 12., np.nan, 35, 42]) - self.assertTrue(np.allclose(desired_out, t_event_nans, equal_nan=True, atol=0, rtol=0)) - - # take second to last - t_trial_start = np.array([0, 10, 20, 30, 40]) - t_event = np.array([2, 4, 12, 13, 14, 21, 32, 33, 34, 35, 42]) - t_event_nans = ephys_fpga._assign_events_to_trial(t_trial_start, t_event, take=-2) - desired_out = np.array([2, 13, np.nan, 34, np.nan]) - self.assertTrue(np.allclose(desired_out, t_event_nans, equal_nan=True, atol=0, rtol=0)) - t_event_nans = ephys_fpga._assign_events_to_trial(t_trial_start, t_event, take=1) - desired_out = np.array([4, 13, np.nan, 33, np.nan]) - self.assertTrue(np.allclose(desired_out, t_event_nans, equal_nan=True, atol=0, rtol=0)) - - # test errors - self.assertRaises(ValueError, ephys_fpga._assign_events_to_trial, np.array([0., 2., 1.]), t_event) - self.assertRaises(ValueError, ephys_fpga._assign_events_to_trial, t_trial_start, np.array([0., 2., 1.])) - - def test_wheel_trace_from_sync(self): - """Test ephys_fpga._rotary_encoder_positions_from_fronts function.""" - pos_ = - np.array([-1, 0, -1, -2, -1, -2]) * (np.pi / ephys_fpga.WHEEL_TICKS) - ta = np.array([1, 2, 3, 4, 5, 6]) - tb = np.array([0.5, 3.2, 3.3, 3.4, 5.25, 5.5]) - pa = (np.mod(np.arange(6), 2) - 0.5) * 2 - pb = (np.mod(np.arange(6) + 1, 2) - .5) * 2 - t, pos = ephys_fpga._rotary_encoder_positions_from_fronts(ta, pa, tb, pb, coding='x2') - self.assertTrue(np.all(np.isclose(pos_, pos))) - - pos_ = - np.array([-1, 0, -1, 0, -1, -2]) * (np.pi / ephys_fpga.WHEEL_TICKS) - tb = np.array([0.5, 3.2, 3.4, 5.25]) - pb = (np.mod(np.arange(4) + 1, 2) - .5) * 2 - t, pos = ephys_fpga._rotary_encoder_positions_from_fronts(ta, pa, tb, pb, coding='x2') - self.assertTrue(np.all(np.isclose(pos_, pos))) - - def test_time_fields(self): - """Test for FpgaTrials._time_fields static method.""" - expected = ('intervals', 'fooBar_times_bpod', 'spike_times', 'baz_timestamps') - fields = ephys_fpga.FpgaTrials._time_fields(expected + ('position', 'timebase', 'fooBaz')) - self.assertCountEqual(expected, fields) - - -class TestEphysBehaviorExtraction(unittest.TestCase): - def setUp(self): - self.session_path = Path(__file__).parent.joinpath('data', 'session_ephys') - - def test_get_probabilityLeft(self): - data = raw.load_data(self.session_path) - settings = raw.load_settings(self.session_path) - *_, pLeft0, _ = biased_trials.ProbaContrasts( - self.session_path).extract(bpod_trials=data, settings=settings)[0] - self.assertTrue(len(pLeft0) == len(data)) - # Test if only generative prob values in data - self.assertTrue(all([x in [0.2, 0.5, 0.8] for x in np.unique(pLeft0)])) - # Test if settings file has empty LEN_DATA result is same - settings.update({"LEN_BLOCKS": None}) - *_, pLeft1, _ = biased_trials.ProbaContrasts( - self.session_path).extract(bpod_trials=data, settings=settings)[0] - self.assertTrue(all(pLeft0 == pLeft1)) - # Test if only generative prob values in data - self.assertTrue(all([x in [0.2, 0.5, 0.8] for x in np.unique(pLeft1)])) - - -class TestWheelExtraction(unittest.TestCase): - - def setUp(self) -> None: - self.ta = np.array([2, 4, 6, 8, 12, 14, 16, 18]) - self.pa = np.array([1, -1, 1, -1, 1, -1, 1, -1]) - self.tb = np.array([3, 5, 7, 9, 11, 13, 15, 17]) - self.pb = np.array([1, -1, 1, -1, 1, -1, 1, -1]) - - def test_x1_decoding(self): - p_ = np.array([1, 2, 1, 0]) - t_ = np.array([2, 6, 11, 15]) - t, p = ephys_fpga._rotary_encoder_positions_from_fronts( - self.ta, self.pa, self.tb, self.pb, ticks=np.pi * 2, coding='x1') - self.assertTrue(np.all(t == t_)) - self.assertTrue(np.all(p == p_)) - - def test_x4_decoding(self): - p_ = np.array([1, 2, 3, 4, 5, 6, 7, 8, 7, 6, 5, 4, 3, 2, 1, 0]) / 4 - t_ = np.array([2, 3, 4, 5, 6, 7, 8, 9, 11, 12, 13, 14, 15, 16, 17, 18]) - t, p = ephys_fpga._rotary_encoder_positions_from_fronts( - self.ta, self.pa, self.tb, self.pb, ticks=np.pi * 2, coding='x4') - self.assertTrue(np.all(t == t_)) - self.assertTrue(np.all(np.isclose(p, p_))) - - def test_x2_decoding(self): - p_ = np.array([1, 2, 3, 4, 3, 2, 1, 0]) / 2 - t_ = np.array([2, 4, 6, 8, 12, 14, 16, 18]) - t, p = ephys_fpga._rotary_encoder_positions_from_fronts( - self.ta, self.pa, self.tb, self.pb, ticks=np.pi * 2, coding='x2') - self.assertTrue(np.all(t == t_)) - self.assertTrue(np.all(p == p_)) - - -class TestExtractedWheelUnits(unittest.TestCase): - """Tests the infer_wheel_units function""" - - wheel_radius_cm = 3.1 - - def setUp(self) -> None: - """ - Create the wheel position data for testing: the positions attribute holds a dictionary of - units, each holding a dictionary of encoding types to test, e.g. - - positions = { - 'rad': { - 'X1': ..., - 'X2': ..., - 'X4': ... - }, - 'cm': { - 'X1': ..., - 'X2': ..., - 'X4': ... - } - } - :return: - """ - def x(unit, enc=int(1), wheel_radius=self.wheel_radius_cm): - radius = 1 if unit == 'rad' else wheel_radius - return 1 / ephys_fpga.WHEEL_TICKS * np.pi * 2 * radius / enc - - # A pseudo-random sequence of integrated fronts - seq = np.array([-1, 0, 1, 2, 1, 2, 3, 4, 3, 2, 1, 0, -1, -2, 1, -2]) - encs = (1, 2, 4) # Encoding types to test - units = ('rad', 'cm') # Units to test - self.positions = {unit: {f'X{e}': x(unit, e) * seq for e in encs} for unit in units} - - def test_extract_wheel_moves(self): - for unit in self.positions.keys(): - for encoding, pos in self.positions[unit].items(): - result = infer_wheel_units(pos) - self.assertEqual(unit, result[0], f'failed to determine units for {encoding}') - expected = int(ephys_fpga.WHEEL_TICKS * int(encoding[1])) - self.assertEqual(expected, result[1], - f'failed to determine number of ticks for {encoding} in {unit}') - self.assertEqual(encoding, result[2], f'failed to determine encoding in {unit}') - - -class TestWheelMovesExtraction(unittest.TestCase): - - def setUp(self) -> None: - """ - Test data is in the form ((inputs), (outputs)) where inputs is a tuple containing a - numpy array of timestamps and one of positions; outputs is a tuple of outputs from - the functions. For details, see help on TestWheel.setUp method in module - brainbox.tests.test_behavior - """ - pickle_file = Path(__file__).parents[3].joinpath( - 'brainbox', 'tests', 'fixtures', 'wheel_test.p') - if not pickle_file.exists(): - self.test_data = None - else: - with open(pickle_file, 'rb') as f: - self.test_data = pickle.load(f) - - # Some trial times for trial_data[1] - self.trials = { - 'goCue_times': np.array([162.5, 105.6, 55]), - 'feedback_times': np.array([164.3, 108.3, 56]) - } - - def test_extract_wheel_moves(self): - test_data = self.test_data[1] - # Wrangle data into expected form - re_ts = test_data[0][0] - re_pos = test_data[0][1] - - logname = 'ibllib.io.extractors.training_wheel' - with self.assertLogs(logname, level='INFO') as cm: - wheel_moves = extract_wheel_moves(re_ts, re_pos) - self.assertEqual([f'INFO:{logname}:Wheel in cm units using X2 encoding'], cm.output) - - n = 56 # expected number of movements - self.assertTupleEqual(wheel_moves['intervals'].shape, (n, 2), - 'failed to return the correct number of intervals') - self.assertEqual(wheel_moves['peakAmplitude'].size, n) - self.assertEqual(wheel_moves['peakVelocity_times'].size, n) - - # Check the first 3 intervals - ints = np.array( - [[24.78462599, 25.22562599], - [29.58762599, 31.15062599], - [31.64262599, 31.81662599]]) - actual = wheel_moves['intervals'][:3, ] - self.assertIsNone(np.testing.assert_allclose(actual, ints), 'unexpected intervals') - - # Check amplitudes - actual = wheel_moves['peakAmplitude'][-3:] - expected = [0.50255486, -1.70103154, 1.00740789] - self.assertIsNone(np.testing.assert_allclose(actual, expected), 'unexpected amplitudes') - - # Check peak velocities - actual = wheel_moves['peakVelocity_times'][-3:] - expected = [175.13662599, 176.65762599, 178.57262599] - self.assertIsNone(np.testing.assert_allclose(actual, expected), 'peak times') - - # Test extraction in rad - re_pos = wh.cm_to_rad(re_pos) - with self.assertLogs(logname, level='INFO') as cm: - wheel_moves = ephys_fpga.extract_wheel_moves(re_ts, re_pos) - self.assertEqual([f'INFO:{logname}:Wheel in rad units using X2 encoding'], cm.output) - - # Check the first 3 intervals. As position thresholds are adjusted by units and - # encoding, we should expect the intervals to be identical to above - actual = wheel_moves['intervals'][:3, ] - self.assertIsNone(np.testing.assert_allclose(actual, ints), 'unexpected intervals') - - def test_movement_log(self): - """ - Integration test for inferring the units and decoding type for wheel data input for - extract_wheel_moves. Only expected to work for the default wheel diameter. - """ - ta = np.array([2, 4, 6, 8, 12, 14, 16, 18]) - pa = np.array([1, -1, 1, -1, 1, -1, 1, -1]) - tb = np.array([3, 5, 7, 9, 11, 13, 15, 17]) - pb = np.array([1, -1, 1, -1, 1, -1, 1, -1]) - logname = 'ibllib.io.extractors.training_wheel' - - for unit in ['cm', 'rad']: - for i in (1, 2, 4): - encoding = 'X' + str(i) - r = 3.1 if unit == 'cm' else 1 - # print(encoding, unit) - t, p = ephys_fpga._rotary_encoder_positions_from_fronts( - ta, pa, tb, pb, ticks=1024, coding=encoding.lower(), radius=r) - expected = f'INFO:{logname}:Wheel in {unit} units using {encoding} encoding' - with self.assertLogs(logname, level='INFO') as cm: - ephys_fpga.extract_wheel_moves(t, p) - self.assertEqual([expected], cm.output) - - def test_extract_first_movement_times(self): - test_data = self.test_data[1] - wheel_moves = ephys_fpga.extract_wheel_moves(test_data[0][0], test_data[0][1]) - first, is_final, ind = extract_first_movement_times(wheel_moves, self.trials) - np.testing.assert_allclose(first, [162.48462599, 105.62562599, np.nan]) - np.testing.assert_array_equal(is_final, [False, True, False]) - np.testing.assert_array_equal(ind, [46, 18]) - - -if __name__ == '__main__': - unittest.main(exit=False, verbosity=2) diff --git a/ibllib/tests/extractors/test_extractors.py b/ibllib/tests/extractors/test_extractors.py index 7bd58d4f0..437c43c07 100644 --- a/ibllib/tests/extractors/test_extractors.py +++ b/ibllib/tests/extractors/test_extractors.py @@ -1,3 +1,4 @@ +"""Test trials, wheel and camera extractors.""" import functools import shutil import tempfile @@ -47,10 +48,12 @@ def setUp(self): self.biased_lt5 = {'path': self.main_path / 'data' / 'session_biased_lt5'} self.training_ge5 = {'path': self.main_path / 'data' / 'session_training_ge5'} self.biased_ge5 = {'path': self.main_path / 'data' / 'session_biased_ge5'} + self.ephys = {'path': self.main_path / 'data' / 'session_ephys'} self.training_lt5['ntrials'] = len(raw.load_data(self.training_lt5['path'])) self.biased_lt5['ntrials'] = len(raw.load_data(self.biased_lt5['path'])) self.training_ge5['ntrials'] = len(raw.load_data(self.training_ge5['path'])) self.biased_ge5['ntrials'] = len(raw.load_data(self.biased_ge5['path'])) + self.ephys['ntrials'] = len(raw.load_data(self.ephys['path'])) # turn off logging for unit testing as we will purposedly go into warning/error cases self.wheel_ge5_path = self.main_path / 'data' / 'wheel_ge5' self.wheel_lt5_path = self.main_path / 'data' / 'wheel_lt5' @@ -144,6 +147,22 @@ def test_get_probabilityLeft(self): probs.append(0.5) self.assertTrue(sum([x in probs for x in pl]) == len(pl)) + # EPHYS SESSION + data = raw.load_data(self.ephys['path']) + md = raw.load_settings(self.ephys['path']) + *_, pLeft0, _ = biased_trials.ProbaContrasts( + self.ephys['path']).extract(bpod_trials=data, settings=md)[0] + self.assertEqual(len(pLeft0), self.ephys['ntrials'], 'ephys prob left') + # Test if only generative prob values in data + self.assertTrue(all(x in [0.2, 0.5, 0.8] for x in np.unique(pLeft0))) + # Test if settings file has empty LEN_DATA result is same + md.update({'LEN_BLOCKS': None}) + *_, pLeft1, _ = biased_trials.ProbaContrasts( + self.ephys['path']).extract(bpod_trials=data, settings=md)[0] + self.assertTrue(all(pLeft0 == pLeft1)) + # Test if only generative prob values in data + self.assertTrue(all(x in [0.2, 0.5, 0.8] for x in np.unique(pLeft1))) + def test_get_choice(self): # TRAINING SESSIONS choice = training_trials.Choice( @@ -761,5 +780,5 @@ def test_attribute_times(self, display=False): camera.attribute_times(tsa, tsb, injective=False, take='closest') -if __name__ == "__main__": +if __name__ == '__main__': unittest.main(exit=False, verbosity=2) From 276603eb1cb50197415897ea1795a806182ded68 Mon Sep 17 00:00:00 2001 From: Miles Wells Date: Tue, 12 Dec 2023 17:12:14 +0200 Subject: [PATCH 62/68] Added trainingPhaseChoiceWorld to task protocol extractor map; fix typo in ephysqc --- ibllib/ephys/ephysqc.py | 2 +- ibllib/io/extractors/task_extractor_map.json | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/ibllib/ephys/ephysqc.py b/ibllib/ephys/ephysqc.py index 483fe51f9..6b8607ce9 100644 --- a/ibllib/ephys/ephysqc.py +++ b/ibllib/ephys/ephysqc.py @@ -370,7 +370,7 @@ def _single_test(assertion, str_ok, str_ko): try: # note: tried to depend as little as possible on the extraction code but for the valve... extractor = ephys_fpga.FpgaTrials(ses_path) - bpod_intervals = extractor.get_bpod_event_times(sync, sync_map) + _, bpod_intervals = extractor.get_bpod_event_times(rawsync, sync_map) t_valve_open = bpod_intervals['valve_open'][:, 0] res = t_valve_open.size > 1 except AssertionError: diff --git a/ibllib/io/extractors/task_extractor_map.json b/ibllib/io/extractors/task_extractor_map.json index 22d0eebb0..2c3160265 100644 --- a/ibllib/io/extractors/task_extractor_map.json +++ b/ibllib/io/extractors/task_extractor_map.json @@ -1,5 +1,6 @@ {"ephysChoiceWorld": "EphysTrials", "_biasedChoiceWorld": "BiasedTrials", "_habituationChoiceWorld": "HabituationTrials", - "_trainingChoiceWorld": "TrainingTrials" + "_trainingChoiceWorld": "TrainingTrials", + "_trainingPhaseChoiceWorld": "TrainingTrials" } From f61781379653c179850c6f01a0553726c346d4e9 Mon Sep 17 00:00:00 2001 From: Miles Wells Date: Tue, 12 Dec 2023 17:41:38 +0200 Subject: [PATCH 63/68] remove prepare_experiment --- ibllib/io/session_params.py | 57 ------------------------------------- 1 file changed, 57 deletions(-) diff --git a/ibllib/io/session_params.py b/ibllib/io/session_params.py index 417838b55..300a5d6e3 100644 --- a/ibllib/io/session_params.py +++ b/ibllib/io/session_params.py @@ -473,60 +473,3 @@ def get_remote_stub_name(session_path, device_id=None): exp_ref = '{date}_{sequence:d}_{subject:s}'.format(**ConversionMixin.path2ref(session_path)) remote_filename = f'{exp_ref}@{device_id}.yaml' return session_path / '_devices' / remote_filename - - -def prepare_experiment(session_path, acquisition_description=None, local=None, remote=None, device_id=None, overwrite=False): - """ - Copy acquisition description yaml to the server and local transfers folder. - - Parameters - ---------- - session_path : str, pathlib.Path, pathlib.PurePath - The RELATIVE session path, e.g. subject/2020-01-01/001. - acquisition_description : dict - The data to write to the experiment.description.yaml file. - local : str, pathlib.Path - The path to the local session folders. - >>> C:\iblrigv8_data\cortexlab\Subjects # noqa - remote : str, pathlib.Path - The path to the remote server session folders. - >>> Y:\Subjects # noqa - device_id : str, optional - A device name, if None the TRANSFER_LABEL parameter is used (defaults to this device's - hostname with a unique numeric ID) - overwrite : bool - If true, overwrite any existing file with the new one, otherwise, update the existing file. - """ - if not acquisition_description: - return - - # Determine if user passed in arg for local/remote subject folder locations or pull in from - # local param file or prompt user if missing data. - if local is None or remote is None or device_id is None: - params = misc.create_basic_transfer_params(local_data_path=local, remote_data_path=remote, TRANSFER_LABEL=device_id) - local, device_id = (params['DATA_FOLDER_PATH'], params['TRANSFER_LABEL']) - # if the user provides False as an argument, it means the intent is to not copy anything, this - # won't be preserved by create_basic_transfer_params by default - remote = False if remote is False else params['REMOTE_DATA_FOLDER_PATH'] - - # This is in the docstring but still, if the session Path is absolute, we need to make it relative - if Path(session_path).is_absolute(): - session_path = Path(*session_path.parts[-3:]) - - # First attempt to copy to server - if remote is not False: - remote_session_path = Path(remote).joinpath(session_path) - remote_device_path = get_remote_stub_name(remote_session_path, device_id=device_id) - previous_description = read_params(remote_device_path) if remote_device_path.exists() and not overwrite else {} - try: - write_yaml(remote_device_path, merge_params(previous_description, acquisition_description)) - _logger.info(f'Written data to remote device at: {remote_device_path}.') - except Exception as ex: - _logger.warning(f'Failed to write data to remote device at: {remote_device_path}. \n {ex}') - - # then create on the local machine - filename = f'_ibl_experiment.description_{device_id}.yaml' - local_device_path = Path(local).joinpath(session_path, filename) - previous_description = read_params(local_device_path) if local_device_path.exists() and not overwrite else {} - write_yaml(local_device_path, merge_params(previous_description, acquisition_description)) - _logger.info(f'Written data to local session at : {local_device_path}.') From ebc9774b1ffc5d9bf5e85b0545663e506ed2bef7 Mon Sep 17 00:00:00 2001 From: Miles Wells Date: Wed, 13 Dec 2023 13:52:16 +0200 Subject: [PATCH 64/68] Release notes --- release_notes.md | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/release_notes.md b/release_notes.md index a7aab07df..29984fb50 100644 --- a/release_notes.md +++ b/release_notes.md @@ -1,6 +1,25 @@ -## Develop +## Release Notes 2.27 + +### features - Add full video wheel motion alignment code to ibllib.io.extractors.video_motion module - Change FPGA camera extractor to attempt wheel alignment if audio alignment fails +- Flexible FpgaTrials class allows subclassing for changes in hardware and task +- Task QC thresholds depend on sound card +- Extractor classes now return dicts instead of tuple +- Support extraction of habituationChoiceWorld with FPGA +- New IBLGlobusPatcher class allows safe and complete deletion of datasets + +### bugfixes +- Fix numpy version dependent error in io.extractors.camera.attribute_times +- Fix for habituationChoiceWorld stim off times occuring outside of trial intervals +- Improvements to Timeline trials extractor, especially for valve open times +- trainingPhaseChoiceWorld added to Bpod protocol extractor map fixture +- Last trial of FPGA sessions now correctly extracted + +### other +- Removed deprecated pyschofit module +- Deprecated oneibl.globus module in favour of one.remote.globus +- Deprecated qc.task_extractors in favour of behaviour pipeline tasks ## Release Notes 2.26 From 9fbc0f273af68854538eed7a7253ff3a210cc0f8 Mon Sep 17 00:00:00 2001 From: Mayo Faulkner Date: Wed, 13 Dec 2023 12:52:30 +0000 Subject: [PATCH 65/68] randomise session for histology --- ibllib/tests/qc/test_alignment_qc.py | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/ibllib/tests/qc/test_alignment_qc.py b/ibllib/tests/qc/test_alignment_qc.py index bb3e6a433..54bf8bd06 100644 --- a/ibllib/tests/qc/test_alignment_qc.py +++ b/ibllib/tests/qc/test_alignment_qc.py @@ -19,7 +19,6 @@ from ibllib.pipes.histology import register_track, register_chronic_track from one.registration import RegistrationClient - EPHYS_SESSION = 'b1c968ad-4874-468d-b2e4-5ffa9b9964e9' one = ONE(**TEST_DB) brain_atlas = AllenAtlas(25) @@ -37,7 +36,12 @@ class TestTracingQc(unittest.TestCase): def setUpClass(cls) -> None: probe = [''.join(random.choices(string.ascii_letters, k=5)), ''.join(random.choices(string.ascii_letters, k=5))] - ins = create_alyx_probe_insertions(session_path=EPHYS_SESSION, model='3B2', labels=probe, + date = str(datetime.date(2019, np.random.randint(1, 12), np.random.randint(1, 28))) + _, eid = RegistrationClient(one).create_new_session('ZM_1150', date=date) + cls.eid = str(eid) + # Currently the task protocol of a session must contain 'ephys' in order to create an insertion! + one.alyx.rest('sessions', 'partial_update', id=cls.eid, data={'task_protocol': 'ephys'}) + ins = create_alyx_probe_insertions(session_path=cls.eid, model='3B2', labels=probe, one=one, force=True) cls.probe00_id, cls.probe01_id = (x['id'] for x in ins) data = np.load(Path(Path(__file__).parent.parent. @@ -64,6 +68,7 @@ def test_tracing_not_exists(self): def tearDownClass(cls) -> None: one.alyx.rest('insertions', 'delete', id=cls.probe01_id) one.alyx.rest('insertions', 'delete', id=cls.probe00_id) + one.alyx.rest('sessions', 'delete', id=cls.eid) class TestChronicTracingQC(unittest.TestCase): @@ -73,12 +78,17 @@ def setUpClass(cls) -> None: serial = ''.join(random.choices(string.ascii_letters, k=10)) # Make a chronic insertions - ref = one.eid2ref(EPHYS_SESSION) - insdict = {"subject": ref['subject'], "name": probe, "model": '3B2', "serial": serial} + date = str(datetime.date(2019, np.random.randint(1, 12), np.random.randint(1, 28))) + _, eid = RegistrationClient(one).create_new_session('ZM_1150', date=date) + cls.eid = str(eid) + # Currently the task protocol of a session must contain 'ephys' in order to create an insertion! + one.alyx.rest('sessions', 'partial_update', id=cls.eid, data={'task_protocol': 'ephys'}) + + insdict = {"subject": 'ZM_1150', "name": probe, "model": '3B2', "serial": serial} ins = one.alyx.rest('chronic-insertions', 'create', data=insdict) cls.chronic_id = ins['id'] # Make a probe insertions - insdict = {"session": EPHYS_SESSION, "name": probe, "model": '3B2', "serial": serial, + insdict = {"session": cls.eid, "name": probe, "model": '3B2', "serial": serial, "chronic_insertion": cls.chronic_id} ins = one.alyx.rest('insertions', 'create', data=insdict) cls.probe_id = ins['id'] @@ -117,6 +127,7 @@ def test_tracing_not_exists(self): def tearDownClass(cls) -> None: one.alyx.rest('insertions', 'delete', id=cls.probe_id) one.alyx.rest('chronic-insertions', 'delete', id=cls.chronic_id) + one.alyx.rest('sessions', 'delete', id=cls.eid) class TestAlignmentQcExisting(unittest.TestCase): From 8454d8d455760d2de834dc0e93002fd1ce543e2e Mon Sep 17 00:00:00 2001 From: Mayo Faulkner Date: Wed, 13 Dec 2023 13:17:40 +0000 Subject: [PATCH 66/68] randomise all seeds --- ibllib/tests/qc/test_alignment_qc.py | 12 ++++++++---- ibllib/tests/qc/test_critical_reasons.py | 3 ++- ibllib/tests/test_oneibl.py | 3 ++- 3 files changed, 12 insertions(+), 6 deletions(-) diff --git a/ibllib/tests/qc/test_alignment_qc.py b/ibllib/tests/qc/test_alignment_qc.py index 54bf8bd06..e4b7a2ba0 100644 --- a/ibllib/tests/qc/test_alignment_qc.py +++ b/ibllib/tests/qc/test_alignment_qc.py @@ -34,9 +34,10 @@ class TestTracingQc(unittest.TestCase): @classmethod def setUpClass(cls) -> None: + rng = np.random.default_rng() probe = [''.join(random.choices(string.ascii_letters, k=5)), ''.join(random.choices(string.ascii_letters, k=5))] - date = str(datetime.date(2019, np.random.randint(1, 12), np.random.randint(1, 28))) + date = str(datetime.date(2019, rng.integers(1, 12), rng.integers(1, 28))) _, eid = RegistrationClient(one).create_new_session('ZM_1150', date=date) cls.eid = str(eid) # Currently the task protocol of a session must contain 'ephys' in order to create an insertion! @@ -74,11 +75,12 @@ def tearDownClass(cls) -> None: class TestChronicTracingQC(unittest.TestCase): @classmethod def setUpClass(cls) -> None: + rng = np.random.default_rng() probe = ''.join(random.choices(string.ascii_letters, k=5)) serial = ''.join(random.choices(string.ascii_letters, k=10)) # Make a chronic insertions - date = str(datetime.date(2019, np.random.randint(1, 12), np.random.randint(1, 28))) + date = str(datetime.date(2019, rng.integers(1, 12), rng.integers(1, 28))) _, eid = RegistrationClient(one).create_new_session('ZM_1150', date=date) cls.eid = str(eid) # Currently the task protocol of a session must contain 'ephys' in order to create an insertion! @@ -136,6 +138,7 @@ class TestAlignmentQcExisting(unittest.TestCase): @classmethod def setUpClass(cls) -> None: + rng = np.random.default_rng() data = np.load(Path(Path(__file__).parent.parent. joinpath('fixtures', 'qc', 'data_alignmentqc_existing.npz')), allow_pickle=True) @@ -148,7 +151,7 @@ def setUpClass(cls) -> None: insertion = data['insertion'].tolist() insertion['name'] = ''.join(random.choices(string.ascii_letters, k=5)) insertion['json'] = {'xyz_picks': cls.xyz_picks} - date = str(datetime.date(2019, np.random.randint(1, 12), np.random.randint(1, 28))) + date = str(datetime.date(2019, rng.integers(1, 12), rng.integers(1, 28))) _, eid = RegistrationClient(one).create_new_session('ZM_1150', date=date) cls.eid = str(eid) # Currently the task protocol of a session must contain 'ephys' in order to create an insertion! @@ -254,6 +257,7 @@ class TestAlignmentQcManual(unittest.TestCase): @classmethod def setUpClass(cls) -> None: + rng = np.random.default_rng() data = np.load(Path(Path(__file__).parent.parent. joinpath('fixtures', 'qc', 'data_alignmentqc_manual.npz')), allow_pickle=True) @@ -268,7 +272,7 @@ def setUpClass(cls) -> None: insertion['name'] = ''.join(random.choices(string.ascii_letters, k=5)) insertion['json'] = {'xyz_picks': cls.xyz_picks} - date = str(datetime.date(2018, np.random.randint(1, 12), np.random.randint(1, 28))) + date = str(datetime.date(2018, rng.integers(1, 12), rng.integers(1, 28))) _, eid = RegistrationClient(one).create_new_session('ZM_1150', date=date) cls.eid = str(eid) insertion['session'] = cls.eid diff --git a/ibllib/tests/qc/test_critical_reasons.py b/ibllib/tests/qc/test_critical_reasons.py index 4094ff468..934cacab1 100644 --- a/ibllib/tests/qc/test_critical_reasons.py +++ b/ibllib/tests/qc/test_critical_reasons.py @@ -28,10 +28,11 @@ def mock_input(prompt): class TestUserPmtSess(unittest.TestCase): def setUp(self) -> None: + rng = np.random.default_rng() # Make sure tests use correct session ID one.alyx.clear_rest_cache() # Create new session on database with a random date to avoid race conditions - date = str(datetime.date(2022, np.random.randint(1, 12), np.random.randint(1, 28))) + date = str(datetime.date(2022, rng.integers(1, 12), rng.integers(1, 28))) from one.registration import RegistrationClient _, eid = RegistrationClient(one).create_new_session('ZM_1150', date=date) eid = str(eid) diff --git a/ibllib/tests/test_oneibl.py b/ibllib/tests/test_oneibl.py index 9b7fb1049..bd24fc8bf 100644 --- a/ibllib/tests/test_oneibl.py +++ b/ibllib/tests/test_oneibl.py @@ -267,6 +267,7 @@ def test_task_names_extractors(self): class TestRegistration(unittest.TestCase): def setUp(self) -> None: + rng = np.random.default_rng() self.one = ONE(**TEST_DB, cache_rest=None) # makes sure tests start without session created eid = self.one.search(subject=SUBJECT, date_range='2018-04-01', query_type='remote') @@ -292,7 +293,7 @@ def setUp(self) -> None: except HTTPError: self.rev = self.one.alyx.rest('revisions', 'create', data={'name': self.revision}) # Create a new tag - tag_data = {'name': f'test_tag_{np.random.randint(0, 1e3)}', 'protected': True} + tag_data = {'name': f'test_tag_{rng.integers(0, 1e3)}', 'protected': True} self.tag = self.one.alyx.rest('tags', 'create', data=tag_data) def test_registration_datasets(self): From 7c2a821fbcdcdc446198b5dd21ea266b6a47f2f9 Mon Sep 17 00:00:00 2001 From: Miles Wells Date: Wed, 13 Dec 2023 15:51:03 +0200 Subject: [PATCH 67/68] Update release notes; remove unused time date range function --- ibllib/tests/qc/test_alignment_qc.py | 39 ++++++---- ibllib/tests/qc/test_base_qc.py | 22 +++++- ibllib/tests/test_oneibl.py | 106 +++++++++++++++------------ ibllib/tests/test_time.py | 13 ---- ibllib/time.py | 11 --- release_notes.md | 1 + 6 files changed, 103 insertions(+), 89 deletions(-) diff --git a/ibllib/tests/qc/test_alignment_qc.py b/ibllib/tests/qc/test_alignment_qc.py index e4b7a2ba0..331850b17 100644 --- a/ibllib/tests/qc/test_alignment_qc.py +++ b/ibllib/tests/qc/test_alignment_qc.py @@ -55,15 +55,15 @@ def test_tracing_exists(self): channels=False, brain_atlas=brain_atlas) insertion = one.alyx.get('/insertions/' + self.probe00_id, clobber=True) - assert (insertion['json']['qc'] == 'NOT_SET') - assert (insertion['json']['extended_qc']['tracing_exists'] == 1) + self.assertEqual(insertion['json']['qc'], 'NOT_SET') + self.assertEqual(insertion['json']['extended_qc']['tracing_exists'], 1) def test_tracing_not_exists(self): register_track(self.probe01_id, picks=None, one=one, overwrite=True, channels=False, brain_atlas=brain_atlas) insertion = one.alyx.get('/insertions/' + self.probe01_id, clobber=True) - assert (insertion['json']['qc'] == 'CRITICAL') - assert (insertion['json']['extended_qc']['tracing_exists'] == 0) + self.assertEqual(insertion['json']['qc'], 'CRITICAL') + self.assertEqual(insertion['json']['extended_qc']['tracing_exists'], 0) @classmethod def tearDownClass(cls) -> None: @@ -106,24 +106,24 @@ def test_tracing_exists(self): channels=False, brain_atlas=brain_atlas) insertion = one.alyx.get('/insertions/' + self.probe_id, clobber=True) - assert (insertion['json']['qc'] == 'NOT_SET') - assert (insertion['json']['extended_qc']['tracing_exists'] == 1) + self.assertEqual(insertion['json']['qc'], 'NOT_SET') + self.assertEqual(insertion['json']['extended_qc']['tracing_exists'], 1) insertion = one.alyx.get('/chronic-insertions/' + self.chronic_id, clobber=True) - assert (insertion['json']['qc'] == 'NOT_SET') - assert (insertion['json']['extended_qc']['tracing_exists'] == 1) + self.assertEqual(insertion['json']['qc'], 'NOT_SET') + self.assertEqual(insertion['json']['extended_qc']['tracing_exists'], 1) def test_tracing_not_exists(self): register_chronic_track(self.chronic_id, picks=None, one=one, overwrite=True, channels=False, brain_atlas=brain_atlas) insertion = one.alyx.get('/insertions/' + self.probe_id, clobber=True) - assert (insertion['json']['qc'] == 'CRITICAL') - assert (insertion['json']['extended_qc']['tracing_exists'] == 0) + self.assertEqual(insertion['json']['qc'], 'CRITICAL') + self.assertEqual(insertion['json']['extended_qc']['tracing_exists'], 0) insertion = one.alyx.get('/chronic-insertions/' + self.chronic_id, clobber=True) - assert (insertion['json']['qc'] == 'CRITICAL') - assert (insertion['json']['extended_qc']['tracing_exists'] == 0) + self.assertEqual(insertion['json']['qc'], 'CRITICAL') + self.assertEqual(insertion['json']['extended_qc']['tracing_exists'], 0) @classmethod def tearDownClass(cls) -> None: @@ -135,6 +135,10 @@ def tearDownClass(cls) -> None: class TestAlignmentQcExisting(unittest.TestCase): probe_id = None prev_traj_id = None + eid = None + alignments = None + xyz_picks = None + trajectory = None @classmethod def setUpClass(cls) -> None: @@ -254,6 +258,10 @@ def tearDownClass(cls) -> None: class TestAlignmentQcManual(unittest.TestCase): probe_id = None prev_traj_id = None + eid = None + alignments = None + xyz_picks = None + trajectory = None @classmethod def setUpClass(cls) -> None: @@ -390,6 +398,9 @@ def _verify(tc, alignment_resolved=None, alignment_count=None, class TestUploadToFlatIron(unittest.TestCase): probe_id = None + alignments = None + xyz_picks = None + trajectory = None @unittest.skip("Skip FTP upload test") @classmethod @@ -425,7 +436,7 @@ def setUpClass(cls) -> None: print(cls.file_paths) def test_data_content(self): - alf_path = one.path_from_eid(EPHYS_SESSION).joinpath('alf', self.probe_name) + alf_path = one.eid2path(EPHYS_SESSION).joinpath('alf', self.probe_name) channels_mlapdv = np.load(alf_path.joinpath('channels.mlapdv.npy')) self.assertTrue(np.all(np.abs(channels_mlapdv) > 0)) channels_id = np.load(alf_path.joinpath('channels.brainLocationIds_ccf_2017.npy')) @@ -444,5 +455,5 @@ def tearDownClass(cls) -> None: one.alyx.rest('insertions', 'delete', id=cls.probe_id) -if __name__ == "__main__": +if __name__ == '__main__': unittest.main(exit=False, verbosity=2) diff --git a/ibllib/tests/qc/test_base_qc.py b/ibllib/tests/qc/test_base_qc.py index e56750c64..b1eb5b6a4 100644 --- a/ibllib/tests/qc/test_base_qc.py +++ b/ibllib/tests/qc/test_base_qc.py @@ -1,18 +1,30 @@ import unittest from unittest import mock +import random import numpy as np +from ibllib.tests import TEST_DB from ibllib.qc.base import QC from one.api import ONE -from ibllib.tests import TEST_DB +from one.registration import RegistrationClient one = ONE(**TEST_DB) class TestQC(unittest.TestCase): + """Test base QC class.""" + + eid = None + """str: An experiment UUID to use for updating QC fields.""" + + @classmethod + def setUpClass(cls): + date = f'20{random.randint(0, 30):02}-{random.randint(1, 12):02}-{random.randint(1, 28):02}' + _, eid = RegistrationClient(one).create_new_session('ZM_1150', date=date) + cls.eid = str(eid) + def setUp(self) -> None: - self.eid = 'b1c968ad-4874-468d-b2e4-5ffa9b9964e9' ses = one.alyx.rest('sessions', 'partial_update', id=self.eid, data={'qc': 'NOT_SET'}) assert ses['qc'] == 'NOT_SET', 'failed to reset qc field for test' extended = one.alyx.json_field_write('sessions', field_name='extended_qc', @@ -94,7 +106,7 @@ def test_update(self) -> None: self.qc.update('%INVALID%') def test_extended_qc(self) -> None: - """Test that the extended_qc JSON field is correctly updated""" + """Test that the extended_qc JSON field is correctly updated.""" current = one.alyx.rest('sessions', 'read', id=self.eid)['extended_qc'] data = {'_qc_test_foo': np.random.rand(), '_qc_test_bar': np.random.rand()} updated = self.qc.update_extended_qc(data) @@ -135,6 +147,10 @@ def test_compute_outcome_from_extended_qc(self): self.qc.json = True self.assertEqual(self.qc.compute_outcome_from_extended_qc(), 'WARNING') + @classmethod + def tearDownClass(cls): + one.alyx.rest('sessions', 'delete', id=cls.eid) + if __name__ == '__main__': unittest.main(exit=False, verbosity=2) diff --git a/ibllib/tests/test_oneibl.py b/ibllib/tests/test_oneibl.py index bd24fc8bf..a92305268 100644 --- a/ibllib/tests/test_oneibl.py +++ b/ibllib/tests/test_oneibl.py @@ -18,7 +18,6 @@ from ibllib.oneibl import patcher, registration import ibllib.io.extractors.base -from ibllib import __version__ from ibllib.tests import TEST_DB from ibllib.io import session_params @@ -206,35 +205,19 @@ def test_dsets_2_path(self): self.assertIsInstance(testable, PurePosixPath) -SUBJECT = 'clns0730' -USER = 'test_user' - -md5_0 = 'add2ab27dbf8428f8140-0870d5080c7f' -r = {'created_by': 'olivier', - 'path': f'{SUBJECT}/2018-08-24/002', - 'filenames': ["raw_behavior_data/_iblrig_encoderTrialInfo.raw.ssv"], - 'hashes': [md5_0], - 'filesizes': [1234], - 'versions': [__version__]} - -MOCK_SESSION_SETTINGS = { - 'SESSION_DATE': '2018-04-01', - 'SESSION_DATETIME': '2018-04-01T12:48:26.795526', - 'PYBPOD_CREATOR': [USER, 'f092c2d5-c98a-45a1-be7c-df05f129a93c', 'local'], - 'SESSION_NUMBER': '002', - 'SUBJECT_NAME': SUBJECT, - 'PYBPOD_BOARD': '_iblrig_mainenlab_behavior_1', - 'PYBPOD_PROTOCOL': '_iblrig_tasks_ephysChoiceWorld', - 'IBLRIG_VERSION': '5.4.1', - 'SUBJECT_WEIGHT': 22, -} - -MOCK_SESSION_DICT = { - 'subject': SUBJECT, - 'start_time': '2018-04-01T12:48:26.795526', - 'number': 2, - 'users': [USER] -} +def get_mock_session_settings(subject='clns0730', user='test_user'): + """Create a basic session settings file for testing.""" + return { + 'SESSION_DATE': '2018-04-01', + 'SESSION_DATETIME': '2018-04-01T12:48:26.795526', + 'PYBPOD_CREATOR': [user, 'f092c2d5-c98a-45a1-be7c-df05f129a93c', 'local'], + 'SESSION_NUMBER': '002', + 'SUBJECT_NAME': subject, + 'PYBPOD_BOARD': '_iblrig_mainenlab_behavior_1', + 'PYBPOD_PROTOCOL': '_iblrig_tasks_ephysChoiceWorld', + 'IBLRIG_VERSION': '5.4.1', + 'SUBJECT_WEIGHT': 22, + } class TestRegistrationEndpoint(unittest.TestCase): @@ -266,15 +249,34 @@ def test_task_names_extractors(self): class TestRegistration(unittest.TestCase): + subject = '' + """str: The name of the subject under which to create sessions.""" + + one = None + """one.api.OneAlyx: An instance of ONE connected to a test database.""" + + @classmethod + def setUpClass(cls): + """Create a random new subject.""" + cls.one = ONE(**TEST_DB, cache_rest=None) + cls.subject = ''.join(random.choices(string.ascii_letters, k=10)) + cls.one.alyx.rest('subjects', 'create', data={'lab': 'mainenlab', 'nickname': cls.subject}) + def setUp(self) -> None: - rng = np.random.default_rng() - self.one = ONE(**TEST_DB, cache_rest=None) + self.settings = get_mock_session_settings(self.subject) + self.session_dict = { + 'subject': self.subject, + 'start_time': '2018-04-01T12:48:26.795526', + 'number': 2, + 'users': [self.settings['PYBPOD_CREATOR'][0]] + } + # makes sure tests start without session created - eid = self.one.search(subject=SUBJECT, date_range='2018-04-01', query_type='remote') + eid = self.one.search(subject=self.subject, date_range='2018-04-01', query_type='remote') for ei in eid: self.one.alyx.rest('sessions', 'delete', id=ei) self.td = tempfile.TemporaryDirectory() - self.session_path = Path(self.td.name).joinpath(SUBJECT, '2018-04-01', '002') + self.session_path = Path(self.td.name).joinpath(self.subject, '2018-04-01', '002') self.alf_path = self.session_path.joinpath('alf') self.alf_path.mkdir(parents=True) np.save(self.alf_path.joinpath('spikes.times.npy'), np.random.random(500)) @@ -293,12 +295,13 @@ def setUp(self) -> None: except HTTPError: self.rev = self.one.alyx.rest('revisions', 'create', data={'name': self.revision}) # Create a new tag - tag_data = {'name': f'test_tag_{rng.integers(0, 1e3)}', 'protected': True} + tag_name = 'test_tag_' + ''.join(random.choices(string.ascii_letters, k=5)) + tag_data = {'name': tag_name, 'protected': True} self.tag = self.one.alyx.rest('tags', 'create', data=tag_data) def test_registration_datasets(self): # registers a single file - ses = self.one.alyx.rest('sessions', 'create', data=MOCK_SESSION_DICT) + ses = self.one.alyx.rest('sessions', 'create', data=self.session_dict) st_file = self.alf_path.joinpath('spikes.times.npy') registration.register_dataset(file_list=st_file, one=self.one) dsets = self.one.alyx.rest('datasets', 'list', session=ses['url'][-36:]) @@ -383,7 +386,7 @@ def _write_settings_file(self): behavior_path.mkdir() settings_file = behavior_path.joinpath('_iblrig_taskSettings.raw.json') with open(settings_file, 'w') as fid: - json.dump(MOCK_SESSION_SETTINGS, fid) + json.dump(self.settings, fid) return settings_file def test_create_sessions(self): @@ -409,7 +412,7 @@ def test_registration_session(self): settings_file = self._write_settings_file() rc = registration.IBLRegistrationClient(one=self.one) rc.register_session(str(self.session_path)) - eid = self.one.search(subject=SUBJECT, date_range=['2018-04-01', '2018-04-01'], + eid = self.one.search(subject=self.subject, date_range=['2018-04-01', '2018-04-01'], query_type='remote')[0] datasets = self.one.alyx.rest('datasets', 'list', session=eid) for ds in datasets: @@ -421,21 +424,21 @@ def test_registration_session(self): self.assertTrue(ses_info['procedures'] == ['Ephys recording with acute probe(s)']) self.one.alyx.rest('sessions', 'delete', id=eid) # re-register the session as behaviour this time - MOCK_SESSION_SETTINGS['PYBPOD_PROTOCOL'] = '_iblrig_tasks_trainingChoiceWorld6.3.1' + self.settings['PYBPOD_PROTOCOL'] = '_iblrig_tasks_trainingChoiceWorld6.3.1' with open(settings_file, 'w') as fid: - json.dump(MOCK_SESSION_SETTINGS, fid) + json.dump(self.settings, fid) rc.register_session(self.session_path) - eid = self.one.search(subject=SUBJECT, date_range=['2018-04-01', '2018-04-01'], + eid = self.one.search(subject=self.subject, date_range=['2018-04-01', '2018-04-01'], query_type='remote')[0] ses_info = self.one.alyx.rest('sessions', 'read', id=eid) self.assertTrue(ses_info['procedures'] == ['Behavior training/tasks']) self.one.alyx.rest('sessions', 'delete', id=eid) # re-register the session as unknown protocol this time - MOCK_SESSION_SETTINGS['PYBPOD_PROTOCOL'] = 'gnagnagna' + self.settings['PYBPOD_PROTOCOL'] = 'gnagnagna' with open(settings_file, 'w') as fid: - json.dump(MOCK_SESSION_SETTINGS, fid) + json.dump(self.settings, fid) rc.register_session(self.session_path) - eid = self.one.search(subject=SUBJECT, date_range=['2018-04-01', '2018-04-01'], + eid = self.one.search(subject=self.subject, date_range=['2018-04-01', '2018-04-01'], query_type='remote')[0] ses_info = self.one.alyx.rest('sessions', 'read', id=eid) self.assertTrue(ses_info['procedures'] == []) @@ -459,9 +462,9 @@ def test_register_chained_session(self): session_params.write_params(self.session_path, experiment_description) with open(behaviour_paths[1].joinpath('_iblrig_taskSettings.raw.json'), 'w') as fid: - json.dump(MOCK_SESSION_SETTINGS, fid) + json.dump(self.settings, fid) - settings = MOCK_SESSION_SETTINGS.copy() + settings = self.settings.copy() settings['PYBPOD_PROTOCOL'] = '_iblrig_tasks_passiveChoiceWorld' start_time = (datetime.datetime.fromisoformat(settings['SESSION_DATETIME']) - datetime.timedelta(hours=1, minutes=2, seconds=12)) @@ -481,7 +484,7 @@ def test_register_chained_session(self): expected = '_iblrig_tasks_passiveChoiceWorld5.4.1/_iblrig_tasks_ephysChoiceWorld5.4.1' self.assertEqual(expected, ses_info['task_protocol']) # Test weightings created on Alyx - w = self.one.alyx.rest('subjects', 'read', id=SUBJECT)['weighings'] + w = self.one.alyx.rest('subjects', 'read', id=self.subject)['weighings'] self.assertEqual(2, len(w)) self.assertCountEqual({22.}, {x['weight'] for x in w}) weight_dates = {x['date_time'] for x in w} @@ -501,9 +504,16 @@ def tearDown(self) -> None: for rev in v1_rev: self.one.alyx.rest('revisions', 'delete', id=rev['name']) # Delete weighings - for w in self.one.alyx.rest('subjects', 'read', id=SUBJECT)['weighings']: + for w in self.one.alyx.rest('subjects', 'read', id=self.subject)['weighings']: self.one.alyx.rest('weighings', 'delete', id=w['url'].split('/')[-1]) + @classmethod + def tearDownClass(cls) -> None: + # Note: datasets deleted in cascade + for ses in cls.one.alyx.rest('sessions', 'list', subject=cls.subject, no_cache=True): + cls.one.alyx.rest('sessions', 'delete', id=ses['url'][-36:]) + cls.one.alyx.rest('subjects', 'delete', id=cls.subject) + if __name__ == '__main__': unittest.main() diff --git a/ibllib/tests/test_time.py b/ibllib/tests/test_time.py index 7df999145..43cd694d8 100644 --- a/ibllib/tests/test_time.py +++ b/ibllib/tests/test_time.py @@ -6,19 +6,6 @@ class TestUtils(unittest.TestCase): - def test_format_date_range(self): - date_range = ['2018-03-01', '2018-03-24'] - date_range_out = ['2018-03-01', '2018-03-24'] - # test the string input - self.assertTrue(ibllib.time.format_date_range(date_range) == date_range_out) - # test the date input - date_range = [datetime.datetime.strptime(d, '%Y-%m-%d') for d in date_range] - self.assertTrue(ibllib.time.format_date_range(date_range) == date_range_out) - # test input validation - date_range[-1] = date_range_out[-1] # noqa [datetime, str] - with self.assertRaises(ValueError): - ibllib.time.format_date_range(date_range) - def test_isostr2date(self): # test the full string a = ibllib.time.isostr2date('2018-03-01T12:34:56.99999') diff --git a/ibllib/time.py b/ibllib/time.py index 5776666c6..95aa0cf61 100644 --- a/ibllib/time.py +++ b/ibllib/time.py @@ -32,17 +32,6 @@ def date2isostr(adate): return datetime.datetime.isoformat(adate) -def format_date_range(date_range): - if all([isinstance(d, str) for d in date_range]): - date_range = [datetime.datetime.strptime(d, '%Y-%m-%d') for d in date_range] - elif not all([isinstance(d, datetime.date) for d in date_range]): - raise ValueError('Date range doesn''t have proper format: list of 2 strings "yyyy-mm-dd" ') - # the django filter is implemented in datetime and assumes the beginning of the day (last day - # is excluded by default - date_range = [d.strftime('%Y-%m-%d') for d in date_range] - return date_range - - def convert_pgts(time): """Convert PointGray cameras timestamps to seconds. Use convert then uncycle""" diff --git a/release_notes.md b/release_notes.md index 29984fb50..569aef684 100644 --- a/release_notes.md +++ b/release_notes.md @@ -15,6 +15,7 @@ - Improvements to Timeline trials extractor, especially for valve open times - trainingPhaseChoiceWorld added to Bpod protocol extractor map fixture - Last trial of FPGA sessions now correctly extracted +- Correct dynamic pipeline extraction of passive choice world trials ### other - Removed deprecated pyschofit module From 934f9da77fc41685193262394d8f06e0e29d61e8 Mon Sep 17 00:00:00 2001 From: Miles Wells Date: Wed, 13 Dec 2023 16:21:48 +0200 Subject: [PATCH 68/68] Remove double import; delete session notes first on teardown --- ibllib/tests/qc/test_critical_reasons.py | 13 ++++++------- ibllib/tests/test_ephys.py | 1 - 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/ibllib/tests/qc/test_critical_reasons.py b/ibllib/tests/qc/test_critical_reasons.py index 934cacab1..038eaf52f 100644 --- a/ibllib/tests/qc/test_critical_reasons.py +++ b/ibllib/tests/qc/test_critical_reasons.py @@ -4,14 +4,14 @@ import random import string import datetime -import numpy as np +import numpy as np import requests from one.api import ONE +from one.registration import RegistrationClient from ibllib.tests import TEST_DB import ibllib.qc.critical_reasons as usrpmt -from one.registration import RegistrationClient one = ONE(**TEST_DB) @@ -33,7 +33,6 @@ def setUp(self) -> None: one.alyx.clear_rest_cache() # Create new session on database with a random date to avoid race conditions date = str(datetime.date(2022, rng.integers(1, 12), rng.integers(1, 28))) - from one.registration import RegistrationClient _, eid = RegistrationClient(one).create_new_session('ZM_1150', date=date) eid = str(eid) # Currently the task protocol of a session must contain 'ephys' in order to create an insertion! @@ -144,6 +143,10 @@ def tearDown(self) -> None: except requests.HTTPError as ex: if ex.errno != 404: raise ex + + notes = one.alyx.rest('notes', 'list', django=f'object_id,{self.sess_id}', no_cache=True) + for n in notes: + one.alyx.rest('notes', 'delete', id=n['id']) text = '"title": "=== EXPERIMENTER REASON(S)' notes = one.alyx.rest('notes', 'list', django=f'text__icontains,{text}', no_cache=True) for n in notes: @@ -153,10 +156,6 @@ def tearDown(self) -> None: for n in notes: one.alyx.rest('notes', 'delete', n['id']) - note = one.alyx.rest('notes', 'list', django=f'object_id,{self.sess_id}', no_cache=True) - for no in note: - one.alyx.rest('notes', 'delete', id=no['id']) - class TestSignOffNote(unittest.TestCase): def setUp(self) -> None: diff --git a/ibllib/tests/test_ephys.py b/ibllib/tests/test_ephys.py index dfd368c25..69267a347 100644 --- a/ibllib/tests/test_ephys.py +++ b/ibllib/tests/test_ephys.py @@ -140,7 +140,6 @@ def tearDownClass(cls) -> None: cls.tempdir.cleanup() def setUp(self) -> None: - self.eid = 'b1c968ad-4874-468d-b2e4-5ffa9b9964e9' # make a temp probe insertion self.pname = 'probe02'