Skip to content

Commit

Permalink
Iblsort (#811)
Browse files Browse the repository at this point in the history
* remove deprecated spike sorting task

* flake

* add pointer to the lock file when the task fails

* spike sorting loader uses iblsorter as default collection

* set the waveform extraction chunk size to 30_000

* bugfix: reflect changes in arguments of waveform extraction

* assertion reversed fixed

* change entrypoint for spike sorting script

* remove ephys cell qc task from pipeline as it is part of spike sorting

* add iblsort environment to the spike sorting task

* typo

* test forcing subprocess for large_jobs

* Revert "test forcing subprocess for large_jobs"

This reverts commit 31ff95d.

* label probe qc if ONE instance

* add passingSpikes.pqt to spike sorting job - update task signature

* configure task to decompress cbin beforehand

* spike sorting loader with empty sorter name

* revert decompressing ap cbin file before sorting

* s3 patcher

* Fix bug in the make sorting plots task signature

* add ec2 datahandler

* update iblsorter task signatures

* clobber=True

* add make remotehttp

* add an optional scratch_folder argument to the task

* bugfix Streamer / ONE 2.10.1: pass the eid

* SpikeSorting task: recover raw data qc files

* fix task queue test

* fix streamer with one 2.10

* spike sorting loading: relax the version until the BWM is patched

---------

Co-authored-by: chris-langfield <[email protected]>
Co-authored-by: Mayo Faulkner <[email protected]>
  • Loading branch information
3 people authored Oct 28, 2024
1 parent cdf635d commit aedbdc8
Show file tree
Hide file tree
Showing 11 changed files with 252 additions and 118 deletions.
21 changes: 14 additions & 7 deletions brainbox/io/one.py
Original file line number Diff line number Diff line change
Expand Up @@ -866,13 +866,21 @@ def _get_attributes(dataset_types):
waveform_attributes = list(set(WAVEFORMS_ATTRIBUTES + waveform_attributes))
return {'spikes': spike_attributes, 'clusters': cluster_attributes, 'waveforms': waveform_attributes}

def _get_spike_sorting_collection(self, spike_sorter='pykilosort'):
def _get_spike_sorting_collection(self, spike_sorter=None):
"""
Filters a list or array of collections to get the relevant spike sorting dataset
if there is a pykilosort, load it
"""
collection = next(filter(lambda c: c == f'alf/{self.pname}/{spike_sorter}', self.collections), None)
# otherwise, prefers the shortest
for sorter in list([spike_sorter, 'iblsorter', 'pykilosort']):
if sorter is None:
continue
if sorter == "":
collection = next(filter(lambda c: c == f'alf/{self.pname}', self.collections), None)
else:
collection = next(filter(lambda c: c == f'alf/{self.pname}/{sorter}', self.collections), None)
if collection is not None:
return collection
# if none is found amongst the defaults, prefers the shortest
collection = collection or next(iter(sorted(filter(lambda c: f'alf/{self.pname}' in c, self.collections), key=len)), None)
_logger.debug(f"selecting: {collection} to load amongst candidates: {self.collections}")
return collection
Expand Down Expand Up @@ -982,14 +990,13 @@ def download_raw_waveforms(self, **kwargs):
"""
_logger.debug(f"loading waveforms from {self.collection}")
return self.one.load_object(
self.eid, "waveforms",
attribute=["traces", "templates", "table", "channels"],
id=self.eid, obj="waveforms", attribute=["traces", "templates", "table", "channels"],
collection=self._get_spike_sorting_collection("pykilosort"), download_only=True, **kwargs
)

def raw_waveforms(self, **kwargs):
wf_paths = self.download_raw_waveforms(**kwargs)
return WaveformsLoader(wf_paths[0].parent, wfs_dtype=np.float16)
return WaveformsLoader(wf_paths[0].parent)

def load_channels(self, **kwargs):
"""
Expand Down Expand Up @@ -1022,7 +1029,7 @@ def load_channels(self, **kwargs):
self.histology = 'alf'
return Bunch(channels)

def load_spike_sorting(self, spike_sorter='pykilosort', revision=None, enforce_version=True, good_units=False, **kwargs):
def load_spike_sorting(self, spike_sorter='iblsorter', revision=None, enforce_version=False, good_units=False, **kwargs):
"""
Loads spikes, clusters and channels
Expand Down
1 change: 1 addition & 0 deletions brainbox/io/spikeglx.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ def __init__(self, pid, one, typ='ap', cache_folder=None, remove_cached=False):
self.file_chunks = self.one.load_dataset(self.eid, f'*.{typ}.ch', collection=f"*{self.pname}")
meta_file = self.one.load_dataset(self.eid, f'*.{typ}.meta', collection=f"*{self.pname}")
cbin_rec = self.one.list_datasets(self.eid, collection=f"*{self.pname}", filename=f'*{typ}.*bin', details=True)
cbin_rec.index = cbin_rec.index.map(lambda x: (self.eid, x))
self.url_cbin = self.one.record2url(cbin_rec)[0]
with open(self.file_chunks, 'r') as f:
self.chunks = json.load(f)
Expand Down
2 changes: 1 addition & 1 deletion ibllib/ephys/sync_probes.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def sync(ses_path, **kwargs):
return version3B(ses_path, **kwargs)


def version3A(ses_path, display=True, type='smooth', tol=2.1):
def version3A(ses_path, display=True, type='smooth', tol=2.1, probe_names=None):
"""
From a session path with _spikeglx_sync arrays extracted, locate ephys files for 3A and
outputs one sync.timestamps.probeN.npy file per acquired probe. By convention the reference
Expand Down
34 changes: 33 additions & 1 deletion ibllib/oneibl/data_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from iblutil.util import flatten, ensure_list

from ibllib.oneibl.registration import register_dataset, get_lab, get_local_data_repository
from ibllib.oneibl.patcher import FTPPatcher, SDSCPatcher, SDSC_ROOT_PATH, SDSC_PATCH_PATH
from ibllib.oneibl.patcher import FTPPatcher, SDSCPatcher, SDSC_ROOT_PATH, SDSC_PATCH_PATH, S3Patcher


_logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -747,6 +747,38 @@ def cleanUp(self):
os.unlink(file)


class RemoteEC2DataHandler(DataHandler):
def __init__(self, session_path, signature, one=None):
"""
Data handler for running tasks on remote compute node. Will download missing data via http using ONE
:param session_path: path to session
:param signature: input and output file signatures
:param one: ONE instance
"""
super().__init__(session_path, signature, one=one)

def setUp(self):
"""
Function to download necessary data to run tasks using ONE
:return:
"""
df = super().getData()
self.one._check_filesystem(df)

def uploadData(self, outputs, version, **kwargs):
"""
Function to upload and register data of completed task via S3 patcher
:param outputs: output files from task to register
:param version: ibllib version
:return: output info of registered datasets
"""
versions = super().uploadData(outputs, version)
s3_patcher = S3Patcher(one=self.one)
return s3_patcher.patch_dataset(outputs, created_by=self.one.alyx.user,
versions=versions, **kwargs)


class RemoteHttpDataHandler(DataHandler):
def __init__(self, session_path, signature, one=None):
"""
Expand Down
56 changes: 54 additions & 2 deletions ibllib/oneibl/patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,13 @@
import globus_sdk
import iblutil.io.params as iopar
from iblutil.util import ensure_list
from one.alf.files import get_session_path, add_uuid_string
from one.alf.files import get_session_path, add_uuid_string, full_path_parts
from one.alf.spec import is_uuid_string, is_uuid
from one import params
from one.webclient import AlyxClient
from one.converters import path_from_dataset
from one.remote import globus
from one.remote.aws import url2uri
from one.remote.aws import url2uri, get_s3_from_alyx

from ibllib.oneibl.registration import register_dataset

Expand Down Expand Up @@ -633,3 +633,55 @@ def _scp(self, local_path, remote_path, dry=True):
def _rm(self, flatiron_path, dry=True):
raise PermissionError("This Patcher does not have admin permissions to remove data "
"from the FlatIron server")


class S3Patcher(Patcher):

def __init__(self, one=None):
assert one
super().__init__(one=one)
self.s3_repo = 's3_patcher'
self.s3_path = 'patcher'

# Instantiate boto connection
self.s3, self.bucket = get_s3_from_alyx(self.one.alyx, repo_name=self.s3_repo)

def check_datasets(self, file_list):
# Here we want to check if the datasets exist, if they do we don't want to patch unless we force.
exists = []
for file in file_list:
collection = full_path_parts(file, as_dict=True)['collection']
dset = self.one.alyx.rest('datasets', 'list', session=self.one.path2eid(file), name=file.name,
collection=collection, clobber=True)
if len(dset) > 0:
exists.append(file)

return exists

def patch_dataset(self, file_list, dry=False, ftp=False, force=False, **kwargs):

exists = self.check_datasets(file_list)
if len(exists) > 0 and not force:
_logger.error(f'Files: {", ".join([f.name for f in file_list])} already exist, to force set force=True')
return

response = super().patch_dataset(file_list, dry=dry, repository=self.s3_repo, ftp=False, **kwargs)
# TODO in an ideal case the flatiron filerecord won't be altered when we register this dataset. This requires
# changing the the alyx.data.register_view
for ds in response:
frs = ds['file_records']
fr_server = next(filter(lambda fr: 'flatiron' in fr['data_repository'], frs))
# Update the flatiron file record to be false
self.one.alyx.rest('files', 'partial_update', id=fr_server['id'],
data={'exists': False})

def _scp(self, local_path, remote_path, dry=True):

aws_remote_path = Path(self.s3_path).joinpath(remote_path.relative_to(FLATIRON_MOUNT))
_logger.info(f'Transferring file {local_path} to {aws_remote_path}')
self.s3.Bucket(self.bucket).upload_file(str(PurePosixPath(local_path)), str(PurePosixPath(aws_remote_path)))

return 0, ''

def _rm(self, *args, **kwargs):
raise PermissionError("This Patcher does not have admin permissions to remove data.")
2 changes: 0 additions & 2 deletions ibllib/pipes/dynamic_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,8 +490,6 @@ def make_pipeline(session_path, **pkwargs):

tasks[f'RawEphysQC_{pname}'] = type(f'RawEphysQC_{pname}', (etasks.RawEphysQC,), {})(
**kwargs, **ephys_kwargs, pname=pname, parents=register_task)
tasks[f'EphysCellQC_{pname}'] = type(f'EphysCellQC_{pname}', (etasks.EphysCellsQc,), {})(
**kwargs, **ephys_kwargs, pname=pname, parents=[tasks[f'Spikesorting_{pname}']])

# Video tasks
if 'cameras' in devices:
Expand Down
Loading

0 comments on commit aedbdc8

Please sign in to comment.