Skip to content

Commit

Permalink
Merge branch 'develop' into TrainingTaskQC
Browse files Browse the repository at this point in the history
  • Loading branch information
k1o0 committed Dec 11, 2023
2 parents 06d65ea + e0cac72 commit c848cea
Show file tree
Hide file tree
Showing 19 changed files with 1,760 additions and 442 deletions.
2 changes: 1 addition & 1 deletion .flake8
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[flake8]
max-line-length = 130
ignore = W504, W503, E266
ignore = W504, W503, E266, D, BLK
exclude =
.git,
__pycache__,
Expand Down
28 changes: 21 additions & 7 deletions brainbox/io/one.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@
import matplotlib.pyplot as plt

from one.api import ONE, One
import one.alf.io as alfio
from one.alf.files import get_alf_path
from one.alf.exceptions import ALFObjectNotFound
from one.alf import cache
import one.alf.io as alfio
from neuropixel import TIP_SIZE_UM, trace_header
import spikeglx

Expand Down Expand Up @@ -830,6 +830,20 @@ def __post_init__(self):
self.atlas = AllenAtlas()
self.files = {}

def _load_object(self, *args, **kwargs):
"""
This function is a wrapper around alfio.load_object that will remove the UUID in the
filename if the object is on SDSC.
"""
remove_uuids = getattr(self.one, 'uuid_filenames', False)
d = alfio.load_object(*args, **kwargs)
if remove_uuids:
# pops the UUID in the key names
keys = list(d.keys())
for k in keys:
d[k[:-37]] = d.pop(k)
return d

@staticmethod
def _get_attributes(dataset_types):
"""returns attributes to load for spikes and clusters objects"""
Expand Down Expand Up @@ -865,7 +879,7 @@ def load_spike_sorting_object(self, obj, *args, **kwargs):
:return:
"""
self.download_spike_sorting_object(obj, *args, **kwargs)
return alfio.load_object(self.files[obj])
return self._load_object(self.files[obj])

def download_spike_sorting_object(self, obj, spike_sorter='pykilosort', dataset_types=None, collection=None,
missing='raise', **kwargs):
Expand Down Expand Up @@ -922,10 +936,10 @@ def load_channels(self, **kwargs):
# we do not specify the spike sorter on purpose here: the electrode sites do not depend on the spike sorting
self.download_spike_sorting_object(obj='electrodeSites', collection=f'alf/{self.pname}', missing='ignore')
if 'electrodeSites' in self.files:
channels = alfio.load_object(self.files['electrodeSites'], wildcards=self.one.wildcards)
channels = self._load_object(self.files['electrodeSites'], wildcards=self.one.wildcards)
else: # otherwise, we try to load the channel object from the spike sorting folder - this may not contain histology
self.download_spike_sorting_object(obj='channels', **kwargs)
channels = alfio.load_object(self.files['channels'], wildcards=self.one.wildcards)
channels = self._load_object(self.files['channels'], wildcards=self.one.wildcards)
if 'brainLocationIds_ccf_2017' not in channels:
_logger.debug(f"loading channels from alyx for {self.files['channels']}")
_channels, self.histology = _load_channel_locations_traj(
Expand Down Expand Up @@ -960,8 +974,8 @@ def load_spike_sorting(self, spike_sorter='pykilosort', **kwargs):
self.spike_sorter = spike_sorter
self.download_spike_sorting(spike_sorter=spike_sorter, **kwargs)
channels = self.load_channels(spike_sorter=spike_sorter, **kwargs)
clusters = alfio.load_object(self.files['clusters'], wildcards=self.one.wildcards)
spikes = alfio.load_object(self.files['spikes'], wildcards=self.one.wildcards)
clusters = self._load_object(self.files['clusters'], wildcards=self.one.wildcards)
spikes = self._load_object(self.files['spikes'], wildcards=self.one.wildcards)

return spikes, clusters, channels

Expand Down Expand Up @@ -1090,7 +1104,7 @@ def raster(self, spikes, channels, save_dir=None, br=None, label='raster', time_

self.download_spike_sorting_object('drift', self.spike_sorter, missing='ignore')
if 'drift' in self.files:
drift = alfio.load_object(self.files['drift'], wildcards=self.one.wildcards)
drift = self._load_object(self.files['drift'], wildcards=self.one.wildcards)
axs[0, 0].plot(drift['times'], drift['um'], 'k', alpha=.5)

if save_dir is not None:
Expand Down
168 changes: 168 additions & 0 deletions examples/loading_data/loading_raw_audio_data.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "5683982d",
"metadata": {},
"source": [
"# Loading Raw Audio Data"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6b2485da",
"metadata": {
"nbsphinx": "hidden"
},
"outputs": [],
"source": [
"# Turn off logging, this is a hidden cell on docs page\n",
"import logging\n",
"logger = logging.getLogger('ibllib')\n",
"logger.setLevel(logging.CRITICAL)"
]
},
{
"cell_type": "markdown",
"id": "16345774",
"metadata": {},
"source": [
"The audio file is saved from the microphone. It is useful to look at it to plot a spectrogram and confirm the sounds played during the task are indeed audible."
]
},
{
"cell_type": "markdown",
"id": "8d62c890",
"metadata": {},
"source": [
"## Relevant datasets\n",
"* _iblrig_micData.raw.flac\n"
]
},
{
"cell_type": "markdown",
"id": "bc23fdf7",
"metadata": {},
"source": [
"## Loading"
]
},
{
"cell_type": "markdown",
"id": "9103084d",
"metadata": {},
"source": [
"### Loading raw audio file"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2b807296",
"metadata": {
"ibl_execute": false
},
"outputs": [],
"source": [
"from one.api import ONE\n",
"import soundfile as sf\n",
"\n",
"one = ONE()\n",
"eid = '4ecb5d24-f5cc-402c-be28-9d0f7cb14b3a'\n",
"\n",
"# -- Get raw data\n",
"filename = one.load_dataset(eid, '_iblrig_micData.raw.flac', download_only=True)\n",
"with open(filename, 'rb') as f:\n",
" wav, fs = sf.read(f)"
]
},
{
"cell_type": "markdown",
"id": "203d23c1",
"metadata": {},
"source": [
"## Plot the spectrogram"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "811e3533",
"metadata": {
"ibl_execute": false
},
"outputs": [],
"source": [
"from ibllib.io.extractors.training_audio import welchogram\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"\n",
"# -- Compute spectrogram over first 2 minutes\n",
"t_idx = 120 * fs\n",
"tscale, fscale, W, detect = welchogram(fs, wav[:t_idx])\n",
"\n",
"# -- Put data into single variable\n",
"TF = {}\n",
"\n",
"TF['power'] = W.astype(np.single)\n",
"TF['frequencies'] = fscale[None, :].astype(np.single)\n",
"TF['onset_times'] = detect\n",
"TF['times_mic'] = tscale[:, None].astype(np.single)\n",
"\n",
"# # -- Plot spectrogram\n",
"tlims = TF['times_mic'][[0, -1]].flatten()\n",
"flims = TF['frequencies'][0, [0, -1]].flatten()\n",
"fig = plt.figure(figsize=[16, 7])\n",
"ax = plt.axes()\n",
"im = ax.imshow(20 * np.log10(TF['power'].T), aspect='auto', cmap=plt.get_cmap('magma'),\n",
" extent=np.concatenate((tlims, flims)),\n",
" origin='lower')\n",
"ax.set_xlabel(r'Time (s)')\n",
"ax.set_ylabel(r'Frequency (Hz)')\n",
"plt.colorbar(im)\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"id": "bef6702e",
"metadata": {},
"source": [
"## More details\n",
"* [Description of audio datasets](https://docs.google.com/document/d/1OqIqqakPakHXRAwceYLwFY9gOrm8_P62XIfCTnHwstg/edit#heading=h.n61f0vdcplxp)"
]
},
{
"cell_type": "markdown",
"id": "4e9dd4b9",
"metadata": {},
"source": [
"## Useful modules\n",
"* [ibllib.io.extractors.training_audio](https://int-brain-lab.github.io/iblenv/_autosummary/ibllib.io.extractors.training_audio.html#module-ibllib.io.extractors.training_audio)"
]
}
],
"metadata": {
"celltoolbar": "Edit Metadata",
"kernelspec": {
"display_name": "Python [conda env:iblenv] *",
"language": "python",
"name": "conda-env-iblenv-py"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.7"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
58 changes: 57 additions & 1 deletion examples/loading_data/loading_raw_ephys_data.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@
"\n",
"# Use spikeglx reader to read in the whole raw data\n",
"sr = spikeglx.Reader(bin_file)\n",
"sr.shape\n"
"print(sr.shape)"
]
},
{
Expand Down Expand Up @@ -326,6 +326,62 @@
"destriped = destripe(raw, fs=sr.fs)"
]
},
{
"cell_type": "markdown",
"source": [
"## Get the probe geometry"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "markdown",
"source": [
"### Using the `eid` and `probe` information"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [
"from brainbox.io.one import load_channel_locations\n",
"channels = load_channel_locations(eid, probe)\n",
"print(channels[probe].keys())\n",
"# Use the axial and lateral coordinates ; Print example first 4 channels\n",
"print(channels[probe][\"axial_um\"][0:4])\n",
"print(channels[probe][\"lateral_um\"][0:4])"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "markdown",
"source": [
"### Using the reader and the `.cbin` file"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [
"# You would have loaded the bin file as per the loading example above\n",
"# sr = spikeglx.Reader(bin_file)\n",
"sr.geometry"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "markdown",
"id": "9851b10d",
Expand Down
9 changes: 9 additions & 0 deletions ibllib/io/extractors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,16 @@ class BaseExtractor(abc.ABC):
"""

session_path = None
"""pathlib.Path: Absolute path of session folder."""

save_names = None
"""tuple of str: The filenames of each extracted dataset, or None if array should not be saved."""

var_names = None
"""tuple of str: A list of names for the extracted variables. These become the returned output keys."""

default_path = Path('alf') # relative to session
"""pathlib.Path: The default output folder relative to `session_path`."""

def __init__(self, session_path=None):
# If session_path is None Path(session_path) will fail
Expand Down Expand Up @@ -127,6 +134,8 @@ class BaseBpodTrialsExtractor(BaseExtractor):
bpod_trials = None
settings = None
task_collection = None
frame2ttl = None
audio = None

def extract(self, bpod_trials=None, settings=None, **kwargs):
"""
Expand Down
2 changes: 2 additions & 0 deletions ibllib/io/extractors/biased_trials.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,8 @@ class EphysTrials(BaseBpodTrialsExtractor):
def _extract(self, extractor_classes=None, **kwargs) -> dict:
base = [GoCueTriggerTimes, StimOnTriggerTimes, ItiInTimes, StimOffTriggerTimes, StimFreezeTriggerTimes,
ErrorCueTriggerTimes, TrialsTableEphys, IncludedTrials, PhasePosQuiescence]
# Get all detected TTLs. These are stored for QC purposes
self.frame2ttl, self.audio = raw.load_bpod_fronts(self.session_path, data=self.bpod_trials)
# Exclude from trials table
out, _ = run_extractor_classes(base, session_path=self.session_path, bpod_trials=self.bpod_trials, settings=self.settings,
save=False, task_collection=self.task_collection)
Expand Down
6 changes: 5 additions & 1 deletion ibllib/io/extractors/camera.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,12 +513,16 @@ def attribute_times(arr, events, tol=.1, injective=True, take='first'):
Returns
-------
numpy.array
An array the same length as `events`.
An array the same length as `events` containing indices of `arr` corresponding to each
event.
"""
if (take := take.lower()) not in ('first', 'nearest', 'after'):
raise ValueError('Parameter `take` must be either "first", "nearest", or "after"')
stack = np.ma.masked_invalid(arr, copy=False)
stack.fill_value = np.inf
# If there are no invalid values, the mask is False so let's ensure it's a bool array
if stack.mask is np.bool_(0):
stack.mask = np.zeros(arr.shape, dtype=bool)
assigned = np.full(events.shape, -1, dtype=int) # Initialize output array
min_tol = 0 if take == 'after' else -tol
for i, x in enumerate(events):
Expand Down
Loading

0 comments on commit c848cea

Please sign in to comment.