diff --git a/brainbox/atlas.py b/brainbox/atlas.py deleted file mode 100644 index 8c28feb89..000000000 --- a/brainbox/atlas.py +++ /dev/null @@ -1,161 +0,0 @@ -""" -Functions which map metrics to the Allen atlas. - -Code by G. Meijer -""" - -import numpy as np -import seaborn as sns -import matplotlib.pyplot as plt -from iblatlas import atlas - - -def _label2values(imlabel, fill_values, ba): - """ - Fills a slice from the label volume with values to display - :param imlabel: 2D np-array containing label ids (slice of the label volume) - :param fill_values: 1D np-array containing values to fill into the slice - :return: 2D np-array filled with values - """ - im_unique, ilabels, iim = np.unique(imlabel, return_index=True, return_inverse=True) - _, ir_unique, _ = np.intersect1d(ba.regions.id, im_unique, return_indices=True) - im = np.squeeze(np.reshape(fill_values[ir_unique[iim]], (*imlabel.shape, 1))) - return im - - -def plot_atlas(regions, values, ML=-1, AP=0, DV=-1, hemisphere='left', color_palette='Reds', - minmax=None, axs=None, custom_region_list=None): - """ - Plot a sagittal, coronal and horizontal slice of the Allen atlas with regions colored in - according to any value that the user specifies. - - Parameters - ---------- - regions : 1D array - Array of strings with the acronyms of brain regions (in Allen convention) that should be - filled with color - values : 1D array - Array of values that correspond to the brain region acronyms - ML, AP, DV : float - The coordinates of the slices in mm - hemisphere : string - Which hemisphere to color, options are 'left' (default), 'right', 'both' - color_palette : any input that can be interpreted by sns.color_palette - The color palette of the plot - minmax : 2 element array - The min and max of the color map, if None it uses the min and max of values - axs : 3 element list of axis - A list of the three axis in which to plot the three slices - custom_region_list : 1D array with shape the same as ba.regions.acronym.shape - Input any custom list of acronyms that replaces the default list of acronyms - found in ba.regions.acronym. For example if you want to merge certain regions you can - give them the same name in the custom_region_list - """ - - # Import Allen atlas - ba = atlas.AllenAtlas(25) - - # Check input - assert regions.shape == values.shape - if minmax is not None: - assert len(minmax) == 2 - if axs is not None: - assert len(axs) == 3 - if custom_region_list is not None: - assert custom_region_list.shape == ba.regions.acronym.shape - - # Get region boundaries volume - boundaries = np.diff(ba.label, axis=0, append=0) - boundaries = boundaries + np.diff(ba.label, axis=1, append=0) - boundaries = boundaries + np.diff(ba.label, axis=2, append=0) - boundaries[boundaries != 0] = 1 - - # Get all brain region names, use custom list if inputted - if custom_region_list is None: - all_regions = ba.regions.acronym - else: - all_regions = custom_region_list - - # Set values outside colormap bounds - if minmax is not None: - values[values < minmax[0] + np.abs(minmax[0] / 1000)] = (minmax[0] - + np.abs(minmax[0] / 1000)) - values[values > minmax[1] - np.abs(minmax[1] / 1000)] = (minmax[1] - - np.abs(minmax[0] / 1000)) - - # Add values to brain region list - region_values = np.ones(ba.regions.acronym.shape) * (np.min(values) - (np.max(values) + 1)) - for i, region in enumerate(regions): - region_values[all_regions == region] = values[i] - - # Set 'void' to default white - region_values[0] = np.min(values) - (np.max(values) + 1) - - # Get slices with fill values - slice_sag = ba.slice(ML / 1000, axis=0, volume=ba.label) # saggital - slice_sag = _label2values(slice_sag, region_values, ba) - bound_sag = ba.slice(ML / 1000, axis=0, volume=boundaries) - slice_cor = ba.slice(AP / 1000, axis=1, volume=ba.label) # coronal - slice_cor = _label2values(slice_cor, region_values, ba) - bound_cor = ba.slice(AP / 1000, axis=1, volume=boundaries) - slice_hor = ba.slice(DV / 1000, axis=2, volume=ba.label) # horizontal - slice_hor = _label2values(slice_hor, region_values, ba) - bound_hor = ba.slice(DV / 1000, axis=2, volume=boundaries) - - # Only color specified hemisphere - if hemisphere == 'left': - slice_cor[:int(slice_cor.shape[0] / 2), :] = np.min(values) - (np.max(values) + 1) - slice_hor[:, int(slice_cor.shape[0] / 2):] = np.min(values) - (np.max(values) + 1) - elif hemisphere == 'right': - slice_cor[int(slice_cor.shape[0] / 2):, :] = np.min(values) - (np.max(values) + 1) - slice_hor[:, :int(slice_cor.shape[0] / 2)] = np.min(values) - (np.max(values) + 1) - if ((hemisphere == 'left') & (ML > 0)) or ((hemisphere == 'right') & (ML < 0)): - slice_sag[:] = np.min(values) - (np.max(values) + 1) - - # Add boundaries to slices outside of the fill value region and set to grey - if minmax is None: - slice_sag[bound_sag == 1] = np.max(values) + 1 - slice_cor[bound_cor == 1] = np.max(values) + 1 - slice_hor[bound_hor == 1] = np.max(values) + 1 - else: - slice_sag[bound_sag == 1] = minmax[1] + 1 - slice_cor[bound_cor == 1] = minmax[1] + 1 - slice_hor[bound_hor == 1] = minmax[1] + 1 - - # Construct color map - color_map = sns.color_palette(color_palette, 1000) - color_map.append((0.8, 0.8, 0.8)) # color of the boundaries between regions - color_map.insert(0, (1, 1, 1)) # color of the background and regions without a value - - # Get color scale - if minmax is None: - cmin = np.min(values) - cmax = np.max(values) - else: - cmin = minmax[0] - cmax = minmax[1] - - # Plot - if axs is None: - fig, axs = plt.subplots(1, 3, figsize=(16, 4)) - - # Saggital - sns.heatmap(np.rot90(slice_sag, 3), cmap=color_map, cbar=True, vmin=cmin, vmax=cmax, ax=axs[0]) - axs[0].set(title='ML: %.1f mm' % ML) - plt.axis('off') - axs[0].get_xaxis().set_visible(False) - axs[0].get_yaxis().set_visible(False) - - # Coronal - sns.heatmap(np.rot90(slice_cor, 3), cmap=color_map, cbar=True, vmin=cmin, vmax=cmax, ax=axs[1]) - axs[1].set(title='AP: %.1f mm' % AP) - plt.axis('off') - axs[1].get_xaxis().set_visible(False) - axs[1].get_yaxis().set_visible(False) - - # Horizontal - sns.heatmap(np.rot90(slice_hor, 3), cmap=color_map, cbar=True, vmin=cmin, vmax=cmax, ax=axs[2]) - axs[2].set(title='DV: %.1f mm' % DV) - plt.axis('off') - axs[2].get_xaxis().set_visible(False) - axs[2].get_yaxis().set_visible(False) diff --git a/brainbox/lfp.py b/brainbox/lfp.py deleted file mode 100644 index 8b407d2f3..000000000 --- a/brainbox/lfp.py +++ /dev/null @@ -1,114 +0,0 @@ -# -*- coding: utf-8 -*- -""" -Functions to analyse LFP signals. - -@author: Guido Meijer -Created on Fri Mar 13 14:57:53 2020 -""" - -from scipy.signal import welch, csd, filtfilt, butter -import numpy as np - - -def butter_filter(signal, highpass_freq=None, lowpass_freq=None, order=4, fs=2500): - - # The filter type is determined according to the values of cut-off frequencies - Fn = fs / 2. - if lowpass_freq and highpass_freq: - if highpass_freq < lowpass_freq: - Wn = (highpass_freq / Fn, lowpass_freq / Fn) - btype = 'bandpass' - else: - Wn = (lowpass_freq / Fn, highpass_freq / Fn) - btype = 'bandstop' - elif lowpass_freq: - Wn = lowpass_freq / Fn - btype = 'lowpass' - elif highpass_freq: - Wn = highpass_freq / Fn - btype = 'highpass' - else: - raise ValueError("Either highpass_freq or lowpass_freq must be given") - - # Filter signal - b, a = butter(order, Wn, btype=btype, output='ba') - filtered_data = filtfilt(b=b, a=a, x=signal, axis=1) - - return filtered_data - - -def power_spectrum(signal, fs=2500, segment_length=0.5, segment_overlap=0.5, scaling='density'): - """ - Calculate the power spectrum of an LFP signal - - Parameters - ---------- - signal : 2D array - LFP signal from different channels in V with dimensions (channels X samples) - fs : int - Sampling frequency - segment_length : float - Length of the segments for which the spectral density is calcualted in seconds - segment_overlap : float - Fraction of overlap between the segments represented as a float number between 0 (no - overlap) and 1 (complete overlap) - - Returns - ---------- - freqs : 1D array - Frequencies for which the spectral density is calculated - psd : 2D array - Power spectrum in V^2 with dimensions (channels X frequencies) - - """ - - # Transform segment from seconds to samples - segment_samples = int(fs * segment_length) - overlap_samples = int(segment_overlap * segment_samples) - - # Calculate power spectrum - freqs, psd = welch(signal, fs=fs, nperseg=segment_samples, noverlap=overlap_samples, - scaling=scaling) - return freqs, psd - - -def coherence(signal_a, signal_b, fs=2500, segment_length=1, segment_overlap=0.5): - """ - Calculate the coherence between two LFP signals - - Parameters - ---------- - signal_a : 1D array - LFP signal from different channels with dimensions (channels X samples) - fs : int - Sampling frequency - segment_length : float - Length of the segments for which the spectral density is calcualted in seconds - segment_overlap : float - Fraction of overlap between the segments represented as a float number between 0 (no - overlap) and 1 (complete overlap) - - Returns - ---------- - freqs : 1D array - Frequencies for which the coherence is calculated - coherence : 1D array - Coherence takes a value between 0 and 1, with 0 or 1 representing no or perfect coherence, - respectively - phase_lag : 1D array - Estimate of phase lag in radian between the input time series for each frequency - - """ - - # Transform segment from seconds to samples - segment_samples = int(fs * segment_length) - overlap_samples = int(segment_overlap * segment_samples) - - # Calculate coherence - freqs, Pxx = welch(signal_a, fs=fs, nperseg=segment_samples, noverlap=overlap_samples) - _, Pyy = welch(signal_b, fs=fs, nperseg=segment_samples, noverlap=overlap_samples) - _, Pxy = csd(signal_a, signal_b, fs=fs, nperseg=segment_samples, noverlap=overlap_samples) - coherence = np.abs(Pxy) ** 2 / (Pxx * Pyy) - phase_lag = np.angle(Pxy) - - return freqs, coherence, phase_lag diff --git a/brainbox/quality/__init__.py b/brainbox/quality/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/brainbox/quality/lfp_qc.py b/brainbox/quality/lfp_qc.py deleted file mode 100644 index 2dc50d0e3..000000000 --- a/brainbox/quality/lfp_qc.py +++ /dev/null @@ -1,115 +0,0 @@ -""" -Quality control on LFP -Based on code by Olivier Winter: -https://github.com/int-brain-lab/ibllib/blob/master/examples/ibllib/ephys_qc_raw.py - -adapted by Anne Urai -Instructions https://docs.google.com/document/d/1lBNcssodWdBILBN0PrPWi0te4f6I8H_Bb-7ESxQdv5U/edit# -""" - -from pathlib import Path - -import sys -import glob -import os -import matplotlib.pyplot as plt -import numpy as np -import seaborn as sns - - -from ibllib.ephys import ephysqc -import one.alf.io as alfio -# from IPython import embed as shell - - -def _plot_spectra(outpath, typ, savefig=True): - ''' - TODO document this function - ''' - - spec = alfio.load_object(outpath, 'ephysQcFreq' + typ.upper(), namespace='spikeglx') - - # hack to ensure a single key name - if 'power.probe_00' in spec.keys(): - spec['power'] = spec.pop('power.probe_00') - spec['freq'] = spec.pop('freq.probe_00') - elif 'power.probe_01' in spec.keys(): - spec['power'] = spec.pop('power.probe_01') - spec['freq'] = spec.pop('freq.probe_01') - - # plot - sns.set_style("whitegrid") - plt.figure(figsize=[9, 4.5]) - ax = plt.axes() - ax.plot(spec['freq'], 20 * np.log10(spec['power'] + 1e-14), - linewidth=0.5, color=[0.5, 0.5, 0.5]) - ax.plot(spec['freq'], 20 * np.log10(np.median(spec['power'] + 1e-14, axis=1)), label='median') - ax.set_xlabel(r'Frequency (Hz)') - ax.set_ylabel(r'dB rel to $V^2.$Hz$^{-1}$') - if typ == 'ap': - ax.set_ylim([-275, -125]) - elif typ == 'lf': - ax.set_ylim([-260, -60]) - ax.legend() - ax.set_title(outpath) - if savefig: - plt.savefig(outpath / (typ + '_spec.png'), dpi=150) - print('saved figure to %s' % (outpath / (typ + '_spec.png'))) - - -def _plot_rmsmap(outpath, typ, savefig=True): - ''' - TODO document this function - ''' - - rmsmap = alfio.load_object(outpath, 'ephysQcTime' + typ.upper(), namespace='spikeglx') - - # hack to ensure a single key name - if 'times.probe_00' in rmsmap.keys(): - rmsmap['times'] = rmsmap.pop('times.probe_00') - rmsmap['rms'] = rmsmap.pop('rms.probe_00') - elif 'times.probe_01' in rmsmap.keys(): - rmsmap['times'] = rmsmap.pop('times.probe_01') - rmsmap['rms'] = rmsmap.pop('rms.probe_01') - - plt.figure(figsize=[12, 4.5]) - axim = plt.axes([0.2, 0.1, 0.7, 0.8]) - axrms = plt.axes([0.05, 0.1, 0.15, 0.8]) - axcb = plt.axes([0.92, 0.1, 0.02, 0.8]) - - axrms.plot(np.median(rmsmap['rms'], axis=0)[:-1] * 1e6, np.arange(1, rmsmap['rms'].shape[1])) - axrms.set_ylim(0, rmsmap['rms'].shape[1]) - - im = axim.imshow(20 * np.log10(rmsmap['rms'].T + 1e-15), aspect='auto', origin='lower', - extent=[rmsmap['times'][0], rmsmap['times'][-1], 0, rmsmap['rms'].shape[1]]) - axim.set_xlabel(r'Time (s)') - axrms.set_ylabel(r'Channel Number') - plt.colorbar(im, cax=axcb) - if typ == 'ap': - im.set_clim(-110, -90) - axrms.set_xlim(100, 0) - elif typ == 'lf': - im.set_clim(-100, -60) - axrms.set_xlim(500, 0) - axim.set_xlim(0, 4000) - axim.set_title(outpath) - if savefig: - plt.savefig(outpath / (typ + '_rms.png'), dpi=150) - - -# ============================== ### -# FIND THE RIGHT FILES, RUN AS SCRIPT -# ============================== ### - -if __name__ == '__main__': - if len(sys.argv) != 2: - print("Please give the folder path as an input argument!") - else: - outpath = Path(sys.argv[1]) # grab from command line input - fbin = glob.glob(os.path.join(outpath, '*.lf.bin')) - assert len(fbin) > 0 - print('fbin: %s' % fbin) - # make sure you send a path for the time being and not a string - ephysqc.extract_rmsmap(Path(fbin[0])) - _plot_spectra(outpath, 'lf') - _plot_rmsmap(outpath, 'lf') diff --git a/brainbox/quality/permutation_test.py b/brainbox/quality/permutation_test.py deleted file mode 100644 index 6a58855d8..000000000 --- a/brainbox/quality/permutation_test.py +++ /dev/null @@ -1,105 +0,0 @@ -""" -Quality control for arbitrary metrics, using permutation testing. - -Written by Sebastian Bruijns -""" - -import numpy as np -import time -import matplotlib.pyplot as plt -# TODO: take in eids and download data yourself? - - -def permut_test(data1, data2, metric, n_permut=1000, show=False, title=None): - """ - Compute the probability of observating metric difference for datasets, via permutation testing. - - We're taking absolute values of differences, because the order of dataset input shouldn't - matter - We're only computing means, what if we want to apply a more complicated function to the - permutation result? - Pay attention to always give one list (even if its just one dataset, but then it doesn't make - sense anyway...) - - Parameters - ---------- - data1 : array-like - First data set, list or array of data-entities to use for permutation test - (make data2 optional and then permutation test more similar to tuning sensitivity?) - data2 : array-like - Second data set, also list or array of data-entities to use for permutation test - metric : function, array-like -> float - Metric to use for permutation test, will be used to reduce elements of data1 and data2 - to one number - n_permut : integer (optional) - Number of perumtations to use for test - plot : Boolean (optional) - Whether or not to show a plot of the permutation distribution and a marker for the position - of the true difference in relation to this distribution - - Returns - ------- - p : float - p-value of true difference in permutation distribution - - See Also - -------- - TODO: - - Examples - -------- - TODO: - """ - # Calculate metrics and true difference between groups - print('data1') - print(data1) - metrics1 = [metric(d) for d in data1] - print('metrics1') - print(metrics1) - metrics2 = [metric(d) for d in data2] - true_diff = np.abs(np.mean(metrics1) - np.mean(metrics2)) - - # Prepare permutations - size1 = len(metrics1) - diffs = np.concatenate((metrics1, metrics2)) - permutations = np.zeros((n_permut, diffs.size), dtype=np.int32) - - # Create permutations, could be parallelized or vectorized in principle, but unclear how - indizes = np.arange(diffs.size) - for i in range(n_permut): - np.random.shuffle(indizes) - permutations[i] = indizes - - permut_diffs = np.abs(np.mean(diffs[permutations[:, :size1]], axis=1) - - np.mean(diffs[permutations[:, size1:]], axis=1)) - p = len(permut_diffs[permut_diffs > true_diff]) / n_permut - - if show or title: - plot_permut_test(permut_diffs=permut_diffs, true_diff=true_diff, p=p, title=title) - - return p - - -def plot_permut_test(permut_diffs, true_diff, p, title=None): - """Plot permutation test result.""" - n, _, _ = plt.hist(permut_diffs) - plt.plot(true_diff, np.max(n) / 20, '*r', markersize=12) - - # Prettify plot - plt.gca().spines['top'].set_visible(False) - plt.gca().spines['right'].set_visible(False) - plt.title("p = {}".format(p)) - - if title: - plt.savefig(title + '.png') - plt.close() - - -if __name__ == '__main__': - rng = np.random.RandomState(2) - data1 = rng.normal(0, 1, (23, 5)) - data2 = rng.normal(0.1, 1, (32, 5)) - t = time.time() - p = permut_test(data1, data2, np.mean, plot=True) - print(time.time() - t) - print(p) diff --git a/brainbox/spike_features.py b/brainbox/spike_features.py deleted file mode 100644 index da31cb0c3..000000000 --- a/brainbox/spike_features.py +++ /dev/null @@ -1,107 +0,0 @@ -''' -Functions that compute spike features from spike waveforms. -''' - -import numpy as np -from brainbox.io.spikeglx import extract_waveforms - - -def depth(ephys_file, spks_b, clstrs_b, chnls_b, tmplts_b, unit, n_ch=12, n_ch_probe=385, sr=30000, - dtype='int16', car=False): - ''' - Gets `n_ch` channels around a unit's channel of max amplitude, extracts all unit spike - waveforms from binary datafile for these channels, and for each spike, computes the dot - products of waveform by unit template for those channels, and computes center-of-mass of these - dot products to get spike depth estimates. - - Parameters - ---------- - ephys_file : string - The file path to the binary ephys data. - spks_b : bunch - A spikes bunch containing fields with spike information (e.g. cluster IDs, times, features, - etc.) for all spikes. - clstrs_b : bunch - A clusters bunch containing fields with cluster information (e.g. amp, ch of max amp, depth - of ch of max amp, etc.) for all clusters. - chnls_b : bunch - A channels bunch containing fields with channel information (e.g. coordinates, indices, - etc.) for all probe channels. - tmplts_b : bunch - A unit templates bunch containing fields with unit template information (e.g. template - waveforms, etc.) for all unit templates. - unit : numeric - The unit for which to return the spikes depths. - n_ch : int (optional) - The number of channels to sample around the channel of max amplitude to compute the depths. - sr : int (optional) - The sampling rate (hz) that the ephys data was acquired at. - n_ch_probe : int (optional) - The number of channels of the recording. - dtype: str (optional) - The datatype represented by the bytes in `ephys_file`. - car: bool (optional) - A flag to perform common-average-referencing before extracting waveforms. - - Returns - ------- - d : ndarray - The estimated spike depths for all spikes in `unit`. - - See Also - -------- - io.extract_waveforms - - Examples - -------- - 1) Get the spike depths for unit 1. - >>> import numpy as np - >>> import brainbox as bb - >>> import alf.io as aio - >>> import ibllib.ephys.spikes as e_spks - (*Note, if there is no 'alf' directory, make 'alf' directory from 'ks2' output directory): - >>> e_spks.ks2_to_alf(path_to_ks_out, path_to_alf_out) - # Get the necessary alf objects from an alf directory. - >>> spks_b = aio.load_object(path_to_alf_out, 'spikes') - >>> clstrs_b = aio.load_object(path_to_alf_out, 'clusters') - >>> chnls_b = aio.load_object(path_to_alf_out, 'channels') - >>> tmplts_b = aio.load_object(path_to_alf_out, 'templates') - # Compute spike depths. - >>> unit1_depths = bb.spike_features.depth(path_to_ephys_file, spks_b, clstrs_b, chnls_b, - tmplts_b, unit=1) - ''' - - # Set constants: # - n_c_ch = n_ch // 2 # number of close channels to take on either side of max channel - - # Get unit waveforms: # - # Get unit timestamps. - unit_spk_indxs = np.where(spks_b['clusters'] == unit)[0] - ts = spks_b['times'][unit_spk_indxs] - # Get `n_close_ch` channels around channel of max amplitude. - max_ch = clstrs_b['channels'][unit] - if max_ch < n_c_ch: # take only channels greater than `max_ch`. - ch = np.arange(max_ch, max_ch + n_ch) - elif (max_ch + n_c_ch) > n_ch_probe: # take only channels less than `max_ch`. - ch = np.arange(max_ch - n_ch, max_ch) - else: # take `n_c_ch` around `max_ch`. - ch = np.arange(max_ch - n_c_ch, max_ch + n_c_ch) - # Get unit template across `ch` and extract waveforms from `ephys_file`. - tmplt_wfs = tmplts_b['waveforms'] - unit_tmplt = tmplt_wfs[unit, :, ch].T - wf_t = tmplt_wfs.shape[1] / (sr / 1000) # duration (ms) of each waveform - wf = extract_waveforms(ephys_file=ephys_file, ts=ts, ch=ch, t=wf_t, sr=sr, - n_ch_probe=n_ch_probe, dtype='int16', car=car) - - # Compute center-of-mass: # - ch_depths = chnls_b['localCoordinates'][[ch], [1]] - d = np.zeros_like(ts) # depths array - # Compute normalized dot product of (waveforms,unit_template) across `ch`, - # and get center-of-mass, `c_o_m`, of these dot products (one dot for each ch) - for spk in range(len(ts)): - dot_wf_template = np.sum(wf[spk, :, :] * unit_tmplt, axis=0) - dot_wf_template += np.abs(np.min(dot_wf_template)) - dot_wf_template /= np.max(dot_wf_template) - c_o_m = (1 / np.sum(dot_wf_template)) * np.sum(dot_wf_template * ch_depths) - d[spk] = c_o_m - return d diff --git a/ibllib/io/npy_header.py b/ibllib/io/npy_header.py deleted file mode 100644 index 03511893a..000000000 --- a/ibllib/io/npy_header.py +++ /dev/null @@ -1,19 +0,0 @@ -from collections import namedtuple -import ast - - -def read(filename): - header = namedtuple('npy_header', - 'magic_string, version, header_len, descr, fortran_order, shape') - - with open(filename, 'rb') as fid: - header.magic_string = fid.read(6) - header.version = fid.read(2) - header.header_len = int.from_bytes(fid.read(2), byteorder='little') - d = ast.literal_eval(fid.read(header.header_len).decode()) - - for k in d.keys(): - print(k) - setattr(header, k, d[k]) - - return header diff --git a/ibllib/oneibl/data_handlers.py b/ibllib/oneibl/data_handlers.py index 176d10586..2ccc0960d 100644 --- a/ibllib/oneibl/data_handlers.py +++ b/ibllib/oneibl/data_handlers.py @@ -45,7 +45,7 @@ def __init__(self, name, collection, register=None, revision=None, unique=True): collection : str, None An ALF collection or pattern. register : bool - Whether to register the output file. Default is False for input files, True for output + Whether to register the file. Default is False for input files, True for output files. revision : str An optional revision. @@ -305,7 +305,7 @@ def input(name, collection, required=True, register=False, **kwargs): required : bool Whether file must always be present, or is an optional dataset. Default is True. register : bool - Whether to register the output file. Default is False for input files, True for output + Whether to register the input file. Default is False for input files, True for output files. revision : str An optional revision. diff --git a/ibllib/pipes/purge_rig_data.py b/ibllib/pipes/purge_rig_data.py deleted file mode 100644 index 823076dbd..000000000 --- a/ibllib/pipes/purge_rig_data.py +++ /dev/null @@ -1,76 +0,0 @@ -""" -Purge data from acquisition PC. - -Steps: - -- Find all files by rglob -- Find all sessions of the found files -- Check Alyx if corresponding datasetTypes have been registered as existing - sessions and files on Flatiron -- Delete local raw file if found on Flatiron -""" -import argparse -from pathlib import Path - -from one.api import ONE -from one.alf.path import get_session_path - - -def session_name(path) -> str: - """Returns the session name (subject/date/number) string for any filepath - using session_path""" - return '/'.join(get_session_path(path).parts[-3:]) - - -def purge_local_data(local_folder, file_name, lab=None, dry=False): - # Figure out datasetType from file_name or file path - file_name = Path(file_name).name - alf_parts = file_name.split('.') - dstype = '.'.join(alf_parts[:2]) - print(f'Looking for file <{file_name}> in folder <{local_folder}>') - # Get all paths for file_name in local folder - local_folder = Path(local_folder) - files = list(local_folder.rglob(f'*{file_name}')) - print(f'Found {len(files)} files') - print(f'Checking on Flatiron for datsetType: {dstype}...') - # Get all sessions and details from Alyx that have the dstype - one = ONE(cache_rest=None) - if lab is None: - eid, det = one.search(dataset_types=[dstype], details=True) - else: - eid, det = one.search(dataset_types=[dstype], lab=lab, details=True) - urls = [] - for d in det: - urls.extend([x['data_url'] for x in d['data_dataset_session_related'] - if x['dataset_type'] == dstype]) - # Remove None answers when session is registered but dstype not htere yet - urls = [u for u in urls if u is not None] - print(f'Found files on Flatiron: {len(urls)}') - to_remove = [] - for f in files: - sess_name = session_name(f) - for u in urls: - if sess_name in u: - to_remove.append(f) - print(f'Local files to remove: {len(to_remove)}') - for f in to_remove: - print(f) - if dry: - continue - else: - f.unlink() - return - - -if __name__ == "__main__": - parser = argparse.ArgumentParser(description='Delete files from rig') - parser.add_argument('folder', help='Local iblrig_data folder') - parser.add_argument( - 'file', help='File name to search and destroy for every session') - parser.add_argument('-lab', required=False, default=None, - help='Lab name, search on Alyx faster. default: None') - parser.add_argument('--dry', required=False, default=False, - action='store_true', help='Dry run? default: False') - args = parser.parse_args() - purge_local_data(args.folder, args.file, lab=args.lab, dry=args.dry) - print('Done\n') diff --git a/ibllib/pipes/remote_server.py b/ibllib/pipes/remote_server.py deleted file mode 100644 index 1e7a5f702..000000000 --- a/ibllib/pipes/remote_server.py +++ /dev/null @@ -1,143 +0,0 @@ -import logging -from pathlib import Path, PosixPath -import re -import subprocess -import os - -from ibllib.ephys import sync_probes -from ibllib.pipes import ephys_preprocessing as ephys -from ibllib.oneibl.patcher import FTPPatcher -from one.api import ONE - -_logger = logging.getLogger(__name__) - -FLATIRON_HOST = 'ibl.flatironinstitute.org' -FLATIRON_PORT = 61022 -FLATIRON_USER = 'datauser' -root_path = '/mnt/s0/Data/' - - -def _run_command(cmd): - process = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE, - stderr=subprocess.PIPE) - info, error = process.communicate() - if process.returncode != 0: - return None, error.decode('utf-8') - else: - return info.decode('utf-8').strip(), None - - -def job_transfer_ks2(probe_path): - - assert isinstance(probe_path, str) - - def _get_volume_usage_percentage(vol): - cmd = f'df {vol}' - res, _ = _run_command(cmd) - size_list = re.split(' +', res.split('\n')[-1]) - per_usage = int(size_list[4][:-1]) - return per_usage - - # First check disk availability - space = _get_volume_usage_percentage('/mnt/s0') - # If we are less than 80% full we can transfer more stuff - if space < 80: - # Transfer data from flatiron to s3 - cmd = f'ssh -i ~/.ssh/mayo_alyx.pem -p {FLATIRON_PORT} ' \ - f'{FLATIRON_USER}@{FLATIRON_HOST} ./transfer_to_aws.sh {probe_path}' - result, error = _run_command(cmd) - - # Check that command has run as expected and output info to logger - if not result: - _logger.error(f'{probe_path}: Could not transfer data from FlatIron to s3 \n' - f'Error: {error}') - return - else: - _logger.info(f'{probe_path}: Data transferred from FlatIron to s3') - - # Transfer data from s3 to /mnt/s0/Data on aws - session = str(PosixPath(probe_path).parent.parent) - cmd = f'aws s3 sync s3://ibl-ks2-storage/{session} "/mnt/s0/Data/{session}"' - result, error = _run_command(cmd) - - # Check that command has run as expected and output info to logger - if not result: - _logger.error(f'{probe_path}: Could not transfer data from s3 to aws \n' - f'Error: {error}') - return - else: - _logger.info(f'{probe_path}: Data transferred from s3 to aws') - - # Rename the files to get rid of eid associated with each dataset - session_path = Path(root_path).joinpath(session) - for file in session_path.glob('**/*'): - if len(Path(file.stem).suffix) == 37: - file.rename(Path(file.parent, str(Path(file.stem).stem) + file.suffix)) - _logger.info(f'Renamed dataset {file.stem} to {str(Path(file.stem).stem)}') - else: - _logger.warning(f'Dataset {file.stem} not renamed') - continue - - # Create a sort_me.flag - cmd = f'touch /mnt/s0/Data/{session}/sort_me.flag' - result, error = _run_command(cmd) - _logger.info(f'{session}: sort_me.flag created') - - # Remove files from s3 - cmd = f'aws s3 rm --recursive s3://ibl-ks2-storage/{session}' - result, error = _run_command(cmd) - if not result: - _logger.error(f'{session}: Could not remove data from s3 \n' - f'Error: {error}') - return - else: - _logger.info(f'{session}: Data removed from s3') - - return - - -def job_run_ks2(): - - # Look for flag files in /mnt/s0/Data and sort them in order of date they were created - flag_files = list(Path(root_path).glob('**/sort_me.flag')) - flag_files.sort(key=os.path.getmtime) - - # Start with the oldest flag - session_path = flag_files[0].parent - session = str(PosixPath(*session_path.parts[4:])) - flag_files[0].unlink() - - # Instantiate one - one = ONE(cache_rest=None) - - # sync the probes - status, sync_files = sync_probes.sync(session_path) - - if not status: - _logger.error(f'{session}: Could not sync probes') - return - else: - _logger.info(f'{session}: Probes successfully synced') - - # run ks2 - task = ephys.SpikeSorting(session_path, one=one) - status = task.run() - - if status != 0: - _logger.error(f'{session}: Could not run ks2') - return - else: - _logger.info(f'{session}: ks2 successfully completed') - - # Run the cell qc - # qc_file = [] - - # Register and upload files to FTP Patcher - outfiles = task.outputs - ftp_patcher = FTPPatcher(one=one) - ftp_patcher.create_dataset(path=outfiles, created_by=one._par.ALYX_LOGIN) - - # Remove everything apart from alf folder and spike sorter folder - # Don't do this for now unitl we are sure it works for 3A and 3B!! - # cmd = f'rm -r {session_path}/raw_ephys_data rm -r {session_path}/raw_behavior_data' - # result, error = _run_command(cmd) diff --git a/ibllib/pipes/sdsc_tasks.py b/ibllib/pipes/sdsc_tasks.py deleted file mode 100644 index efd3aa058..000000000 --- a/ibllib/pipes/sdsc_tasks.py +++ /dev/null @@ -1,43 +0,0 @@ -import numpy as np - -import spikeglx -from ibllib.ephys.sync_probes import apply_sync -from ibllib.pipes.tasks import Task - - -class RegisterSpikeSortingSDSC(Task): - - @property - def signature(self): - signature = { - 'input_files': [('*sync.npy', f'raw_ephys_data/{self.pname}', False), - ('*ap.meta', f'raw_ephys_data/{self.pname}', False)], - 'output_files': [] - } - return signature - - def __init__(self, session_path, pname=None, revision_label='#test#', **kwargs): - super().__init__(session_path, **kwargs) - - self.pname = pname - self.revision_label = revision_label - - def _run(self): - - out_path = self.session_path.joinpath('alf', self.pname, 'pykilosort', self.revision_label) - - def _fs(meta_file): - # gets sampling rate from data - md = spikeglx.read_meta_data(meta_file) - return spikeglx._get_fs_from_meta(md) - - sync_file = next(self.session_path.joinpath('raw_ephys_data', self.pname).glob('*sync.npy')) - meta_file = next(self.session_path.joinpath('raw_ephys_data', self.pname).glob('*ap.meta')) - - st_file = out_path.joinpath('spikes.times.npy') - spike_samples = np.load(out_path.joinpath('spikes.samples.npy')) - interp_times = apply_sync(sync_file, spike_samples / _fs(meta_file), forward=True) - np.save(st_file, interp_times) - - out = list(self.session_path.joinpath('alf', self.pname, 'pykilosort', self.revision_label).glob('*')) - return out diff --git a/ibllib/qc/qcplots.py b/ibllib/qc/qcplots.py deleted file mode 100644 index 4af2b294c..000000000 --- a/ibllib/qc/qcplots.py +++ /dev/null @@ -1,80 +0,0 @@ -"""Plots for trial QC - -Example: - one = ONE() - # Load data - eid = 'c8ef527b-6f7f-4f08-8b99-5aeb9d2b3740 - # Run QC - qc = TaskQC(eid, one=one) - plot_results(qc) - plt.show() - -""" -from collections import Counter -from collections.abc import Sized -from pathlib import Path -from datetime import datetime - -import matplotlib.pyplot as plt -import pandas as pd -import seaborn as sns - -from ibllib.qc.task_metrics import TaskQC - - -def plot_results(qc_obj, save_path=None): - if not isinstance(qc_obj, TaskQC): - raise ValueError('Input must be TaskQC object') - - if not qc_obj.passed: - qc_obj.compute() - outcome, results, outcomes = qc_obj.compute_session_status() - - # Sort checks by outcome and print - map = {k: [] for k in set(outcomes.values())} - for k, v in outcomes.items(): - map[v].append(k[6:]) - for k, v in map.items(): - print(f'The following checks were labelled {k}:') - print('\n'.join(v), '\n') - - # Collect some session details - n_trials = qc_obj.extractor.data['intervals'].shape[0] - det = qc_obj.one.get_details(qc_obj.eid) - ref = f"{datetime.fromisoformat(det['start_time']).date()}_{det['number']:d}_{det['subject']}" - title = ref + (' (Bpod data only)' if qc_obj.extractor.bpod_only else '') - - # Sort into each category - counts = Counter(outcomes.values()) - plt.bar(range(len(counts)), counts.values(), align='center', tick_label=list(counts.keys())) - plt.gcf().suptitle(title) - plt.ylabel('# QC checks') - plt.xlabel('outcome') - - a4_dims = (11.7, 8.27) - fig, (ax0, ax1) = plt.subplots(2, 1, figsize=a4_dims, constrained_layout=True) - fig.suptitle(title) - - # Plot failed trial level metrics - def get_trial_level_failed(d): - new_dict = {k[6:]: v for k, v in d.items() - if outcomes[k] == 'FAIL' and isinstance(v, Sized) and len(v) == n_trials} - return pd.DataFrame.from_dict(new_dict) - sns.boxplot(data=get_trial_level_failed(qc_obj.metrics), orient='h', ax=ax0) - ax0.set_yticklabels(ax0.get_yticklabels(), rotation=30, fontsize=8) - ax0.set(xscale='symlog', title='Metrics (failed)', xlabel='metric values (units vary)') - - # Plot failed trial level metrics - sns.barplot(data=get_trial_level_failed(qc_obj.passed), orient='h', ax=ax1) - ax1.set_yticklabels(ax1.get_yticklabels(), rotation=30, fontsize=8) - ax1.set(title='Counts', xlabel='proportion of trials that passed') - - if save_path is not None: - save_path = Path(save_path) - - if save_path.is_dir() and not save_path.exists(): - print(f"Folder {save_path} does not exist, not saving...") - elif save_path.is_dir(): - fig.savefig(save_path.joinpath(f"{ref}_QC.png")) - else: - fig.savefig(save_path)