Skip to content

Commit

Permalink
Extract waveform cbin outputs file list (#54)
Browse files Browse the repository at this point in the history
* the cbin waveform extractor returns the list of output files

* move to ruff formating

* update requirements and github workflow
  • Loading branch information
oliche authored Jan 3, 2025
1 parent 2af823b commit 6b5d488
Show file tree
Hide file tree
Showing 23 changed files with 616 additions and 302 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@ jobs:

- name: flake8
run: |
pip install flake8>=7 --quiet
pip install ruff --quiet
cd ibl-neuropixel
python -m flake8
ruff check
- name: iblrig and iblpybpod requirements
shell: bash -l {0}
Expand Down
9 changes: 8 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,16 @@ The following describes the methods implemented in this repository.
https://doi.org/10.6084/m9.figshare.19705522

## Contribution
Contribution checklist:
- run tests
- ruff format
- PR to main


Pypi Release checklist:
- Edit the version number in `setup.py`, and add release notes in `release_notes.md`
- Edit the version number in `setup.py`
- add release notes in `release_notes.md`


```shell
flake8
Expand Down
9 changes: 9 additions & 0 deletions release_notes.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,14 @@
# Changelog

## [1.6.1] - 2025-01-03

### fixed
- waveforms extractor returns a list of files for registration

### changed
- moved the repo contribution guide to automatic ruff formatting


## [1.6.0] - 2024-12-06

### added
Expand Down
2 changes: 1 addition & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
flake8
pytest
ruff
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

setuptools.setup(
name="ibl-neuropixel",
version="1.6.0",
version="1.6.1",
author="The International Brain Laboratory",
description="Collection of tools for Neuropixel 1.0 and 2.0 probes data",
long_description=long_description,
Expand Down
14 changes: 8 additions & 6 deletions src/ibldsp/cadzow.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,15 +136,17 @@ def cadzow_np1(
WAV = scipy.fft.rfft(wav[:, :])
padgain = scipy.signal.windows.hann(npad * 2)[:npad]
WAV = np.r_[
np.flipud(WAV[1: npad + 1, :]) * padgain[:, np.newaxis],
np.flipud(WAV[1 : npad + 1, :]) * padgain[:, np.newaxis],
WAV,
np.flipud(WAV[-npad - 2: -1, :]) * np.flipud(np.r_[padgain, 1])[:, np.newaxis],
np.flipud(WAV[-npad - 2 : -1, :]) * np.flipud(np.r_[padgain, 1])[:, np.newaxis],
] # apply padding
x = np.r_[np.flipud(h["x"][1: npad + 1]), h["x"], np.flipud(h["x"][-npad - 2: -1])]
x = np.r_[
np.flipud(h["x"][1 : npad + 1]), h["x"], np.flipud(h["x"][-npad - 2 : -1])
]
y = np.r_[
np.flipud(h["y"][1: npad + 1]) - 120,
np.flipud(h["y"][1 : npad + 1]) - 120,
h["y"],
np.flipud(h["y"][-npad - 2: -1]) + 120,
np.flipud(h["y"][-npad - 2 : -1]) + 120,
]
WAV_ = np.zeros_like(WAV)
gain = np.zeros(ntr + npad * 2 + 1)
Expand All @@ -166,6 +168,6 @@ def cadzow_np1(
)
WAV_[firstx:lastx, :] += array * gw[:, np.newaxis]

WAV_ = WAV_[npad: -npad - 1] # remove padding
WAV_ = WAV_[npad : -npad - 1] # remove padding
wav_ = scipy.fft.irfft(WAV_)
return wav_
1 change: 1 addition & 0 deletions src/ibldsp/fourier.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Low-level functions to work in frequency domain for n-dim arrays
"""

from math import pi

import numpy as np
Expand Down
58 changes: 39 additions & 19 deletions src/ibldsp/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
import matplotlib.pyplot as plt


def show_channels_labels(raw, fs, channel_labels, xfeats, similarity_threshold, psd_hf_threshold=0.02):
def show_channels_labels(
raw, fs, channel_labels, xfeats, similarity_threshold, psd_hf_threshold=0.02
):
"""
Shows the features side by side a snippet of raw data
:param sr:
Expand All @@ -12,27 +14,45 @@ def show_channels_labels(raw, fs, channel_labels, xfeats, similarity_threshold,
raw = raw - np.mean(raw, axis=-1)[:, np.newaxis] # removes DC offset
ns_plot = np.minimum(ns, 3000)
vaxis_uv = 250 if fs < 2600 else 75
fig, ax = plt.subplots(1, 5, figsize=(18, 6), gridspec_kw={'width_ratios': [1, 1, 1, 8, .2]})
ax[0].plot(xfeats['xcor_hf'], np.arange(nc))
ax[0].plot(xfeats['xcor_hf'][(iko := channel_labels == 1)], np.arange(nc)[iko], 'k*')
ax[0].plot(similarity_threshold[0] * np.ones(2), [0, nc], 'k--')
ax[0].plot(similarity_threshold[1] * np.ones(2), [0, nc], 'r--')
ax[0].set(ylabel='channel #', xlabel='high coherence', ylim=[0, nc], title='a) dead channel')
ax[1].plot(xfeats['psd_hf'], np.arange(nc))
ax[1].plot(xfeats['psd_hf'][(iko := channel_labels == 2)], np.arange(nc)[iko], 'r*')
ax[1].plot(psd_hf_threshold * np.array([1, 1]), [0, nc], 'r--')
ax[1].set(yticklabels=[], xlabel='PSD', ylim=[0, nc], title='b) noisy channel')
fig, ax = plt.subplots(
1, 5, figsize=(18, 6), gridspec_kw={"width_ratios": [1, 1, 1, 8, 0.2]}
)
ax[0].plot(xfeats["xcor_hf"], np.arange(nc))
ax[0].plot(
xfeats["xcor_hf"][(iko := channel_labels == 1)], np.arange(nc)[iko], "k*"
)
ax[0].plot(similarity_threshold[0] * np.ones(2), [0, nc], "k--")
ax[0].plot(similarity_threshold[1] * np.ones(2), [0, nc], "r--")
ax[0].set(
ylabel="channel #",
xlabel="high coherence",
ylim=[0, nc],
title="a) dead channel",
)
ax[1].plot(xfeats["psd_hf"], np.arange(nc))
ax[1].plot(xfeats["psd_hf"][(iko := channel_labels == 2)], np.arange(nc)[iko], "r*")
ax[1].plot(psd_hf_threshold * np.array([1, 1]), [0, nc], "r--")
ax[1].set(yticklabels=[], xlabel="PSD", ylim=[0, nc], title="b) noisy channel")
ax[1].sharey(ax[0])
ax[2].plot(xfeats['xcor_lf'], np.arange(nc))
ax[2].plot(xfeats['xcor_lf'][(iko := channel_labels == 3)], np.arange(nc)[iko], 'y*')
ax[2].plot([-.75, -.75], [0, nc], 'y--')
ax[2].set(yticklabels=[], xlabel='LF coherence', ylim=[0, nc], title='c) outside')
ax[2].plot(xfeats["xcor_lf"], np.arange(nc))
ax[2].plot(
xfeats["xcor_lf"][(iko := channel_labels == 3)], np.arange(nc)[iko], "y*"
)
ax[2].plot([-0.75, -0.75], [0, nc], "y--")
ax[2].set(yticklabels=[], xlabel="LF coherence", ylim=[0, nc], title="c) outside")
ax[2].sharey(ax[0])
im = ax[3].imshow(raw[:, :ns_plot] * 1e6, origin='lower', cmap='PuOr', aspect='auto',
vmin=-vaxis_uv, vmax=vaxis_uv, extent=[0, ns_plot / fs * 1e3, 0, nc])
ax[3].set(yticklabels=[], title='d) Raw data', xlabel='time (ms)', ylim=[0, nc])
im = ax[3].imshow(
raw[:, :ns_plot] * 1e6,
origin="lower",
cmap="PuOr",
aspect="auto",
vmin=-vaxis_uv,
vmax=vaxis_uv,
extent=[0, ns_plot / fs * 1e3, 0, nc],
)
ax[3].set(yticklabels=[], title="d) Raw data", xlabel="time (ms)", ylim=[0, nc])
ax[3].grid(False)
ax[3].sharey(ax[0])
plt.colorbar(im, cax=ax[4], shrink=0.8).ax.set(ylabel='(uV)')
plt.colorbar(im, cax=ax[4], shrink=0.8).ax.set(ylabel="(uV)")
fig.tight_layout()
return fig, ax
4 changes: 3 additions & 1 deletion src/ibldsp/raw_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,9 @@ def compute_raw_features_snippet(sr_ap, sr_lf, t0, t1, filter_ap=None, filter_lf
dc_offset = np.mean(raw, axis=1)
channel_labels, xfeats_raw = detect_bad_channels(raw, **detect_kwargs[band])
butter = scipy.signal.sosfiltfilt(filters[band], raw)
destriped = destripe_fn[band](raw, fs=sr.fs, h=sr.geometry, channel_labels=channel_labels)
destriped = destripe_fn[band](
raw, fs=sr.fs, h=sr.geometry, channel_labels=channel_labels
)
# compute same channel feats for destripe
_, xfeats_destriped = detect_bad_channels(destriped, **detect_kwargs[band])

Expand Down
4 changes: 2 additions & 2 deletions src/ibldsp/smooth.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,15 +68,15 @@ def rolling_window(x, window_len=11, window="blackman"):
'bartlett', 'blackman'"
)

s = np.r_[x[window_len - 1: 0: -1], x, x[-1:-window_len:-1]]
s = np.r_[x[window_len - 1 : 0 : -1], x, x[-1:-window_len:-1]]
# print(len(s))
if window == "flat": # moving average
w = np.ones(window_len, "d")
else:
w = eval("np." + window + "(window_len)")

y = np.convolve(w / w.sum(), s, mode="valid")
return y[round((window_len / 2 - 1)): round(-(window_len / 2))]
return y[round((window_len / 2 - 1)) : round(-(window_len / 2))]


def non_uniform_savgol(x, y, window, polynom):
Expand Down
9 changes: 6 additions & 3 deletions src/ibldsp/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Window generator, front detections, rms
"""

import numpy as np
import scipy

Expand Down Expand Up @@ -276,13 +277,15 @@ def firstlast_splicing(self):
:return: tuple of (first_index, last_index, amplitude_vector]
"""
w = scipy.signal.windows.hann((self.overlap + 1) * 2 + 1, sym=True)[1:self.overlap + 1]
w = scipy.signal.windows.hann((self.overlap + 1) * 2 + 1, sym=True)[
1 : self.overlap + 1
]
assert np.all(np.isclose(w + np.flipud(w), 1))

for first, last in self.firstlast:
amp = np.ones(last - first)
amp[:self.overlap] = 1 if first == 0 else w
amp[-self.overlap:] = 1 if last == self.ns else np.flipud(w)
amp[: self.overlap] = 1 if first == 0 else w
amp[-self.overlap :] = 1 if last == self.ns else np.flipud(w)
yield (first, last, amp)

@property
Expand Down
71 changes: 53 additions & 18 deletions src/ibldsp/voltage.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Module to work with raw voltage traces. Spike sorting pre-processing functions.
"""

from pathlib import Path

import numpy as np
Expand Down Expand Up @@ -142,7 +143,7 @@ def fk(
return xf * gain


def car(x, collection=None, operator='median', **kwargs):
def car(x, collection=None, operator="median", **kwargs):
"""
Applies common average referencing with optional automatic gain control
:param x: np.array(nc, ns) the input array to be de-referenced. dimension, the filtering is considering
Expand All @@ -159,9 +160,9 @@ def car(x, collection=None, operator='median', **kwargs):
xout[sel, :] = car(x=x[sel, :], collection=None, **kwargs)
return xout

if operator == 'median':
if operator == "median":
x = x - np.median(x, axis=0)
elif operator == 'average':
elif operator == "average":
x = x - np.mean(x, axis=0)
return x

Expand Down Expand Up @@ -235,7 +236,9 @@ def kfilt(
return xf * gain


def saturation(data, max_voltage, v_per_sec=1e-8, fs=30_000, proportion=0.2, mute_window_samples=7):
def saturation(
data, max_voltage, v_per_sec=1e-8, fs=30_000, proportion=0.2, mute_window_samples=7
):
"""
Computes
:param data: [nc, ns]: voltage traces array
Expand All @@ -258,7 +261,7 @@ def saturation(data, max_voltage, v_per_sec=1e-8, fs=30_000, proportion=0.2, mut
saturation = np.logical_or(saturation > proportion, n_diff_saturated > proportion)
# apply a cosine taper to the saturation to create a mute function
win = scipy.signal.windows.cosine(mute_window_samples)
mute = np.maximum(0, 1 - scipy.signal.convolve(saturation, win, mode='same'))
mute = np.maximum(0, 1 - scipy.signal.convolve(saturation, win, mode="same"))
return saturation, mute


Expand Down Expand Up @@ -378,18 +381,31 @@ def destripe(
return x


def destripe_lfp(x, fs, h=None, channel_labels=None, butter_kwargs=None, k_filter=False):
def destripe_lfp(
x, fs, h=None, channel_labels=None, butter_kwargs=None, k_filter=False
):
"""
Wrapper around the destripe function with some default parameters to destripe the LFP band
See help destripe function for documentation
:param x: demultiplexed array (nc, ns)
:param fs: sampling frequency
:param channel_labels: see destripe
"""
butter_kwargs = {"N": 3, "Wn": [0.5, 300], "btype": "bandpass", "fs": fs} if butter_kwargs is None else butter_kwargs
butter_kwargs = (
{"N": 3, "Wn": [0.5, 300], "btype": "bandpass", "fs": fs}
if butter_kwargs is None
else butter_kwargs
)
if channel_labels is True:
channel_labels, _ = detect_bad_channels(x, fs=fs, psd_hf_threshold=1.4)
return destripe(x, fs, h=h, butter_kwargs=butter_kwargs, k_filter=k_filter, channel_labels=channel_labels)
return destripe(
x,
fs,
h=h,
butter_kwargs=butter_kwargs,
k_filter=k_filter,
channel_labels=channel_labels,
)


def decompress_destripe_cbin(
Expand Down Expand Up @@ -474,7 +490,9 @@ def decompress_destripe_cbin(
# if we want to compute the rms ap across the session as well as the saturation
if compute_rms:
# creates a saturation memmap, this is a nsamples vector of booleans
file_saturation = output_file.parent.joinpath("_iblqc_ephysSaturation.samples.npy")
file_saturation = output_file.parent.joinpath(
"_iblqc_ephysSaturation.samples.npy"
)
np.save(file_saturation, np.zeros(sr.ns, dtype=bool))
# creates the place holders for the rms
ap_rms_file = output_file.parent.joinpath("ap_rms.bin")
Expand Down Expand Up @@ -542,7 +560,8 @@ def my_function(i_chunk, n_chunk):
# Apply tapers
chunk = _sr[first_s:last_s, :ncv].T
saturated_samples, mute_saturation = saturation(
data=chunk, max_voltage=_sr.range_volts[:ncv], fs=_sr.fs)
data=chunk, max_voltage=_sr.range_volts[:ncv], fs=_sr.fs
)
_saturation[first_s:last_s] = saturated_samples
chunk[:, :SAMPLES_TAPER] *= taper[:SAMPLES_TAPER]
chunk[:, -SAMPLES_TAPER:] *= taper[SAMPLES_TAPER:]
Expand Down Expand Up @@ -617,13 +636,22 @@ def my_function(i_chunk, n_chunk):
saturation_data = np.load(file_saturation)
assert rms_data.shape[0] == time_data.shape[0] * ncv
rms_data = rms_data.reshape(time_data.shape[0], ncv)
output_qc_path = output_file.parent if output_qc_path is None else output_qc_path
output_qc_path = (
output_file.parent if output_qc_path is None else output_qc_path
)
np.save(output_qc_path.joinpath("_iblqc_ephysTimeRmsAP.rms.npy"), rms_data)
np.save(output_qc_path.joinpath("_iblqc_ephysTimeRmsAP.timestamps.npy"), time_data)
np.save(output_qc_path.joinpath("_iblqc_ephysSaturation.samples.npy"), saturation_data)
np.save(
output_qc_path.joinpath("_iblqc_ephysTimeRmsAP.timestamps.npy"), time_data
)
np.save(
output_qc_path.joinpath("_iblqc_ephysSaturation.samples.npy"),
saturation_data,
)


def detect_bad_channels(raw, fs, similarity_threshold=(-0.5, 1), psd_hf_threshold=None, display=False):
def detect_bad_channels(
raw, fs, similarity_threshold=(-0.5, 1), psd_hf_threshold=None, display=False
):
"""
Bad channels detection for Neuropixel probes
Labels channels
Expand Down Expand Up @@ -699,12 +727,12 @@ def nxcor(x, ref):
xcor = channels_similarity(raw)
fscale, psd = scipy.signal.welch(raw * 1e6, fs=fs) # units; uV ** 2 / Hz
# auto-detection of the band with which we are working
band = 'ap' if fs > 2600 else 'lf'
band = "ap" if fs > 2600 else "lf"
# the LFP band data is obviously much stronger so auto-adjust the default threshold
if band == 'ap':
if band == "ap":
psd_hf_threshold = 0.02 if psd_hf_threshold is None else psd_hf_threshold
filter_kwargs = {"N": 3, "Wn": 300 / fs * 2, "btype": "highpass"}
elif band == 'lf':
elif band == "lf":
psd_hf_threshold = 1.4 if psd_hf_threshold is None else psd_hf_threshold
filter_kwargs = {"N": 3, "Wn": 1 / fs * 2, "btype": "highpass"}
sos_hp = scipy.signal.butter(**filter_kwargs, output="sos")
Expand Down Expand Up @@ -741,7 +769,13 @@ def nxcor(x, ref):
# ephys_bad_channels(x, 30000, ichannels, xfeats)
if display:
ibldsp.plots.show_channels_labels(
raw, fs, ichannels, xfeats, similarity_threshold=similarity_threshold, psd_hf_threshold=psd_hf_threshold)
raw,
fs,
ichannels,
xfeats,
similarity_threshold=similarity_threshold,
psd_hf_threshold=psd_hf_threshold,
)
return ichannels, xfeats


Expand Down Expand Up @@ -774,6 +808,7 @@ def detect_bad_channels_cbin(bin_file, n_batches=10, batch_duration=0.3, display
if display:
raw = sr[sl, :nc].TO
from ibllib.plots.figures import ephys_bad_channels

ephys_bad_channels(raw, sr.fs, channel_flags, xfeats_med)
return channel_flags

Expand Down
Loading

0 comments on commit 6b5d488

Please sign in to comment.