diff --git a/brainbox/io/one.py b/brainbox/io/one.py index c9b25a778..ee55ea9d0 100644 --- a/brainbox/io/one.py +++ b/brainbox/io/one.py @@ -866,13 +866,21 @@ def _get_attributes(dataset_types): waveform_attributes = list(set(WAVEFORMS_ATTRIBUTES + waveform_attributes)) return {'spikes': spike_attributes, 'clusters': cluster_attributes, 'waveforms': waveform_attributes} - def _get_spike_sorting_collection(self, spike_sorter='pykilosort'): + def _get_spike_sorting_collection(self, spike_sorter=None): """ Filters a list or array of collections to get the relevant spike sorting dataset if there is a pykilosort, load it """ - collection = next(filter(lambda c: c == f'alf/{self.pname}/{spike_sorter}', self.collections), None) - # otherwise, prefers the shortest + for sorter in list([spike_sorter, 'iblsorter', 'pykilosort']): + if sorter is None: + continue + if sorter == "": + collection = next(filter(lambda c: c == f'alf/{self.pname}', self.collections), None) + else: + collection = next(filter(lambda c: c == f'alf/{self.pname}/{sorter}', self.collections), None) + if collection is not None: + return collection + # if none is found amongst the defaults, prefers the shortest collection = collection or next(iter(sorted(filter(lambda c: f'alf/{self.pname}' in c, self.collections), key=len)), None) _logger.debug(f"selecting: {collection} to load amongst candidates: {self.collections}") return collection @@ -982,14 +990,13 @@ def download_raw_waveforms(self, **kwargs): """ _logger.debug(f"loading waveforms from {self.collection}") return self.one.load_object( - self.eid, "waveforms", - attribute=["traces", "templates", "table", "channels"], + id=self.eid, obj="waveforms", attribute=["traces", "templates", "table", "channels"], collection=self._get_spike_sorting_collection("pykilosort"), download_only=True, **kwargs ) def raw_waveforms(self, **kwargs): wf_paths = self.download_raw_waveforms(**kwargs) - return WaveformsLoader(wf_paths[0].parent, wfs_dtype=np.float16) + return WaveformsLoader(wf_paths[0].parent) def load_channels(self, **kwargs): """ @@ -1022,7 +1029,7 @@ def load_channels(self, **kwargs): self.histology = 'alf' return Bunch(channels) - def load_spike_sorting(self, spike_sorter='pykilosort', revision=None, enforce_version=True, good_units=False, **kwargs): + def load_spike_sorting(self, spike_sorter='iblsorter', revision=None, enforce_version=False, good_units=False, **kwargs): """ Loads spikes, clusters and channels diff --git a/brainbox/io/spikeglx.py b/brainbox/io/spikeglx.py index 9c0618c11..fff72c5f2 100644 --- a/brainbox/io/spikeglx.py +++ b/brainbox/io/spikeglx.py @@ -128,6 +128,7 @@ def __init__(self, pid, one, typ='ap', cache_folder=None, remove_cached=False): self.file_chunks = self.one.load_dataset(self.eid, f'*.{typ}.ch', collection=f"*{self.pname}") meta_file = self.one.load_dataset(self.eid, f'*.{typ}.meta', collection=f"*{self.pname}") cbin_rec = self.one.list_datasets(self.eid, collection=f"*{self.pname}", filename=f'*{typ}.*bin', details=True) + cbin_rec.index = cbin_rec.index.map(lambda x: (self.eid, x)) self.url_cbin = self.one.record2url(cbin_rec)[0] with open(self.file_chunks, 'r') as f: self.chunks = json.load(f) diff --git a/ibllib/ephys/sync_probes.py b/ibllib/ephys/sync_probes.py index 3f3411479..54106b245 100644 --- a/ibllib/ephys/sync_probes.py +++ b/ibllib/ephys/sync_probes.py @@ -47,7 +47,7 @@ def sync(ses_path, **kwargs): return version3B(ses_path, **kwargs) -def version3A(ses_path, display=True, type='smooth', tol=2.1): +def version3A(ses_path, display=True, type='smooth', tol=2.1, probe_names=None): """ From a session path with _spikeglx_sync arrays extracted, locate ephys files for 3A and outputs one sync.timestamps.probeN.npy file per acquired probe. By convention the reference diff --git a/ibllib/oneibl/data_handlers.py b/ibllib/oneibl/data_handlers.py index ba713babb..b0c40c735 100644 --- a/ibllib/oneibl/data_handlers.py +++ b/ibllib/oneibl/data_handlers.py @@ -21,7 +21,7 @@ from iblutil.util import flatten, ensure_list from ibllib.oneibl.registration import register_dataset, get_lab, get_local_data_repository -from ibllib.oneibl.patcher import FTPPatcher, SDSCPatcher, SDSC_ROOT_PATH, SDSC_PATCH_PATH +from ibllib.oneibl.patcher import FTPPatcher, SDSCPatcher, SDSC_ROOT_PATH, SDSC_PATCH_PATH, S3Patcher _logger = logging.getLogger(__name__) @@ -747,6 +747,38 @@ def cleanUp(self): os.unlink(file) +class RemoteEC2DataHandler(DataHandler): + def __init__(self, session_path, signature, one=None): + """ + Data handler for running tasks on remote compute node. Will download missing data via http using ONE + + :param session_path: path to session + :param signature: input and output file signatures + :param one: ONE instance + """ + super().__init__(session_path, signature, one=one) + + def setUp(self): + """ + Function to download necessary data to run tasks using ONE + :return: + """ + df = super().getData() + self.one._check_filesystem(df) + + def uploadData(self, outputs, version, **kwargs): + """ + Function to upload and register data of completed task via S3 patcher + :param outputs: output files from task to register + :param version: ibllib version + :return: output info of registered datasets + """ + versions = super().uploadData(outputs, version) + s3_patcher = S3Patcher(one=self.one) + return s3_patcher.patch_dataset(outputs, created_by=self.one.alyx.user, + versions=versions, **kwargs) + + class RemoteHttpDataHandler(DataHandler): def __init__(self, session_path, signature, one=None): """ diff --git a/ibllib/oneibl/patcher.py b/ibllib/oneibl/patcher.py index 3738d7bcf..22f682df4 100644 --- a/ibllib/oneibl/patcher.py +++ b/ibllib/oneibl/patcher.py @@ -34,13 +34,13 @@ import globus_sdk import iblutil.io.params as iopar from iblutil.util import ensure_list -from one.alf.files import get_session_path, add_uuid_string +from one.alf.files import get_session_path, add_uuid_string, full_path_parts from one.alf.spec import is_uuid_string, is_uuid from one import params from one.webclient import AlyxClient from one.converters import path_from_dataset from one.remote import globus -from one.remote.aws import url2uri +from one.remote.aws import url2uri, get_s3_from_alyx from ibllib.oneibl.registration import register_dataset @@ -633,3 +633,55 @@ def _scp(self, local_path, remote_path, dry=True): def _rm(self, flatiron_path, dry=True): raise PermissionError("This Patcher does not have admin permissions to remove data " "from the FlatIron server") + + +class S3Patcher(Patcher): + + def __init__(self, one=None): + assert one + super().__init__(one=one) + self.s3_repo = 's3_patcher' + self.s3_path = 'patcher' + + # Instantiate boto connection + self.s3, self.bucket = get_s3_from_alyx(self.one.alyx, repo_name=self.s3_repo) + + def check_datasets(self, file_list): + # Here we want to check if the datasets exist, if they do we don't want to patch unless we force. + exists = [] + for file in file_list: + collection = full_path_parts(file, as_dict=True)['collection'] + dset = self.one.alyx.rest('datasets', 'list', session=self.one.path2eid(file), name=file.name, + collection=collection, clobber=True) + if len(dset) > 0: + exists.append(file) + + return exists + + def patch_dataset(self, file_list, dry=False, ftp=False, force=False, **kwargs): + + exists = self.check_datasets(file_list) + if len(exists) > 0 and not force: + _logger.error(f'Files: {", ".join([f.name for f in file_list])} already exist, to force set force=True') + return + + response = super().patch_dataset(file_list, dry=dry, repository=self.s3_repo, ftp=False, **kwargs) + # TODO in an ideal case the flatiron filerecord won't be altered when we register this dataset. This requires + # changing the the alyx.data.register_view + for ds in response: + frs = ds['file_records'] + fr_server = next(filter(lambda fr: 'flatiron' in fr['data_repository'], frs)) + # Update the flatiron file record to be false + self.one.alyx.rest('files', 'partial_update', id=fr_server['id'], + data={'exists': False}) + + def _scp(self, local_path, remote_path, dry=True): + + aws_remote_path = Path(self.s3_path).joinpath(remote_path.relative_to(FLATIRON_MOUNT)) + _logger.info(f'Transferring file {local_path} to {aws_remote_path}') + self.s3.Bucket(self.bucket).upload_file(str(PurePosixPath(local_path)), str(PurePosixPath(aws_remote_path))) + + return 0, '' + + def _rm(self, *args, **kwargs): + raise PermissionError("This Patcher does not have admin permissions to remove data.") diff --git a/ibllib/pipes/dynamic_pipeline.py b/ibllib/pipes/dynamic_pipeline.py index 5c2fc224f..c2c41bd03 100644 --- a/ibllib/pipes/dynamic_pipeline.py +++ b/ibllib/pipes/dynamic_pipeline.py @@ -490,8 +490,6 @@ def make_pipeline(session_path, **pkwargs): tasks[f'RawEphysQC_{pname}'] = type(f'RawEphysQC_{pname}', (etasks.RawEphysQC,), {})( **kwargs, **ephys_kwargs, pname=pname, parents=register_task) - tasks[f'EphysCellQC_{pname}'] = type(f'EphysCellQC_{pname}', (etasks.EphysCellsQc,), {})( - **kwargs, **ephys_kwargs, pname=pname, parents=[tasks[f'Spikesorting_{pname}']]) # Video tasks if 'cameras' in devices: diff --git a/ibllib/pipes/ephys_tasks.py b/ibllib/pipes/ephys_tasks.py index 4d794b19f..8596a1619 100644 --- a/ibllib/pipes/ephys_tasks.py +++ b/ibllib/pipes/ephys_tasks.py @@ -355,9 +355,9 @@ class EphysSyncPulses(SyncPulses): @property def signature(self): signature = { - 'input_files': [('*nidq.cbin', self.sync_collection, True), + 'input_files': [('*nidq.cbin', self.sync_collection, False), ('*nidq.ch', self.sync_collection, False), - ('*nidq.meta', self.sync_collection, True), + ('*nidq.meta', self.sync_collection, False), ('*nidq.wiring.json', self.sync_collection, True)], 'output_files': [('_spikeglx_sync.times.npy', self.sync_collection, True), ('_spikeglx_sync.polarities.npy', self.sync_collection, True), @@ -393,13 +393,19 @@ def __init__(self, *args, **kwargs): @property def signature(self): signature = { - 'input_files': [('*ap.meta', f'{self.device_collection}/{pname}', True) for pname in self.pname] + - [('*ap.cbin', f'{self.device_collection}/{pname}', True) for pname in self.pname] + - [('*ap.ch', f'{self.device_collection}/{pname}', True) for pname in self.pname] + - [('*ap.wiring.json', f'{self.device_collection}/{pname}', False) for pname in self.pname] + - [('_spikeglx_sync.times.npy', self.sync_collection, True), - ('_spikeglx_sync.polarities.npy', self.sync_collection, True), - ('_spikeglx_sync.channels.npy', self.sync_collection, True)], + 'input_files': + [('*ap.meta', f'{self.device_collection}/{pname}', True) for pname in self.pname] + + [('*ap.cbin', f'{self.device_collection}/{pname}', True) for pname in self.pname] + + [('*ap.ch', f'{self.device_collection}/{pname}', True) for pname in self.pname] + + [('*ap.wiring.json', f'{self.device_collection}/{pname}', False) for pname in self.pname] + + [('_spikeglx_sync.times.*npy', f'{self.device_collection}/{pname}', False) for pname in self.pname] + + [('_spikeglx_sync.polarities.*npy', f'{self.device_collection}/{pname}', False) for pname in self.pname] + + [('_spikeglx_sync.channels.*npy', f'{self.device_collection}/{pname}', False) for pname in self.pname] + + [('_spikeglx_sync.times.*npy', self.sync_collection, True), + ('_spikeglx_sync.polarities.*npy', self.sync_collection, True), + ('_spikeglx_sync.channels.*npy', self.sync_collection, True), + ('*ap.meta', self.sync_collection, True) + ], 'output_files': [(f'_spikeglx_sync.times.{pname}.npy', f'{self.device_collection}/{pname}', True) for pname in self.pname] + [(f'_spikeglx_sync.polarities.{pname}.npy', f'{self.device_collection}/{pname}', True) @@ -517,8 +523,22 @@ def compute_cell_qc(folder_alf_probe): df_units = pd.concat( [df_units, ks2_labels['ks2_label'].reindex(df_units.index)], axis=1) # save as parquet file - df_units.to_parquet(folder_alf_probe.joinpath("clusters.metrics.pqt")) - return folder_alf_probe.joinpath("clusters.metrics.pqt"), df_units, drift + df_units.to_parquet(file_metrics := folder_alf_probe.joinpath("clusters.metrics.pqt")) + + assert np.all((df_units['bitwise_fail'] == 0) == (df_units['label'] == 1)) # useless but sanity check for OW + + cok = df_units['bitwise_fail'] == 0 + sok = cok[spikes['clusters']].values + spikes['templates'] = spikes['templates'].astype(np.uint16) + spikes['clusters'] = spikes['clusters'].astype(np.uint16) + spikes['depths'] = spikes['depths'].astype(np.float32) + spikes['amps'] = spikes['amps'].astype(np.float32) + file_passing = folder_alf_probe.joinpath('passingSpikes.table.pqt') + df_spikes = pd.DataFrame(spikes) + df_spikes = df_spikes.iloc[sok, :].reset_index(drop=True) + df_spikes.to_parquet(file_passing) + + return [file_metrics, file_passing], df_units, drift def _label_probe_qc(self, folder_probe, df_units, drift): """ @@ -564,26 +584,87 @@ class SpikeSorting(base_tasks.EphysTask, CellQCMixin): priority = 60 job_size = 'large' force = True - + env = 'iblsorter' + _sortername = 'iblsorter' SHELL_SCRIPT = Path.home().joinpath( - "Documents/PYTHON/iblscripts/deploy/serverpc/iblsorter/run_iblsorter.sh" + f"Documents/PYTHON/iblscripts/deploy/serverpc/{_sortername}/sort_recording.sh" ) SPIKE_SORTER_NAME = 'iblsorter' - PYKILOSORT_REPO = Path.home().joinpath('Documents/PYTHON/SPIKE_SORTING/ibl-sorter') + SORTER_REPOSITORY = Path.home().joinpath('Documents/PYTHON/SPIKE_SORTING/ibl-sorter') @property def signature(self): signature = { - 'input_files': [('*ap.meta', f'{self.device_collection}/{self.pname}', True), - ('*ap.*bin', f'{self.device_collection}/{self.pname}', True), - ('*ap.ch', f'{self.device_collection}/{self.pname}', False), - ('*sync.npy', f'{self.device_collection}/{self.pname}', True)], - 'output_files': [('spike_sorting_pykilosort.log', f'spike_sorters/pykilosort/{self.pname}', True), - ('_iblqc_ephysTimeRmsAP.rms.npy', f'{self.device_collection}/{self.pname}', True), - ('_iblqc_ephysTimeRmsAP.timestamps.npy', f'{self.device_collection}/{self.pname}', True)] + 'input_files': [ + ('*ap.meta', f'{self.device_collection}/{self.pname}', True), + ('*ap.*bin', f'{self.device_collection}/{self.pname}', True), + ('*ap.ch', f'{self.device_collection}/{self.pname}', False), + ('*sync.npy', f'{self.device_collection}/{self.pname}', True) + ], + 'output_files': [ + # ./raw_ephys_data/probe00/ + ('_iblqc_ephysTimeRmsAP.rms.npy', f'{self.device_collection}/{self.pname}/', True), + ('_iblqc_ephysTimeRmsAP.timestamps.npy', f'{self.device_collection}/{self.pname}/', True), + ('_iblqc_ephysSaturation.samples.npy', f'{self.device_collection}/{self.pname}/', True), + # ./spike_sorters/iblsorter/probe00 + ('spike_sorting_iblsorter.log', f'spike_sorters/{self._sortername}/{self.pname}', True), + ('_kilosort_raw.output.tar', f'spike_sorters/{self._sortername}/{self.pname}/', True), + # ./alf/probe00/iblsorter + ('_kilosort_whitening.matrix.npy', f'alf/{self.pname}/{self._sortername}/', True), + ('_phy_spikes_subset.channels.npy', f'alf/{self.pname}/{self._sortername}/', True), + ('_phy_spikes_subset.spikes.npy', f'alf/{self.pname}/{self._sortername}/', True), + ('_phy_spikes_subset.waveforms.npy', f'alf/{self.pname}/{self._sortername}/', True), + ('channels.labels.npy', f'alf/{self.pname}/{self._sortername}/', True), + ('channels.localCoordinates.npy', f'alf/{self.pname}/{self._sortername}/', True), + ('channels.rawInd.npy', f'alf/{self.pname}/{self._sortername}/', True), + ('clusters.amps.npy', f'alf/{self.pname}/{self._sortername}/', True), + ('clusters.channels.npy', f'alf/{self.pname}/{self._sortername}/', True), + ('clusters.depths.npy', f'alf/{self.pname}/{self._sortername}/', True), + ('clusters.metrics.pqt', f'alf/{self.pname}/{self._sortername}/', True), + ('clusters.peakToTrough.npy', f'alf/{self.pname}/{self._sortername}/', True), + ('clusters.uuids.csv', f'alf/{self.pname}/{self._sortername}/', True), + ('clusters.waveforms.npy', f'alf/{self.pname}/{self._sortername}/', True), + ('clusters.waveformsChannels.npy', f'alf/{self.pname}/{self._sortername}/', True), + ('drift.times.npy', f'alf/{self.pname}/{self._sortername}/', True), + ('drift.um.npy', f'alf/{self.pname}/{self._sortername}/', True), + ('drift_depths.um.npy', f'alf/{self.pname}/{self._sortername}/', True), + ('passingSpikes.table.pqt', f'alf/{self.pname}/{self._sortername}/', True), + ('spikes.amps.npy', f'alf/{self.pname}/{self._sortername}/', True), + ('spikes.clusters.npy', f'alf/{self.pname}/{self._sortername}/', True), + ('spikes.depths.npy', f'alf/{self.pname}/{self._sortername}/', True), + ('spikes.samples.npy', f'alf/{self.pname}/{self._sortername}/', True), + ('spikes.templates.npy', f'alf/{self.pname}/{self._sortername}/', True), + ('spikes.times.npy', f'alf/{self.pname}/{self._sortername}/', True), + ('templates.amps.npy', f'alf/{self.pname}/{self._sortername}/', True), + ('templates.waveforms.npy', f'alf/{self.pname}/{self._sortername}/', True), + ('templates.waveformsChannels.npy', f'alf/{self.pname}/{self._sortername}/', True), + ], } return signature + @property + def scratch_folder_run(self): + """ + Constructs a path to a temporary folder for the spike sorting output and scratch files + This is usually on a high performance drive, and we should factor around 2.5 times the uncompressed raw recording size + For a scratch drive at /mnt/h0 we would have the following temp dir: + /mnt/h0/iblsorter_1.8.0_CSHL071_2020-10-04_001_probe01/ + """ + # get the scratch drive from the shell script + if self.scratch_folder is None: + with open(self.SHELL_SCRIPT) as fid: + lines = fid.readlines() + line = [line for line in lines if line.startswith("SCRATCH_DRIVE=")][0] + m = re.search(r"\=(.*?)(\#|\n)", line)[0] + scratch_drive = Path(m[1:-1].strip()) + else: + scratch_drive = self.scratch_folder + assert scratch_drive.exists(), f"Scratch drive {scratch_drive} not found" + # get the version of the sorter + self.version = self._fetch_iblsorter_version(self.SORTER_REPOSITORY) + spikesorter_dir = f"{self.version}_{'_'.join(list(self.session_path.parts[-3:]))}_{self.pname}" + return scratch_drive.joinpath(spikesorter_dir) + @staticmethod def _sample2v(ap_file): md = spikeglx.read_meta_data(ap_file.with_suffix(".meta")) @@ -597,7 +678,7 @@ def _fetch_iblsorter_version(repo_path): return f"iblsorter_{iblsorter.__version__}" except ImportError: _logger.info('IBL-sorter not in environment, trying to locate the repository') - init_file = Path(repo_path).joinpath('ibl-sorter', '__init__.py') + init_file = Path(repo_path).joinpath('iblsorter', '__init__.py') try: with open(init_file) as fid: lines = fid.readlines() @@ -619,7 +700,7 @@ def _fetch_iblsorter_run_version(log_file): line = fid.readline() version = re.search('version (.*), output', line) version = version or re.search('version (.*)', line) # old versions have output, new have a version line - version = re.sub(r'\^\[{2}[0-9]+m', '', version.group(1)) # removes the coloring tags + version = re.sub(r'\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])', '', version.group(1)) return version def _run_iblsort(self, ap_file): @@ -629,9 +710,7 @@ def _run_iblsort(self, ap_file): (discontinued support for old spike sortings in the probe folder <1.5.5) :return: path of the folder containing ks2 spike sorting output """ - self.version = self._fetch_iblsorter_version(self.PYKILOSORT_REPO) - label = ap_file.parts[-2] # this is usually the probe name - sorter_dir = self.session_path.joinpath("spike_sorters", self.SPIKE_SORTER_NAME, label) + sorter_dir = self.session_path.joinpath("spike_sorters", self.SPIKE_SORTER_NAME, self.pname) self.FORCE_RERUN = False if not self.FORCE_RERUN: log_file = sorter_dir.joinpath(f"spike_sorting_{self.SPIKE_SORTER_NAME}.log") @@ -643,24 +722,15 @@ def _run_iblsort(self, ap_file): return sorter_dir else: self.FORCE_RERUN = True - # get the scratch drive from the shell script - with open(self.SHELL_SCRIPT) as fid: - lines = fid.readlines() - line = [line for line in lines if line.startswith("SCRATCH_DRIVE=")][0] - m = re.search(r"\=(.*?)(\#|\n)", line)[0] - scratch_drive = Path(m[1:-1].strip()) - assert scratch_drive.exists() - spikesorter_dir = f"{self.version}_{'_'.join(list(self.session_path.parts[-3:]))}_{self.pname}" - temp_dir = scratch_drive.joinpath(spikesorter_dir) - _logger.info(f"job progress command: tail -f {temp_dir} *.log") - temp_dir.mkdir(parents=True, exist_ok=True) + _logger.info(f"job progress command: tail -f {self.scratch_folder_run} *.log") + self.scratch_folder_run.mkdir(parents=True, exist_ok=True) check_nvidia_driver() try: # if pykilosort is in the environment, use the installed version within the task import iblsorter.ibl # noqa - iblsorter.ibl.run_spike_sorting_ibl(bin_file=ap_file, scratch_dir=temp_dir) + iblsorter.ibl.run_spike_sorting_ibl(bin_file=ap_file, scratch_dir=self.scratch_folder_run, delete=False) except ImportError: - command2run = f"{self.SHELL_SCRIPT} {ap_file} {temp_dir}" + command2run = f"{self.SHELL_SCRIPT} {ap_file} {self.scratch_folder_run}" _logger.info(command2run) process = subprocess.Popen( command2run, @@ -675,16 +745,13 @@ def _run_iblsort(self, ap_file): if process.returncode != 0: error_str = error.decode("utf-8").strip() # try and get the kilosort log if any - for log_file in temp_dir.rglob('*_kilosort.log'): + for log_file in self.scratch_folder_run.rglob('*_kilosort.log'): with open(log_file) as fid: log = fid.read() _logger.error(log) break raise RuntimeError(f"{self.SPIKE_SORTER_NAME} {info_str}, {error_str}") - - shutil.copytree(temp_dir.joinpath('output'), sorter_dir, dirs_exist_ok=True) - shutil.rmtree(temp_dir, ignore_errors=True) - + shutil.copytree(self.scratch_folder_run.joinpath('output'), sorter_dir, dirs_exist_ok=True) return sorter_dir def _run(self): @@ -698,34 +765,41 @@ def _run(self): """ efiles = spikeglx.glob_ephys_files(self.session_path.joinpath(self.device_collection, self.pname)) ap_files = [(ef.get("ap"), ef.get("label")) for ef in efiles if "ap" in ef.keys()] + assert len(ap_files) != 0, f"No ap file found for probe {self.session_path.joinpath(self.device_collection, self.pname)}" assert len(ap_files) == 1, f"Several bin files found for the same probe {ap_files}" ap_file, label = ap_files[0] out_files = [] - ks2_dir = self._run_iblsort(ap_file) # runs the sorter, skips if it already ran + sorter_dir = self._run_iblsort(ap_file) # runs the sorter, skips if it already ran + # convert the data to ALF in the ./alf/probeXX/SPIKE_SORTER_NAME folder probe_out_path = self.session_path.joinpath("alf", label, self.SPIKE_SORTER_NAME) shutil.rmtree(probe_out_path, ignore_errors=True) probe_out_path.mkdir(parents=True, exist_ok=True) ibllib.ephys.spikes.ks2_to_alf( - ks2_dir, + sorter_dir, bin_path=ap_file.parent, out_path=probe_out_path, bin_file=ap_file, ampfactor=self._sample2v(ap_file), ) - logfile = ks2_dir.joinpath(f"spike_sorting_{self.SPIKE_SORTER_NAME}.log") + logfile = sorter_dir.joinpath(f"spike_sorting_{self.SPIKE_SORTER_NAME}.log") if logfile.exists(): shutil.copyfile(logfile, probe_out_path.joinpath(f"_ibl_log.info_{self.SPIKE_SORTER_NAME}.log")) + # recover the QC files from the spike sorting output and copy them + for file_qc in sorter_dir.glob('_iblqc_*.npy'): + shutil.move(file_qc, file_qc_out := ap_file.parent.joinpath(file_qc.name)) + out_files.append(file_qc_out) # Sync spike sorting with the main behaviour clock: the nidq for 3B+ and the main probe for 3A out, _ = ibllib.ephys.spikes.sync_spike_sorting(ap_file=ap_file, out_path=probe_out_path) out_files.extend(out) # Now compute the unit metrics - self.compute_cell_qc(probe_out_path) + files_qc, df_units, drift = self.compute_cell_qc(probe_out_path) + out_files.extend(files_qc) # convert ks2_output into tar file and also register # Make this in case spike sorting is in old raw_ephys_data folders, for new # sessions it should already exist tar_dir = self.session_path.joinpath('spike_sorters', self.SPIKE_SORTER_NAME, label) tar_dir.mkdir(parents=True, exist_ok=True) - out = ibllib.ephys.spikes.ks2_to_tar(ks2_dir, tar_dir, force=self.FORCE_RERUN) + out = ibllib.ephys.spikes.ks2_to_tar(sorter_dir, tar_dir, force=self.FORCE_RERUN) out_files.extend(out) # run waveform extraction _logger.info("Running waveform extraction") @@ -733,28 +807,29 @@ def _run(self): clusters = alfio.load_object(probe_out_path, 'clusters', attribute=['channels']) channels = alfio.load_object(probe_out_path, 'channels') extract_wfs_cbin( - cbin_file=ap_file, + bin_file=ap_file, output_dir=probe_out_path, spike_samples=spikes['samples'], spike_clusters=spikes['clusters'], spike_channels=clusters['channels'][spikes['clusters']], - h=None, # todo the geometry needs to be set using the spikeglx object channel_labels=channels['labels'], max_wf=256, trough_offset=42, spike_length_samples=128, - chunksize_samples=int(3000), + chunksize_samples=int(30_000), n_jobs=None, wfs_dtype=np.float16, - preprocessing_steps=["phase_shift", - "bad_channel_interpolation", - "butterworth", - "car"] + preprocess_steps=["phase_shift", "bad_channel_interpolation", "butterworth", "car"], + scratch_dir=self.scratch_folder_run, ) + _logger.info(f"Cleaning up temporary folder {self.scratch_folder_run}") + shutil.rmtree(self.scratch_folder_run, ignore_errors=True) if self.one: eid = self.one.path2eid(self.session_path, query_type='remote') ins = self.one.alyx.rest('insertions', 'list', session=eid, name=label, query_type='remote') if len(ins) != 0: + _logger.info("Populating probe insertion with qc") + self._label_probe_qc(probe_out_path, df_units, drift) _logger.info("Creating SpikeSorting QC plots") plot_task = ApPlots(ins[0]['id'], session_path=self.session_path, one=self.one) _ = plot_task.run() @@ -772,39 +847,3 @@ def _run(self): out_files.extend(out) return out_files - - -class EphysCellsQc(base_tasks.EphysTask, CellQCMixin): - priority = 90 - job_size = 'small' - - @property - def signature(self): - signature = { - 'input_files': [('spikes.times.npy', f'alf/{self.pname}*', True), - ('spikes.clusters.npy', f'alf/{self.pname}*', True), - ('spikes.amps.npy', f'alf/{self.pname}*', True), - ('spikes.depths.npy', f'alf/{self.pname}*', True), - ('clusters.channels.npy', f'alf/{self.pname}*', True)], - 'output_files': [('clusters.metrics.pqt', f'alf/{self.pname}*', True)] - } - return signature - - def _run(self): - """ - Post spike-sorting quality control at the cluster level. - Outputs a QC table in the clusters ALF object and labels corresponding probes in Alyx - """ - files_spikes = Path(self.session_path).joinpath('alf', self.pname).rglob('spikes.times.npy') - folder_probes = [f.parent for f in files_spikes] - out_files = [] - for folder_probe in folder_probes: - try: - qc_file, df_units, drift = self.compute_cell_qc(folder_probe) - out_files.append(qc_file) - self._label_probe_qc(folder_probe, df_units, drift) - except Exception: - _logger.error(traceback.format_exc()) - self.status = -1 - continue - return out_files diff --git a/ibllib/pipes/tasks.py b/ibllib/pipes/tasks.py index 61125a635..14ead0fdd 100644 --- a/ibllib/pipes/tasks.py +++ b/ibllib/pipes/tasks.py @@ -114,7 +114,7 @@ class Task(abc.ABC): env = None # the environment name within which to run the task (NB: the env is not activated automatically!) def __init__(self, session_path, parents=None, taskid=None, one=None, - machine=None, clobber=True, location='server', **kwargs): + machine=None, clobber=True, location='server', scratch_folder=None, **kwargs): """ Base task class :param session_path: session path @@ -125,7 +125,8 @@ def __init__(self, session_path, parents=None, taskid=None, one=None, :param clobber: whether or not to overwrite log on rerun :param location: location where task is run. Options are 'server' (lab local servers'), 'remote' (remote compute node, data required for task downloaded via one), 'AWS' (remote compute node, data required for task downloaded via AWS), - or 'SDSC' (SDSC flatiron compute node) # TODO 'Globus' (remote compute node, data required for task downloaded via Globus) + or 'SDSC' (SDSC flatiron compute node) + :param scratch_folder: optional: Path where to write intermediate temporary data :param args: running arguments """ self.taskid = taskid @@ -141,6 +142,7 @@ def __init__(self, session_path, parents=None, taskid=None, one=None, self.clobber = clobber self.location = location self.plot_tasks = [] # Plotting task/ tasks to create plot outputs during the task + self.scratch_folder = scratch_folder self.kwargs = kwargs @property @@ -221,7 +223,7 @@ def run(self, **kwargs): if self.gpu >= 1: if not self._creates_lock(): self.status = -2 - _logger.info(f'Job {self.__class__} exited as a lock was found') + _logger.info(f'Job {self.__class__} exited as a lock was found at {self._lock_file_path()}') new_log = log_capture_string.getvalue() self.log = new_log if self.clobber else self.log + new_log _logger.removeHandler(ch) @@ -434,7 +436,7 @@ def assert_expected_outputs(self, raise_error=True): return everything_is_fine, files - def assert_expected_inputs(self, raise_error=True): + def assert_expected_inputs(self, raise_error=True, raise_ambiguous=False): """ Check that all the files necessary to run the task have been are present on disk. @@ -469,7 +471,7 @@ def assert_expected_inputs(self, raise_error=True): for k, v in variant_datasets.items() if any(v)} _logger.error('Ambiguous input datasets found: %s', ambiguous) - if raise_error or self.location == 'sdsc': # take no chances on SDSC + if raise_ambiguous or self.location == 'sdsc': # take no chances on SDSC # This could be mitigated if loading with data OneSDSC raise NotImplementedError( 'Multiple variant datasets found. Loading for these is undefined.') @@ -528,6 +530,8 @@ def get_data_handler(self, location=None): dhandler = data_handlers.SDSCDataHandler(self, self.session_path, self.signature, one=self.one) elif location == 'popeye': dhandler = data_handlers.PopeyeDataHandler(self, self.session_path, self.signature, one=self.one) + elif location == 'ec2': + dhandler = data_handlers.RemoteEC2DataHandler(self.session_path, self.signature, one=self.one) else: raise ValueError(f'Unknown location "{location}"') return dhandler diff --git a/ibllib/plots/figures.py b/ibllib/plots/figures.py index 384042add..14a6bb554 100644 --- a/ibllib/plots/figures.py +++ b/ibllib/plots/figures.py @@ -18,6 +18,7 @@ import one.alf.io as alfio from one.alf.exceptions import ALFObjectNotFound from ibllib.io.video import get_video_frame, url_from_eid +from ibllib.oneibl.data_handlers import ExpectedDataset import spikeglx import neuropixel from brainbox.plot import driftmap @@ -387,7 +388,6 @@ def get_probe_signature(self): def get_signatures(self, **kwargs): files_spikes = Path(self.session_path).joinpath('alf').rglob('spikes.times.npy') folder_probes = [f.parent for f in files_spikes] - full_input_files = [] for sig in self.signature['input_files']: for folder in folder_probes: @@ -396,8 +396,9 @@ def get_signatures(self, **kwargs): self.input_files = full_input_files else: self.input_files = self.signature['input_files'] - self.output_files = self.signature['output_files'] + self.input_files = [ExpectedDataset.input(*i) for i in self.input_files] + self.output_files = [ExpectedDataset.output(*i) for i in self.output_files] class BadChannelsAp(ReportSnapshotProbe): diff --git a/ibllib/tests/test_pipes.py b/ibllib/tests/test_pipes.py index 56ef51e68..9383d6dad 100644 --- a/ibllib/tests/test_pipes.py +++ b/ibllib/tests/test_pipes.py @@ -38,8 +38,8 @@ def test_task_queue(self, lab_repo_mock): lab_repo_mock.return_value = 'foo_repo' tasks = [ {'executable': 'ibllib.pipes.mesoscope_tasks.MesoscopePreprocess', 'priority': 80}, - {'executable': 'ibllib.pipes.ephys_tasks.SpikeSorting', 'priority': SpikeSorting.priority}, - {'executable': 'ibllib.pipes.base_tasks.RegisterRawDataTask', 'priority': RegisterRawDataTask.priority} + {'executable': 'ibllib.pipes.ephys_tasks.SpikeSorting', 'priority': SpikeSorting.priority}, # 60 + {'executable': 'ibllib.pipes.base_tasks.RegisterRawDataTask', 'priority': RegisterRawDataTask.priority} # 100 ] alyx = mock.Mock(spec=AlyxClient) alyx.rest.return_value = tasks @@ -49,10 +49,10 @@ def test_task_queue(self, lab_repo_mock): self.assertIn('foolab', alyx.rest.call_args.kwargs.get('django', '')) self.assertIn('foo_repo', alyx.rest.call_args.kwargs.get('django', '')) # Expect to return tasks in descending priority order, without mesoscope task (different env) - self.assertEqual([tasks[2], tasks[1]], queue) + self.assertEqual([tasks[2]], queue) # Expect only mesoscope task returned when relevant env passed - queue = local_server.task_queue(lab='foolab', alyx=alyx, env=('suite2p',)) - self.assertEqual([tasks[0]], queue) + queue = local_server.task_queue(lab='foolab', alyx=alyx, env=('suite2p', 'iblsorter')) + self.assertEqual([tasks[0], tasks[1]], queue) # Expect no tasks as mesoscope task is a large job queue = local_server.task_queue(mode='small', lab='foolab', alyx=alyx, env=('suite2p',)) self.assertEqual([], queue) diff --git a/requirements.txt b/requirements.txt index c6c7427e0..31cdb0898 100644 --- a/requirements.txt +++ b/requirements.txt @@ -26,7 +26,7 @@ iblatlas>=0.5.3 ibl-neuropixel>=1.0.1 iblutil>=1.13.0 mtscomp>=1.0.1 -ONE-api~=2.9.rc0 +ONE-api>=2.10 phylib>=2.6.0 psychofit slidingRP>=1.1.1 # steinmetz lab refractory period metrics