Skip to content

Commit

Permalink
Merge pull request #664 from int-brain-lab/uuidFilenames
Browse files Browse the repository at this point in the history
sdsc load data possible by monkey patching alfio in brainbox.io.one
  • Loading branch information
oliche authored Oct 25, 2023
2 parents 1712805 + c9850ec commit 94fdbda
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 8 deletions.
2 changes: 1 addition & 1 deletion .flake8
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[flake8]
max-line-length = 130
ignore = W504, W503, E266
ignore = W504, W503, E266, D, BLK
exclude =
.git,
__pycache__,
Expand Down
28 changes: 21 additions & 7 deletions brainbox/io/one.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@
import matplotlib.pyplot as plt

from one.api import ONE, One
import one.alf.io as alfio
from one.alf.files import get_alf_path
from one.alf.exceptions import ALFObjectNotFound
from one.alf import cache
import one.alf.io as alfio
from neuropixel import TIP_SIZE_UM, trace_header
import spikeglx

Expand Down Expand Up @@ -830,6 +830,20 @@ def __post_init__(self):
self.atlas = AllenAtlas()
self.files = {}

def _load_object(self, *args, **kwargs):
"""
This function is a wrapper around alfio.load_object that will remove the UUID in the
filename if the object is on SDSC.
"""
remove_uuids = getattr(self.one, 'uuid_filenames', False)
d = alfio.load_object(*args, **kwargs)
if remove_uuids:
# pops the UUID in the key names
keys = list(d.keys())
for k in keys:
d[k[:-37]] = d.pop(k)
return d

@staticmethod
def _get_attributes(dataset_types):
"""returns attributes to load for spikes and clusters objects"""
Expand Down Expand Up @@ -865,7 +879,7 @@ def load_spike_sorting_object(self, obj, *args, **kwargs):
:return:
"""
self.download_spike_sorting_object(obj, *args, **kwargs)
return alfio.load_object(self.files[obj])
return self._load_object(self.files[obj])

def download_spike_sorting_object(self, obj, spike_sorter='pykilosort', dataset_types=None, collection=None,
missing='raise', **kwargs):
Expand Down Expand Up @@ -922,10 +936,10 @@ def load_channels(self, **kwargs):
# we do not specify the spike sorter on purpose here: the electrode sites do not depend on the spike sorting
self.download_spike_sorting_object(obj='electrodeSites', collection=f'alf/{self.pname}', missing='ignore')
if 'electrodeSites' in self.files:
channels = alfio.load_object(self.files['electrodeSites'], wildcards=self.one.wildcards)
channels = self._load_object(self.files['electrodeSites'], wildcards=self.one.wildcards)
else: # otherwise, we try to load the channel object from the spike sorting folder - this may not contain histology
self.download_spike_sorting_object(obj='channels', **kwargs)
channels = alfio.load_object(self.files['channels'], wildcards=self.one.wildcards)
channels = self._load_object(self.files['channels'], wildcards=self.one.wildcards)
if 'brainLocationIds_ccf_2017' not in channels:
_logger.debug(f"loading channels from alyx for {self.files['channels']}")
_channels, self.histology = _load_channel_locations_traj(
Expand Down Expand Up @@ -960,8 +974,8 @@ def load_spike_sorting(self, spike_sorter='pykilosort', **kwargs):
self.spike_sorter = spike_sorter
self.download_spike_sorting(spike_sorter=spike_sorter, **kwargs)
channels = self.load_channels(spike_sorter=spike_sorter, **kwargs)
clusters = alfio.load_object(self.files['clusters'], wildcards=self.one.wildcards)
spikes = alfio.load_object(self.files['spikes'], wildcards=self.one.wildcards)
clusters = self._load_object(self.files['clusters'], wildcards=self.one.wildcards)
spikes = self._load_object(self.files['spikes'], wildcards=self.one.wildcards)

return spikes, clusters, channels

Expand Down Expand Up @@ -1090,7 +1104,7 @@ def raster(self, spikes, channels, save_dir=None, br=None, label='raster', time_

self.download_spike_sorting_object('drift', self.spike_sorter, missing='ignore')
if 'drift' in self.files:
drift = alfio.load_object(self.files['drift'], wildcards=self.one.wildcards)
drift = self._load_object(self.files['drift'], wildcards=self.one.wildcards)
axs[0, 0].plot(drift['times'], drift['um'], 'k', alpha=.5)

if save_dir is not None:
Expand Down

0 comments on commit 94fdbda

Please sign in to comment.