diff --git a/release_notes.md b/release_notes.md index 07f5073..eddb2c1 100644 --- a/release_notes.md +++ b/release_notes.md @@ -1,4 +1,6 @@ # 0.10.0 +## 0.10.3 2024-04-18 +- Patch fixing memory leaks for `waveform_extraction` module. ## 0.10.2 2024-04-10 - Add `waveform_extraction` module to `ibldsp`. This includes the `extract_wfs_array` and `extract_wfs_cbin` methods. - Add code for performing subsample shifts of waveforms. diff --git a/setup.py b/setup.py index edb0434..4668246 100644 --- a/setup.py +++ b/setup.py @@ -8,7 +8,7 @@ setuptools.setup( name="ibl-neuropixel", - version="0.10.2", + version="0.10.3", author="The International Brain Laboratory", description="Collection of tools for Neuropixel 1.0 and 2.0 probes data", long_description=long_description, diff --git a/src/ibldsp/waveform_extraction.py b/src/ibldsp/waveform_extraction.py index 2c0e66d..c026dd9 100644 --- a/src/ibldsp/waveform_extraction.py +++ b/src/ibldsp/waveform_extraction.py @@ -1,16 +1,19 @@ +import logging + import scipy import pandas as pd import numpy as np from numpy.lib.format import open_memmap -import neuropixel -import spikeglx - from joblib import Parallel, delayed, cpu_count +import neuropixel +import spikeglx from ibldsp.voltage import detect_bad_channels, interpolate_bad_channels, car from ibldsp.fourier import fshift from ibldsp.utils import make_channel_index +logger = logging.getLogger(__name__) + def extract_wfs_array( arr, @@ -83,7 +86,12 @@ def _get_channel_labels(sr, num_snippets=20, verbose=True): if verbose: from tqdm import trange - start = (np.linspace(100, int(sr.rl) - 100, num_snippets) * sr.fs).astype(int) + # for most of recordings we take 100 secs left and right but account for recordings smaller + buffer_left_right = np.minimum(100, sr.rl * 0.03) + start = ( + np.linspace(buffer_left_right, int(sr.rl) - buffer_left_right, num_snippets) + * sr.fs + ).astype(int) end = start + int(sr.fs) _channel_labels = np.zeros((384, num_snippets), int) @@ -101,10 +109,9 @@ def _get_channel_labels(sr, num_snippets=20, verbose=True): def _make_wfs_table( sr, - spike_times, + spike_samples, spike_clusters, spike_channels, - chunksize_t=10, max_wf=256, trough_offset=42, spike_length_samples=128, @@ -118,8 +125,8 @@ def _make_wfs_table( """ # exclude spikes without a buffer on either end # of recording - allowed_idx = (spike_times > trough_offset) & ( - spike_times < sr.ns - (spike_length_samples - trough_offset) + allowed_idx = (spike_samples > trough_offset) & ( + spike_samples < sr.ns - (spike_length_samples - trough_offset) ) rng = np.random.default_rng(seed=2024) # numpy 1.23.5 @@ -136,7 +143,7 @@ def _make_wfs_table( nspikes = u_spikeidx.shape[0] unit_nspikes[i] = nspikes # uniformly select up to 500 spikes - u_wf_idx = rng.choice(u_spikeidx, min(max_wf, nspikes)) + u_wf_idx = rng.choice(u_spikeidx, min(max_wf, nspikes), replace=False) unit_wf_idx[u, : min(max_wf, nspikes)] = u_wf_idx # all wf indices in order @@ -145,13 +152,12 @@ def _make_wfs_table( wf_idx = wf_idx[np.nonzero(wf_idx)[0][0]:] # get sample times, clusters, channels - wf_flat = pd.DataFrame( { - "indices": np.arange(wf_idx.shape[0]), - "samples": spike_times[wf_idx].astype(int), - "clusters": spike_clusters[wf_idx].astype(int), - "channels": spike_channels[wf_idx].astype(int), + "index": np.arange(wf_idx.shape[0]), + "sample": spike_samples[wf_idx].astype(int), + "cluster": spike_clusters[wf_idx].astype(int), + "peak_channel": spike_channels[wf_idx].astype(int), } ) @@ -176,6 +182,9 @@ def write_wfs_chunk( Parallel job to extract waveforms from chunk `i_chunk` of a recording `sr` and write them to the correct spot in the output .npy file `wfs_fn`. """ + if len(wf_flat) == 0: + return + my_sr = spikeglx.Reader(cbin) s0, s1 = sr_sl @@ -197,13 +206,13 @@ def write_wfs_chunk( else: offset = trough_offset - sample = wf_flat["samples"].astype(int) + offset - i_chunk * chunksize_samples - peak_channel = wf_flat["channels"] + sample = wf_flat["sample"].astype(int) + offset - i_chunk * chunksize_samples + peak_channel = wf_flat["peak_channel"] df = pd.DataFrame({"sample": sample, "peak_channel": peak_channel}) snip = my_sr[ - s0 - offset: s1 + spike_length_samples - trough_offset, : -my_sr.nsync + s0 - offset:s1 + spike_length_samples - trough_offset, :-my_sr.nsync ] snip0 = interpolate_bad_channels( fshift( @@ -216,7 +225,7 @@ def write_wfs_chunk( # car snip1 = np.full((my_sr.nc, snip0.shape[1]), np.nan) snip1[:-1, :] = car_func(snip0) - wfs_mmap[wf_flat["indices"], :, :] = extract_wfs_array( + wfs_mmap[wf_flat["index"], :, :] = extract_wfs_array( snip1.T, df, channel_neighbors )[0] wfs_mmap.flush() @@ -224,90 +233,101 @@ def write_wfs_chunk( def extract_wfs_cbin( cbin_file, - output_file, - spike_times, + output_dir, + spike_samples, spike_clusters, spike_channels, h=None, - wf_extract_params=None, - nprocesses=None, + channel_labels=None, + max_wf=256, + trough_offset=42, + spike_length_samples=128, + chunksize_samples=int(3000), + n_jobs=None, ): """ Given a cbin file and locations of spikes, extract waveforms for each unit, compute - the templates, and save to `output_file`. - - If `output_file=Path("/path/to/example_clusters.npy")`, this array will be of shape - `(num_units, max_wf, nc, spike_length_samples)` where by default `max_wf=256, nc=40, - spike_length_samples=128`. - - The file "path/to/example_clusters_templates.npy" will also be generated, of shape - `(num_units, nc, spike_length_samples)`, where the median across waveforms is taken - for each unit. - - The parquet file "path/to/example_clusters.pqt" contains the samples and max channels - of each waveform, indexed by unit. + the templates, and save the results in `output_path`. The waveforms come from chunks + of raw data which are phase-corrected to account for the ADC, high-pass filtered in + time with an order 3 Butterworth filter with a 300Hz cutoff, and a common-average + reference procedure is applied in the spatial dimension. + + The following files will be generated: + - waveforms.traces.npy: `(num_units, max_wf, nc, spike_length_samples)` + This file contains the lightly processed waveforms indexed by cluster in the first + dimension. By default `max_wf=256, nc=40, spike_length_samples=128`. + + - waveforms.templates.npy: `(num_units, nc, spike_length_samples)` + This file contains the median across individual waveforms for each unit. + + - waveforms.channels.npz: `(num_units * max_wf, nc)` + The i'th row contains the ordered indices of the `nc`-channel neighborhood used + to extract the i'th waveform. A NaN means the waveform is missing because the + unit it was supposed to come from has less than `max_wf` spikes total in the + recording. + + - waveforms.table.pqt: `num_units * max_wf` rows + For each waveform, gives the absolute sample number from the recording (i.e. + where to find it in `spikes.samples`), peak channel, cluster, and linear index. + A row of -1s implies that the waveform is missing because the unit is was supposed + to come from has less than `max_wf` spikes total. """ if h is None: h = neuropixel.trace_header() - if wf_extract_params is None: - wf_extract_params = { - "max_wf": 256, - "trough_offset": 42, - "spike_length_samples": 128, - "chunksize_t": 10, - } - - output_path = output_file.parent - - max_wf = wf_extract_params["max_wf"] - trough_offset = wf_extract_params["trough_offset"] - spike_length_samples = wf_extract_params["spike_length_samples"] - chunksize_t = wf_extract_params["chunksize_t"] + n_jobs = n_jobs or int(cpu_count() / 2) sr = spikeglx.Reader(cbin_file) - chunksize_samples = chunksize_t * 30_000 s0_arr = np.arange(0, sr.ns, chunksize_samples) s1_arr = s0_arr + chunksize_samples s1_arr[-1] = sr.ns + # selects spikes from throughout the recording for each unit wf_flat, unit_ids = _make_wfs_table( - sr, spike_times, spike_clusters, spike_channels, **wf_extract_params + sr, + spike_samples, + spike_clusters, + spike_channels, + max_wf, + trough_offset, + spike_length_samples, ) num_chunks = s0_arr.shape[0] - print(f"Chunk size: {chunksize_t}") - print(f"Num chunks: {num_chunks}") - print("Running channel detection") - channel_labels = _get_channel_labels(sr) + logger.info(f"Chunk size samples: {chunksize_samples}") + logger.info(f"Num chunks: {num_chunks}") + + logger.info("Running channel detection") + if channel_labels is None: + channel_labels = _get_channel_labels(sr) - nwf = wf_flat["samples"].shape[0] + nwf = len(wf_flat) nu = unit_ids.shape[0] - print(f"Extracting {nwf} waveforms from {nu} units") + logger.info(f"Extracting {nwf} waveforms from {nu} units") # get channel geometry geom = np.c_[h["x"], h["y"]] channel_neighbors = make_channel_index(geom) nc = channel_neighbors.shape[1] - fn = output_path.joinpath("_wf_extract_intermediate.npy") + # this intermediate memmap is written to in parallel + # the waveforms are ordered only by their chronological position + # in the recording, as we are reading them in time chunks + int_fn = output_dir.joinpath("_wf_extract_intermediate.npy") wfs = open_memmap( - fn, mode="w+", shape=(nwf, nc, spike_length_samples), dtype=np.float32 + int_fn, mode="w+", shape=(nwf, nc, spike_length_samples), dtype=np.float32 ) slices = [ - slice( - *(np.searchsorted(wf_flat["samples"], [s0_arr[i], s1_arr[i]]).astype(int)) - ) + slice(*(np.searchsorted(wf_flat["sample"], [s0_arr[i], s1_arr[i]]).astype(int))) for i in range(num_chunks) ] - nprocesses = nprocesses or int(cpu_count() - cpu_count() / 4) - _ = Parallel(n_jobs=nprocesses)( + _ = Parallel(n_jobs=n_jobs)( delayed(write_wfs_chunk)( i, cbin_file, - fn, + int_fn, wfs.shape, h, channel_labels, @@ -321,34 +341,81 @@ def extract_wfs_cbin( for i in range(num_chunks) ) - wfs = open_memmap( - fn, mode="r+", shape=(nwf, nc, spike_length_samples), dtype=np.float32 - ) - # bookkeeping - wfs_by_unit = np.full( - (nu, max_wf, nc, spike_length_samples), np.nan, dtype=np.float16 + # output files + traces_fn = output_dir.joinpath("waveforms.traces.npy") + templates_fn = output_dir.joinpath("waveforms.templates.npy") + table_fn = output_dir.joinpath("waveforms.table.pqt") + channels_fn = output_dir.joinpath("waveforms.channels.npz") + + ## rearrange and save traces by unit + # store medians across waveforms + wfs_templates = np.full((nu, nc, spike_length_samples), np.nan, dtype=np.float32) + # create waveform output file (~2-3 GB) + traces_by_unit = open_memmap( + traces_fn, + mode="w+", + shape=(nu, max_wf, nc, spike_length_samples), + dtype=np.float16, ) - wfs_medians = np.full((nu, nc, spike_length_samples), np.nan, dtype=np.float32) - print("Computing templates") - for i, u in enumerate(unit_ids): - _wfs_unit = wfs[wf_flat["clusters"] == u] - nwf_u = _wfs_unit.shape[0] - wfs_by_unit[i, : min(max_wf, nwf_u), :, :] = _wfs_unit.astype(np.float16) - wfs_medians[i, :, :] = np.nanmedian(_wfs_unit, axis=0) + logger.info("Writing to output files") - df = pd.DataFrame( + for i, u in enumerate(unit_ids): + idx = np.where(wf_flat["cluster"] == u)[0] + nwf_u = idx.shape[0] + # reopening these memmaps on each iteration + # forces Python to clean up each large array it loads + # and prevent a memory leak + wfs = open_memmap( + int_fn, mode="r+", shape=(nwf, nc, spike_length_samples), dtype=np.float32 + ) + traces_by_unit = open_memmap( + traces_fn, + mode="r+", + shape=(nu, max_wf, nc, spike_length_samples), + dtype=np.float16, + ) + # write up to 256 waveforms and leave the rest of dimensions 1-3 as NaNs + traces_by_unit[i, : min(max_wf, nwf_u), :, :] = wfs[idx].astype(np.float16) + traces_by_unit.flush() + # populate this array in memory as it's 256x smaller + wfs_templates[i, :, :] = np.nanmedian(wfs[idx], axis=0) + + # cleanup intermediate file + int_fn.unlink() + + # save templates + np.save(templates_fn, wfs_templates) + + # add in dummy rows and order by unit, and then sample + unit_counts = wf_flat.groupby("cluster")["sample"].count().reset_index(name="count") + unit_counts["missing"] = 256 - unit_counts["count"] + missing_wf = unit_counts[unit_counts["missing"] > 0] + total_missing = sum(missing_wf.missing) + extra_rows = pd.DataFrame( { - "sample": wf_flat["samples"], - "peak_channel": wf_flat["channels"], - "cluster": wf_flat["clusters"], + "sample": [np.nan] * total_missing, + "peak_channel": [np.nan] * total_missing, + "index": [np.nan] * total_missing, + "cluster": sum( + [[row["cluster"]] * row["missing"] for _, row in missing_wf.iterrows()], + [], + ), } ) - df = df.sort_values(["cluster", "sample"]).set_index(["cluster", "sample"]) - - np.save(output_file, wfs_by_unit) - # medians - avg_file = output_file.parent.joinpath(output_file.stem + "_templates.npy") - np.save(avg_file, wfs_medians) - df.to_parquet(output_file.with_suffix(".pqt")) - - fn.unlink() + save_df = pd.concat([wf_flat, extra_rows]) + # now the waveforms are arranged by cluster, and then in time + # these match dimensions 0 and 1 of waveforms.traces.npy + save_df.sort_values(["cluster", "sample"], inplace=True) + save_df.to_parquet(table_fn) + + # save channel map for each waveform + # these values are now reordered so that they match the pqt + # and the traces file + peak_channel = np.nan_to_num(save_df["peak_channel"].to_numpy(), nan=-1).astype( + np.int16 + ) + dummy_idx = np.where(peak_channel >= 0)[0] + # leave "missing" waveforms as -1 since we can't have NaN with int dtype + chan_map = np.ones((max_wf * nu, nc), np.int16) * -1 + chan_map[dummy_idx] = channel_neighbors[peak_channel[dummy_idx].astype(int)] + np.savez(channels_fn, channels=chan_map)