From 97c482494eb0230405e45787040579daf1bc3564 Mon Sep 17 00:00:00 2001 From: Olivier Winter Date: Sat, 5 Oct 2024 11:49:54 +0100 Subject: [PATCH] Optim waveforms (#44) * modify wiggle plot to allow filling of positive and/or negative peaks * double wiggle: double trouble. And tests. * pre-processing should have time samples contiguous in memory * WIP compress before * instantiate memmap only once * waveformextraction: remove AGC for the CAR option * Write to output is parallelized * add decompression in waveform extraction * waveform extract: add removal of temporary file * flatten waveforms * allow for cluster indices not starting at 0 and flake8 fixes * starting on test w/ new format * wip waveforms loader version 1 and 2 * load w/ indices * some more fixes * add colours to the channel detection plots * remove the Windows tests from CI * fix bug when cluster ids do not start at 0 * fix last set of tests for the waveform extraction * remove neurodsp after grace period expired * Make sure the waveform loader is compatible with the notebook * add bad channel plot helper and default LF parameters * add some splitting / splicing function on window generator * add lfp pre-processing * fix channel labels bug * fix default lfp filter parameters * fix waveform extractor conflict * add compatible syntax for 3.9 --------- Co-authored-by: chris-langfield --- .DS_Store | Bin 6148 -> 0 bytes .github/workflows/tests.yaml | 2 +- release_notes.md | 10 + setup.py | 2 +- src/.DS_Store | Bin 6148 -> 0 bytes src/ibldsp/plots.py | 23 +- src/ibldsp/utils.py | 34 ++- src/ibldsp/voltage.py | 79 +++---- src/ibldsp/waveform_extraction.py | 310 +++++++++++++-------------- src/ibldsp/waveforms.py | 55 ++++- src/neurodsp/__init__.py | 8 - src/spikeglx.py | 25 ++- src/tests/unit/cpu/test_ibldsp.py | 38 ++-- src/tests/unit/cpu/test_waveforms.py | 49 +++-- 14 files changed, 349 insertions(+), 286 deletions(-) delete mode 100644 .DS_Store delete mode 100644 src/.DS_Store delete mode 100644 src/neurodsp/__init__.py diff --git a/.DS_Store b/.DS_Store deleted file mode 100644 index 2c8855d5a1f99f4069383e14b127c58da39f85a6..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 6148 zcmeHKJxc>Y5S{fd5E6wnLCXWN5{2{<&TuJgY*JYn^FbmcTrjcF>cBr>P=AAy)9XP7|*-? zU53*?HMxJEIuxVdvD0n2`kaoXdpnwLzt6583~x3!-N~Z3Q@0#)T@P$sP}`|mC_c?2 z>6*Fwm$%ub=ELTtD`ulky>%`pvZ4&&%x1|K4T=>8gaKjTn*rV*0w`l>F)^sO4j6p} z0Qzui!I*z3I47_eT1*Tg0#POvXi}BEVknaie_-c@788Reos_POb!=s2ZzxJvhdp{fbye8RKqLTvo-T_eAWigCX|KeB?fg0 ij6II^fRExWs20Qm9som&i9vWE@*|*W5F-ryDFdG&frq&O diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index 6da4e34..12e3474 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -13,7 +13,7 @@ jobs: strategy: fail-fast: false matrix: - os: ["ubuntu-latest", "windows-latest"] + os: ["ubuntu-latest"] python-version: ["3.9", "3.10"] steps: - name: Checkout ibl-neuropixel repo diff --git a/release_notes.md b/release_notes.md index f03c5d6..1a802e4 100644 --- a/release_notes.md +++ b/release_notes.md @@ -1,3 +1,13 @@ +# 1.4 + +## 1.4.0 2024-10-05 +- Waveform extraction: + - Optimization of the waveform extractor, outputs flattened waveforms + - Refactoring ot the waveform loader with back compability +- Bad channel detector: + - The bad channel detector has a plot option to visualize the bad channels and thresholds + - The default low-cut filters are set to 300Hz for AP band and 2 Hz for LF band + # 1.3 ## 1.3.2 2024-09-18 diff --git a/setup.py b/setup.py index 70a90ed..34daece 100644 --- a/setup.py +++ b/setup.py @@ -8,7 +8,7 @@ setuptools.setup( name="ibl-neuropixel", - version="1.3.2", + version="1.4.0", author="The International Brain Laboratory", description="Collection of tools for Neuropixel 1.0 and 2.0 probes data", long_description=long_description, diff --git a/src/.DS_Store b/src/.DS_Store deleted file mode 100644 index b21c9d1154d1db3eb1a115ac05783c40f581eeb3..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 6148 zcmeHKy=q%A7(KF`;?z(AEyUx6EL~dau~To2yB9n-n-kluWBhYXUAJ|o3tckzJ(4#F zSrS4ZZx9Hcv*ZB+rRS4GBVos_0}Z97 z_;$~JQmHh0NrN8w{k>PM&EoOvMc%C?yoWn3iz$c2Iyz{eN6rI&X~y_~gzP)`!cq^< znYa0S#+m`%Bk|){{jvNAZhYUaPfxPvM-tzXJo?E^+;<T;ayZd@IF)$fVZ-(@u4!<+wsUSHwxH;iYt zSiWE|OgYO_{Dl(^c6k zhBDn*A6dWHVrtNIhte~196PhJHx#93XMLpIp<;s}4+;bYd 2600 else 'lf' + # the LFP band data is obviously much stronger so auto-adjust the default threshold + if band == 'ap': + psd_hf_threshold = 0.02 if psd_hf_threshold is None else psd_hf_threshold + filter_kwargs = {"N": 3, "Wn": 300 / fs * 2, "btype": "highpass"} + elif band == 'lf': + psd_hf_threshold = 1.4 if psd_hf_threshold is None else psd_hf_threshold + filter_kwargs = {"N": 3, "Wn": 1 / fs * 2, "btype": "highpass"} + sos_hp = scipy.signal.butter(**filter_kwargs, output="sos") hf = scipy.signal.sosfiltfilt(sos_hp, raw) xcorf = channels_similarity(hf) - xfeats = { "ind": np.arange(nc), "rms_raw": utils.rms(raw), # very similar to the rms avfter butterworth filter @@ -754,7 +740,8 @@ def nxcor(x, ref): # from ibllib.plots.figures import ephys_bad_channels # ephys_bad_channels(x, 30000, ichannels, xfeats) if display: - ibldsp.plots.show_channels_labels(raw, fs, ichannels, xfeats) + ibldsp.plots.show_channels_labels( + raw, fs, ichannels, xfeats, similarity_threshold=similarity_threshold, psd_hf_threshold=psd_hf_threshold) return ichannels, xfeats diff --git a/src/ibldsp/waveform_extraction.py b/src/ibldsp/waveform_extraction.py index 9075cb0..5041480 100644 --- a/src/ibldsp/waveform_extraction.py +++ b/src/ibldsp/waveform_extraction.py @@ -1,9 +1,9 @@ import logging +from pathlib import Path import scipy import pandas as pd import numpy as np -from pathlib import Path from numpy.lib.format import open_memmap from joblib import Parallel, delayed, cpu_count @@ -11,10 +11,25 @@ from ibldsp.voltage import detect_bad_channels, interpolate_bad_channels, car, kfilt from ibldsp.fourier import fshift from ibldsp.utils import make_channel_index +from iblutil.numerical import ismember logger = logging.getLogger(__name__) +def aggregate_by_clusters(df_wavs): + """ + Group by the waveform dataframe by clusters + :param df_wavs: + :return: + """ + df_clusters = df_wavs.loc[df_wavs['sample'] >= 0, :].groupby('cluster').aggregate( + count=pd.NamedAgg(column="cluster", aggfunc="count"), + first_index=pd.NamedAgg(column="waveform_index", aggfunc="min"), + last_index=pd.NamedAgg(column="waveform_index", aggfunc="max"), + ) + return df_clusters + + def extract_wfs_array( arr, df, @@ -41,15 +56,15 @@ def extract_wfs_array( """ # This is to do fast index assignment to assign missing channels (out of the probe) to NaN if add_nan_trace: - newcol = np.empty((arr.shape[0], 1)) + newcol = np.empty((1, arr.shape[1])) newcol[:] = np.nan - arr = np.hstack([arr, newcol]) + arr = np.vstack([arr, newcol]) # check that the spike window is included in the recording: last_idx = df["sample"].iloc[-1] assert ( - last_idx + (spike_length_samples - trough_offset) < arr.shape[0] - ), f"Spike index {last_idx} extends past end of recording ({arr.shape[0]} samples)." + last_idx + (spike_length_samples - trough_offset) < arr.shape[1] + ), f"Spike index {last_idx} extends past end of recording ({arr.shape[1]} samples)." nwf = len(df) @@ -62,7 +77,7 @@ def extract_wfs_array( ) nchan = cind.shape[1] - wfs = np.zeros((nwf, spike_length_samples, nchan), arr.dtype) + wfs = np.zeros((nwf, nchan, spike_length_samples), arr.dtype) fun = range if verbose: try: @@ -72,9 +87,9 @@ def extract_wfs_array( except ImportError: pass for i in fun(nwf): - wfs[i, :, :] = arr[sind[i], :][:, cind[i]] + wfs[i, :, :] = arr[:, sind[i]][cind[i], :] - return wfs.swapaxes(1, 2), cind, trough_offset + return wfs, cind, trough_offset def _get_channel_labels(sr, num_snippets=20, verbose=True): @@ -155,20 +170,25 @@ def _make_wfs_table( wf_flat = pd.DataFrame( { "index": np.arange(wf_idx.shape[0]), - "sample": spike_samples[wf_idx].astype(int), + "sample": spike_samples[wf_idx].astype(np.int64), "cluster": spike_clusters[wf_idx].astype(int), "peak_channel": spike_channels[wf_idx].astype(int), + "waveform_index": np.zeros(wf_idx.shape[0], int), } ) + # we pre-compute the final absolute indices of each waveform + unique_clusters, cluster_index, cluster_counts = np.unique( + wf_flat["cluster"], return_inverse=True, return_counts=True) + index_order_clusters = np.argsort(cluster_index, kind='stable') + wf_flat.loc[index_order_clusters, 'waveform_index'] = np.arange(wf_flat.shape[0]) # 3d "flat" version return wf_flat, unit_ids def write_wfs_chunk( i_chunk, cbin, - wfs_fn, - mmap_shape, + wfs_mmap, geom_dict, channel_labels, channel_neighbors, @@ -190,8 +210,6 @@ def write_wfs_chunk( my_sr = spikeglx.Reader(cbin, **reader_kwargs) s0, s1 = sr_sl - wfs_mmap = open_memmap(wfs_fn, shape=mmap_shape, mode="r+", dtype=np.float32) - if i_chunk == 0: offset = 0 else: @@ -204,15 +222,15 @@ def write_wfs_chunk( snip = my_sr[ s0 - offset:s1 + spike_length_samples - trough_offset, :-my_sr.nsync - ] + ].T if "butterworth" in preprocess_steps: butter_kwargs = {"N": 3, "Wn": 300 / my_sr.fs * 2, "btype": "highpass"} sos = scipy.signal.butter(**butter_kwargs, output="sos") - snip = scipy.signal.sosfiltfilt(sos, snip.T).T + snip = scipy.signal.sosfiltfilt(sos, snip) if "phase_shift" in preprocess_steps: - snip = fshift(snip, geom_dict["sample_shift"], axis=0) + snip = fshift(snip, geom_dict["sample_shift"], axis=-1) if "bad_channel_interpolation" in preprocess_steps: snip = interpolate_bad_channels( @@ -225,7 +243,7 @@ def write_wfs_chunk( k_kwargs = { "ntr_pad": 60, "ntr_tap": 0, - "lagc": int(my_sr.fs / 10), + "lagc": 0, # no agc for the median estimator of common reference channel "butter_kwargs": {"N": 3, "Wn": 0.01, "btype": "highpass"}, } if "car" in preprocess_steps: @@ -235,12 +253,8 @@ def write_wfs_chunk( if "kfilt" in preprocess_steps: kfilt_func = lambda dat: kfilt(dat, **k_kwargs) # noqa: E731 snip = kfilt_func(snip) - - wfs_mmap[wf_flat["index"], :, :] = extract_wfs_array( - snip, df, channel_neighbors, add_nan_trace=True - )[0] - - wfs_mmap.flush() + iw = wf_flat['waveform_index'].values + wfs_mmap[iw, :, :] = extract_wfs_array(snip, df, channel_neighbors, add_nan_trace=True)[0] def extract_wfs_cbin( @@ -255,11 +269,12 @@ def extract_wfs_cbin( trough_offset=42, spike_length_samples=128, chunksize_samples=int(3000), - reader_kwargs={}, + reader_kwargs=None, n_jobs=None, wfs_dtype=np.float32, - preprocess_steps=[], - seed=None + preprocess_steps=None, + seed=None, + scratch_dir=None, ): """ Given a bin file and locations of spikes, extract waveforms for each unit, compute @@ -269,7 +284,7 @@ def extract_wfs_cbin( reference procedure is applied in the spatial dimension. The following files will be generated: - - waveforms.traces.npy: `(num_units, max_wf, nc, spike_length_samples)` + - waveforms.traces.npy: `(total_waveforms, nc, spike_length_samples)` This file contains the lightly processed waveforms indexed by cluster in the first dimension. By default `max_wf=256, nc=40, spike_length_samples=128`. @@ -307,8 +322,11 @@ def extract_wfs_cbin( :param wfs_dtype: Data type of raw waveforms saved (default np.float32) :param preprocess: Preprocessing options to apply, list which must be a subset of ["phase_shift", "bad_channel_interpolation", "butterworth", "car", "kfilt"] + By default a butterworth 300Hz high-pass and the rephasing of the channels is perfomed """ n_jobs = n_jobs or int(cpu_count() / 2) + preprocess_steps = ['butterworth', 'phase_shift'] if preprocess_steps is None else preprocess_steps + reader_kwargs = {} if reader_kwargs is None else reader_kwargs assert set(preprocess_steps).issubset( { @@ -327,6 +345,13 @@ def extract_wfs_cbin( if h is None: h = sr.geometry + if sr.is_mtscomp: + bin_file = sr.decompress_to_scratch(scratch_dir=scratch_dir) + sr = spikeglx.Reader(bin_file, **reader_kwargs) + file_to_unlink = bin_file + else: + file_to_unlink = None + s0_arr = np.arange(0, sr.ns, chunksize_samples) s1_arr = s0_arr + chunksize_samples s1_arr[-1] = sr.ns @@ -353,7 +378,7 @@ def extract_wfs_cbin( elif channel_labels is None: channel_labels = np.zeros(sr.nc - sr.nsync) - nwf = len(wf_flat) + nwf = wf_flat.shape[0] nu = unit_ids.shape[0] logger.info(f"Extracting {nwf} waveforms from {nu} units") @@ -365,9 +390,9 @@ def extract_wfs_cbin( # this intermediate memmap is written to in parallel # the waveforms are ordered only by their chronological position # in the recording, as we are reading them in time chunks - int_fn = output_dir.joinpath("_wf_extract_intermediate.npy") + traces_fn = output_dir.joinpath("waveforms.traces.npy") wfs = open_memmap( - int_fn, mode="w+", shape=(nwf, nc, spike_length_samples), dtype=np.float32 + traces_fn, mode="w+", shape=(nwf, nc, spike_length_samples), dtype=np.float32 ) slices = [ @@ -379,8 +404,7 @@ def extract_wfs_cbin( delayed(write_wfs_chunk)( i, bin_file, - int_fn, - wfs.shape, + wfs, h, channel_labels, channel_neighbors, @@ -396,86 +420,46 @@ def extract_wfs_cbin( ) # output files - traces_fn = output_dir.joinpath("waveforms.traces.npy") templates_fn = output_dir.joinpath("waveforms.templates.npy") table_fn = output_dir.joinpath("waveforms.table.pqt") channels_fn = output_dir.joinpath("waveforms.channels.npz") - ## rearrange and save traces by unit + ## rearrange dataframe: sort waveforms by cluster and aggregate by cluster + wf_flat.sort_values(by=["cluster", "sample"], inplace=True) + df_clusters = aggregate_by_clusters(wf_flat) + + # we want to store the index of the waveform within each cluster to facilitate loading later + wf_flat['index_within_clusters'] = np.ones(wf_flat.shape[0]) + inewc = np.diff(wf_flat['cluster'].values, prepend=wf_flat['cluster'].values[0]) != 0 + wf_flat.loc[inewc, 'index_within_clusters'] = - df_clusters['count'].values[:-1] + 1 + wf_flat['index_within_clusters'] = np.cumsum(wf_flat['index_within_clusters'].values).astype(int) - 1 + # store medians across waveforms wfs_templates = np.full((nu, nc, spike_length_samples), np.nan, dtype=np.float32) - # create waveform output file (~2-3 GB) - traces_by_unit = open_memmap( - traces_fn, - mode="w+", - shape=(nu, max_wf, nc, spike_length_samples), - dtype=wfs_dtype, - ) logger.info("Writing to output files") - - for i, u in enumerate(unit_ids): - idx = np.where(wf_flat["cluster"] == u)[0] - nwf_u = idx.shape[0] - # reopening these memmaps on each iteration - # forces Python to clean up each large array it loads - # and prevent a memory leak - wfs = open_memmap( - int_fn, mode="r+", shape=(nwf, nc, spike_length_samples), dtype=np.float32 - ) - traces_by_unit = open_memmap( - traces_fn, - mode="r+", - shape=(nu, max_wf, nc, spike_length_samples), - dtype=wfs_dtype, - ) - # write up to 256 waveforms and leave the rest of dimensions 1-3 as NaNs - traces_by_unit[i, : min(max_wf, nwf_u), :, :] = wfs[idx].astype(wfs_dtype) - traces_by_unit.flush() - # populate this array in memory as it's 256x smaller - wfs_templates[i, :, :] = np.nanmedian(wfs[idx], axis=0) - - # cleanup intermediate file - int_fn.unlink() - + wfs = open_memmap(traces_fn) + for i, rec in enumerate(df_clusters.itertuples()): + wfs_templates[i] = np.nanmedian(wfs[rec.first_index:rec.last_index + 1], axis=0) # save templates np.save(templates_fn, wfs_templates) + # save the waveform table - # add in dummy rows and order by unit, and then sample - unit_counts = wf_flat.groupby("cluster")["sample"].count().reset_index(name="count") - unit_counts["missing"] = max_wf - unit_counts["count"] - missing_wf = unit_counts[unit_counts["missing"] > 0] - total_missing = sum(missing_wf.missing) - extra_rows = pd.DataFrame( - { - "sample": [np.nan] * total_missing, - "peak_channel": [np.nan] * total_missing, - "index": [np.nan] * total_missing, - "cluster": sum( - [[row["cluster"]] * row["missing"] for _, row in missing_wf.iterrows()], - [], - ), - } - ) - save_df = pd.concat([wf_flat, extra_rows]) - # now the waveforms are arranged by cluster, and then in time - # these match dimensions 0 and 1 of waveforms.traces.npy - save_df.sort_values(["cluster", "sample"], inplace=True) - save_df.to_parquet(table_fn) + wf_flat.to_parquet(table_fn) # save channel map for each waveform # these values are now reordered so that they match the pqt # and the traces file - peak_channel = np.nan_to_num(save_df["peak_channel"].to_numpy(), nan=-1).astype( - np.int16 - ) - dummy_idx = np.where(peak_channel >= 0)[0] - # leave "missing" waveforms as -1 since we can't have NaN with int dtype - chan_map = np.ones((max_wf * nu, nc), np.int16) * -1 - chan_map[dummy_idx] = channel_neighbors[peak_channel[dummy_idx].astype(int)] + peak_channel = np.nan_to_num(wf_flat["peak_channel"].to_numpy(), nan=-1).astype(np.int16) + chan_map = channel_neighbors[peak_channel.astype(int)] np.savez(channels_fn, channels=chan_map) + # clean up the cached bin file + if file_to_unlink is not None: + file_to_unlink.with_suffix(".meta").unlink() + file_to_unlink.unlink() class WaveformsLoader: + data_version = None """ Interface to the output of `extract_wfs_cbin`. Requires the following four files to @@ -502,26 +486,17 @@ class WaveformsLoader: WaveformsLoader.load_waveforms() and random_waveforms() allow selection of a subset of waveforms. - """ def __init__( self, data_dir, - max_wf=256, trough_offset=42, - spike_length_samples=128, - num_channels=40, - wfs_dtype=np.float32 + **kwargs, ): self.data_dir = Path(data_dir) - self.max_wf = max_wf self.trough_offset = trough_offset - self.spike_length_samples = spike_length_samples - self.num_channels = num_channels - self.wfs_dtype = wfs_dtype - self.traces_fp = self.data_dir.joinpath("waveforms.traces.npy") self.templates_fp = self.data_dir.joinpath("waveforms.templates.npy") self.table_fp = self.data_dir.joinpath("waveforms.table.pqt") @@ -532,89 +507,92 @@ def __init__( assert self.table_fp.exists(), "waveforms.table.pqt file missing!" assert self.channels_fp.exists(), "waveforms.channels.npz file missing!" - # ingest parquet table - self.table = pd.read_parquet(self.table_fp).reset_index(drop=True).drop(columns=["index"]) - self.table["sample"] = self.table["sample"].astype("Int64") - self.table["peak_channel"] = self.table["peak_channel"].astype("Int64") - self.num_labels = self.table["cluster"].nunique() - self.labels = np.array(self.table["cluster"].unique()) - self.total_wfs = sum(~self.table["peak_channel"].isna()) - self.table["wf_number"] = np.tile(np.arange(self.max_wf), self.num_labels) - self.table["linear_index"] = np.arange(len(self.table)) - - traces_shape = (self.num_labels, max_wf, num_channels, spike_length_samples) - templates_shape = (self.num_labels, num_channels, spike_length_samples) + self.traces = np.lib.format.open_memmap(self.traces_fp) + self.df_wav = pd.read_parquet(self.table_fp).reset_index(drop=True).drop(columns=["index"]) + if len(self.traces.shape) == 4: + self.data_version = 1 + self.df_wav["sample"] = self.df_wav["sample"].astype('Int64') + self.df_wav["peak_channel"] = self.df_wav["peak_channel"].astype('Int64') + self.df_wav['waveform_index'] = np.arange(self.df_wav.shape[0], dtype=np.int64) + self.df_wav['index_within_cluster'] = np.tile(np.arange(self.traces.shape[1]), self.traces.shape[0]) + self.total_wfs = sum(~self.df_wav["peak_channel"].isna()) + else: + self.data_version = 2 - self.traces = np.lib.format.open_memmap(self.traces_fp, dtype=wfs_dtype, shape=traces_shape) - self.templates = np.lib.format.open_memmap(self.templates_fp, dtype=np.float32, shape=templates_shape) - self.channels = np.load(self.channels_fp, allow_pickle="True")["channels"] + self.df_clusters = aggregate_by_clusters(self.df_wav) + self.templates = np.lib.format.open_memmap(self.templates_fp, dtype=np.float32) + self.channels = np.load(self.channels_fp)["channels"] def __repr__(self): - s1 = f"WaveformsLoader with {self.total_wfs} waveforms in {self.wfs_dtype} from {self.num_labels} labels.\n" - s2 = f"Data path: {self.data_dir}\n" - s3 = f"{self.spike_length_samples} samples, {self.num_channels} channels, {self.max_wf} max waveforms per label\n" + return f""" + WaveformsLoader data version {self.data_version} + {self.nw:_} total waveforms {self.ns} samples, {self.nc} channels + {self.nu:_} units, {self.max_wf:_} max waveforms per label + dtype: {self.wfs_dtype} + data path: {self.data_dir} + """ - return s1 + s2 + s3 + @property + def max_wf(self): + return self.df_clusters['count'].max() @property - def wf_counts(self): - """ - pandas Series containing number of (non-NaN) waveforms for each label. - """ - return self.table.groupby("cluster").count()["sample"].rename("num_wfs") + def wfs_dtype(self): + return self.traces.dtype + + @property + def nu(self): + return self.df_clusters.shape[0] + + @property + def ns(self): + return self.traces.shape[-1] + + @property + def nc(self): + return self.traces.shape[-2] + + @property + def nw(self): + return self.df_wav.shape[0] def load_waveforms(self, labels=None, indices=None, return_info=True, flatten=False): """ Returns a specified subset of waveforms from the dataset. :param labels: (list, NumPy array) Label ids (usually clusters) from which to get waveforms. - :param indices: (list, NumPy array) Waveform indices to grab for each waveform. - Can be 1D in which case the same indices are returned for each waveform, or - 2D with first dimension == len(labels) to return a specific set of indices for - each waveform. + :param indices: (list, NumPy array) Waveform indices to grab for each waveform 1D. :param return_info: If True, returns waveforms, table, channels, where table is a DF containing information about the waveforms returned, and channels is the channel map for each waveform. :param flatten: If True, returns all waveforms stacked along dimension zero, otherwise returns array of shape (num_labels, num_indices_per_label, num_channels, spike_length_samples) - """ - if labels is None: - labels = self.labels - if indices is None: - indices = np.arange(self.max_wf) - - labels = np.array(labels) - label_idx = np.array([np.where(self.labels == label)[0][0] for label in labels]) - indices = np.array(indices) - - num_labels = labels.shape[0] - - if indices.ndim == 1: - indices = np.tile(indices, (num_labels, 1)) - - wfs = self.traces[label_idx[:, None], indices].astype(np.float32) - - if flatten: - wfs = wfs.reshape(-1, self.num_channels, self.spike_length_samples) - - info = self.table[self.table["cluster"].isin(labels)].copy() - dfs = [] - for i, l in enumerate(labels): - _idx = indices[i] - dfs.append(info[(info["wf_number"].isin(_idx)) & (info["cluster"] == l)]) - info = pd.concat(dfs).reset_index(drop=True) - - channels = self.channels[info["linear_index"].to_numpy()].astype(int) - + labels = np.array(self.df_clusters.index if labels is None else labels) + iw, _ = ismember(self.df_wav['cluster'], labels) + if self.data_version == 1: + indices = np.array(np.arange(self.max_wf) if indices is None else indices) + indices = np.tile(indices, (labels.size, 1)) if indices.ndim < 2 else indices + assert indices.shape[0] == labels.size, \ + "If indices is a 2D-array, the second dimension must match the number of clusters." + _, iu, _ = np.intersect1d(self.df_clusters.index, labels, return_indices=True) + assert iu.size == labels.size, "Not all labels found in dataset." + wfs = self.traces[iu[:, np.newaxis], indices].astype(np.float32) + if flatten: + wfs = wfs.reshape(-1, self.nc, self.ns) + elif self.data_version == 2: + if indices is not None: + iw = np.where(iw)[0] + iw = iw[self.df_wav.loc[iw, 'index_within_clusters'].isin(np.atleast_1d(np.array(indices)))] + wfs = self.traces[iw].astype(np.float32) + info = self.df_wav.loc[iw, :].copy() + channels = self.channels[iw].astype(int) n_nan = sum(info["sample"].isna()) if n_nan > 0: - logger.warning(f"{n_nan} NaN waveforms included in result.") + logger.info(f"{n_nan} NaN waveforms included in result.") if return_info: return wfs, info, channels - - logger.info("Use return_info=True and check the table for details.") - - return wfs + else: + return wfs def random_waveforms( self, diff --git a/src/ibldsp/waveforms.py b/src/ibldsp/waveforms.py index b05882d..edae543 100644 --- a/src/ibldsp/waveforms.py +++ b/src/ibldsp/waveforms.py @@ -6,7 +6,9 @@ import numpy as np import pandas as pd import matplotlib.pyplot as plt +import matplotlib as mpl import scipy + from ibldsp.utils import parabolic_max from ibldsp.fourier import fshift @@ -231,16 +233,24 @@ def find_tip_trough(arr_peak, arr_peak_real, df): return df, arr_peak -def plot_wiggle(wav, fs=1, ax=None, scalar=0.3, clip=1.5, **axkwargs): +def plot_wiggle(wav, fs=1, ax=None, scale=0.3, clip=10, fill_sign=-1, plot_kwargs=None, fill_kwargs=None): """ Displays a multi-trace waveform in a wiggle traces with negative amplitudes filled :param wav: (nchannels, nsamples) - :param axkwargs: keyword arguments to feed to ax.set() - :return: + :param fs: sampling rate + :param ax: axis to plot on + :param scale: waveform amplitude that will be displayed as one inter-trace: if scale = 20e-6 one intertrace will be 20uV + :param clip: maximum value for the traces + :param fill_sign: -1 for negative (default for spikes), 1 for positive + :param plot_kwargs: kwargs for the line plot + :param fill_kwargs: kwargs for the fill + :return: axis """ if ax is None: fig, ax = plt.subplots() + plot_kwargs = {'color': 'k', 'linewidth': 0.5} | (plot_kwargs or {}) + fill_kwargs = {'color': 'k', 'aa': True} | (fill_kwargs or {}) nc, ns = wav.shape vals = np.c_[wav, wav[:, :1] * np.nan].ravel() # flat view of the 2d array. vect = np.arange(vals.size).astype( @@ -255,22 +265,51 @@ def plot_wiggle(wav, fs=1, ax=None, scalar=0.3, clip=1.5, **axkwargs): m = (y2 - y1) / (x2 - x1) c = y1 - m * x1 # tack these values onto the end of the existing data - x = np.hstack([vals, np.zeros_like(c)]) * scalar + x = np.hstack([vals, np.zeros_like(c)]) / scale x = np.maximum(np.minimum(x, clip), -clip) y = np.hstack([vect, c]) # resort the data order = np.argsort(y) # shift from amplitudes to plotting coordinates x_shift, y = y[order].__divmod__(ns + 1) - ax.plot(y / fs, x[order] + x_shift + 1, 'k', linewidth=.5) - x[x > 0] = np.nan + print(plot_kwargs) + ax.plot(y / fs, x[order] + x_shift + 1, **plot_kwargs) + if fill_sign < 0: + x[x > 0] = np.nan + else: + x[x < 0] = np.nan x = x[order] + x_shift + 1 - ax.fill(y / fs, x, 'k', aa=True) - ax.set(xlim=[0, ns / fs], ylim=[0, nc], xlabel='sample', ylabel='trace') + ax.fill(y / fs, x, **fill_kwargs) + ax.set(xlim=[0, ns / fs], ylim=[0, nc]) plt.tight_layout() return ax +def double_wiggle(wav, fs=1, ax=None, colors=None, **kwargs): + """ + Double trouble: this wiggle colours both the negative and the postive values + :param wav: (nchannels, nsamples) + :param fs: sampling rate + :param ax: axis to plot on + :param scale: scale factor for the traces + :param clip: maximum value for the traces + :param fill_sign: -1 for negative (default for spikes), 1 for positive + :param plot_kwargs: kwargs for the line plot + :param fill_kwargs: kwargs for the fill + :return: + """ + if colors is None: + cmap = 'PuOr' + _cmap = mpl.colormaps.get_cmap(cmap) + colors = _cmap(np.linspace(0, 1, 256)) + colors = [colors[50], colors[-50]] + if ax is None: + fig, ax = plt.subplots() + plot_wiggle(wav, fs=fs / 1e3, ax=ax, plot_kwargs={'linewidth': 0}, fill_kwargs={'color': colors[0]}, **kwargs) + plot_wiggle(wav, fs=fs / 1e3, ax=ax, fill_sign=1, plot_kwargs={'linewidth': 0.5}, fill_kwargs={'color': colors[1]}, **kwargs) + return ax + + def plot_peaktiptrough(df, arr, ax, nth_wav=0, plot_grey=True, fs=30000): # Time axix nech, ntr = arr[nth_wav].shape diff --git a/src/neurodsp/__init__.py b/src/neurodsp/__init__.py deleted file mode 100644 index 99f527a..0000000 --- a/src/neurodsp/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -import ibldsp, sys -from warnings import warn - -sys.modules["neurodsp"] = ibldsp -warn( - "neurodsp has been renamed to ibldsp and the old name will be deprecated on 01-Oct-2024.", - FutureWarning, -) diff --git a/src/spikeglx.py b/src/spikeglx.py index f6a3c55..b7cfa3c 100644 --- a/src/spikeglx.py +++ b/src/spikeglx.py @@ -2,6 +2,8 @@ import logging from pathlib import Path import re +import shutil +import time import numpy as np @@ -14,7 +16,7 @@ SAMPLE_SIZE = 2 # int16 DEFAULT_BATCH_SIZE = 1e6 -_logger = logging.getLogger("ibllib") +_logger = logging.getLogger(__name__) def _get_companion_file(sglx_file, pattern='.meta'): @@ -374,6 +376,27 @@ def compress_file(self, keep_original=True, **kwargs): self.file_bin = file_out return file_out + def decompress_to_scratch(self, scratch_dir=None): + """ + Decompresses the file to a temporary directory + Copy over the metadata file + """ + if scratch_dir is None: + bin_file = Path(self.file_bin).with_suffix('.bin') + else: + scratch_dir.mkdir(exist_ok=True, parents=True) + bin_file = Path(scratch_dir).joinpath(self.file_bin.name).with_suffix('.bin') + shutil.copy(self.file_meta_data, bin_file.with_suffix('.meta')) + if not bin_file.exists(): + t0 = time.time() + _logger.info('File is compressed, decompressing to a temporary file...') + self.decompress_file( + keep_original=True, out=bin_file.with_suffix('.bin_temp'), check_after_decompress=False, overwrite=True + ) + shutil.move(bin_file.with_suffix('.bin_temp'), bin_file) + _logger.info(f"Decompression complete: {time.time() - t0:.2f}s") + return bin_file + def decompress_file(self, keep_original=True, **kwargs): """ Decompresses a mtscomp file diff --git a/src/tests/unit/cpu/test_ibldsp.py b/src/tests/unit/cpu/test_ibldsp.py index a8214c9..16e4391 100644 --- a/src/tests/unit/cpu/test_ibldsp.py +++ b/src/tests/unit/cpu/test_ibldsp.py @@ -356,6 +356,22 @@ def test_firstlast_slices(self): my_rms_[wg.iw] = utils.rms(my_sig[sl]) self.assertTrue(np.all(my_rms_ == my_rms)) + def test_firstlast_splicing(self): + sig_in = np.random.randn(600) + sig_out = np.zeros_like(sig_in) + wg = utils.WindowGenerator(ns=600, nswin=100, overlap=20) + for first, last, amp in wg.firstlast_splicing: + sig_out[first:last] = sig_out[first:last] + amp * sig_in[first:last] + np.testing.assert_allclose(sig_out, sig_in) + + def test_firstlast_valid(self): + sig_in = np.random.randn(600) + sig_out = np.zeros_like(sig_in) + wg = utils.WindowGenerator(ns=600, nswin=100, overlap=20) + for first, last, first_valid, last_valid in wg.firstlast_valid: + sig_out[first_valid:last_valid] = sig_in[first_valid:last_valid] + np.testing.assert_array_equal(sig_out, sig_in) + def test_tscale(self): wg = utils.WindowGenerator(ns=500, nswin=100, overlap=50) ts = wg.tscale(fs=1000) @@ -626,25 +642,3 @@ def test_compute_features(self): self.assertEqual(multi_index, list(df.index)) self.assertEqual(["snippet_id", "channel_id"], list(df.index.names)) self.assertEqual(num_snippets * (self.nc - 1), len(df)) - - -class TestNameDeprecationDate(unittest.TestCase): - def test_neurodsp_import(self): - # Check that the old import still works and gives the same package. - # (ibldsp.voltage is imported at the top of this file.) - with self.assertWarnsRegex(FutureWarning, "01-Oct-2024"): - import neurodsp - self.assertEqual(neurodsp.voltage, voltage) - - def test_deprecation_countdown(self): - # Fail on 01-Sep-2024, when `neurodsp` will be retired. - # When this test fails, remove the entire dummy - # `neurodsp` package at the top level of the ibl-neuropixel - # repository - import datetime - - if datetime.datetime.now() > datetime.datetime(2024, 10, 1): - raise NotImplementedError( - "neurodsp will not longer be supported. " - "Change all references to ibldsp." - ) diff --git a/src/tests/unit/cpu/test_waveforms.py b/src/tests/unit/cpu/test_waveforms.py index 4638788..d77dba6 100644 --- a/src/tests/unit/cpu/test_waveforms.py +++ b/src/tests/unit/cpu/test_waveforms.py @@ -1,9 +1,12 @@ from pathlib import Path +import shutil +import tempfile +import unittest import numpy as np import pandas as pd -import tempfile -import shutil +import matplotlib.pyplot as plt +import scipy import ibldsp.utils as utils import ibldsp.waveforms as waveforms @@ -11,9 +14,7 @@ from neurowaveforms.model import generate_waveform from neuropixel import trace_header from ibldsp.fourier import fshift -import scipy -import unittest TEST_PATH = Path(__file__).parent.joinpath("fixtures") @@ -188,6 +189,7 @@ class TestWaveformExtractorArray(unittest.TestCase): channel_neighbors = utils.make_channel_index(geom, radius=200.0) # radius = 200um, 38 chans num_channels = 40 + arr = arr.T def test_extract_waveforms_array(self): wfs, _, _ = waveform_extraction.extract_wfs_array( @@ -307,7 +309,7 @@ class TestWaveformExtractorBin(unittest.TestCase): max_wf = 25 # 2 clusters - spike_samples = np.repeat(np.arange(0, ns, 1600), 2) # 50 spikes + spike_samples = np.repeat(np.arange(0, ns, 1600), 2) # 50 spikes, but 2 of them are on 0 sample spike_channels = np.tile(np.array([100, 368]), 25) spike_clusters = np.tile(np.array([1, 2]), 25) @@ -327,19 +329,20 @@ def tearDown(self): shutil.rmtree(self.tmpdir) def _ground_truth_values(self): - + # here we have to hard-code 48 and 24 because the 2 first spikes are rejected since on sample 0 nc_extract = self.chan_map.shape[1] gt_templates = np.ones((self.n_clusters, nc_extract, self.ns_extract), np.float32) * np.nan - gt_waveforms = np.ones((self.n_clusters, self.max_wf, nc_extract, self.ns_extract), np.float32) * np.nan + gt_waveforms = np.ones((48, nc_extract, self.ns_extract), np.float32) * np.nan c0_chans = self.chan_map[100].astype(np.float32) gt_templates[0, :, :] = np.tile(c0_chans, (self.ns_extract, 1)).T - gt_waveforms[0, :self.max_wf - 1, :, :] = np.tile(c0_chans, (self.max_wf - 1, self.ns_extract, 1)).swapaxes(1, 2) + + gt_waveforms[:24, :, :] = gt_templates[0] c1_chans = self.chan_map[368].astype(np.float32) c1_chans[c1_chans == 384] = np.nan gt_templates[1, :, :] = np.tile(c1_chans, (self.ns_extract, 1)).T - gt_waveforms[1, :self.max_wf - 1, :, :] = np.tile(c1_chans, (self.max_wf - 1, self.ns_extract, 1)).swapaxes(1, 2) + gt_waveforms[24:, :, :] = gt_templates[1] return gt_templates, gt_waveforms @@ -357,16 +360,20 @@ def test_extract_waveforms_bin(self): ) templates = np.load(self.tmpdir.joinpath("waveforms.templates.npy")) waveforms = np.load(self.tmpdir.joinpath("waveforms.traces.npy")) + table = pd.read_parquet(self.tmpdir.joinpath("waveforms.table.pqt")) - for u in [0, 1]: - assert np.allclose(np.nan_to_num(templates[u]), np.nanmedian(waveforms[u], axis=0)) + cluster_ids = table.cluster.unique() + + for i, u in enumerate(cluster_ids): + inds = table[table.cluster == u].waveform_index.to_numpy() + assert np.allclose(templates[i], np.nanmedian(waveforms[inds], axis=0), equal_nan=True) gt_templates, gt_waveforms = self._ground_truth_values() assert np.allclose(np.nan_to_num(gt_templates), np.nan_to_num(templates)) assert np.allclose(np.nan_to_num(gt_waveforms), np.nan_to_num(waveforms)) - wfl = waveform_extraction.WaveformsLoader(self.tmpdir, max_wf=self.max_wf) + wfl = waveform_extraction.WaveformsLoader(self.tmpdir) wfs = wfl.load_waveforms(return_info=False) assert np.allclose(np.nan_to_num(waveforms), np.nan_to_num(wfs)) @@ -374,16 +381,20 @@ def test_extract_waveforms_bin(self): labels = np.array([1, 2]) indices = np.arange(10) + # test the waveform loader wfs, info, channels = wfl.load_waveforms(labels=labels, indices=indices) + # right waveforms - assert np.allclose(np.nan_to_num(waveforms[:, :10]), np.nan_to_num(wfs)) + assert np.allclose(np.nan_to_num(waveforms[:10, :]), np.nan_to_num(wfs[info['cluster'] == 1, :, :])) + assert np.allclose(np.nan_to_num(waveforms[25:35, :]), np.nan_to_num(wfs[info['cluster'] == 2, :, :])) # right channels assert np.all(channels == self.chan_map[info.peak_channel.astype(int).to_numpy()]) - wfs, info, channels = wfl.load_waveforms(labels=labels, indices=np.array([[1, 2, 3], [5, 6, 7]])) - # right waveforms - assert np.allclose(np.nan_to_num(waveforms[0, [1, 2, 3]]), np.nan_to_num(wfs[0])) - assert np.allclose(np.nan_to_num(waveforms[1, [5, 6, 7]]), np.nan_to_num(wfs[1])) - # right channels - assert np.all(channels == self.chan_map[info.peak_channel.astype(int).to_numpy()]) +def test_wiggle(): + wav = generate_waveform() + wav = wav / np.max(np.abs(wav)) * 120 * 1e-6 + fig, ax = plt.subplots(1, 2) + waveforms.plot_wiggle(wav, scale=40 * 1e-6, ax=ax[0]) + waveforms.double_wiggle(wav, scale=40 * 1e-6, fs=30_000, ax=ax[1]) + plt.close('all')