diff --git a/brainbox/behavior/training.py b/brainbox/behavior/training.py index 6959c3bae..a8f2f383d 100644 --- a/brainbox/behavior/training.py +++ b/brainbox/behavior/training.py @@ -157,7 +157,7 @@ def get_subject_training_status(subj, date=None, details=True, one=None): if not trials: return sess_dates = list(trials.keys()) - status, info = get_training_status(trials, task_protocol, ephys_sess, n_delay) + status, info, _ = get_training_status(trials, task_protocol, ephys_sess, n_delay) if details: if np.any(info.get('psych')): @@ -265,13 +265,13 @@ def get_sessions(subj, date=None, one=None): if not np.any(np.array(task_protocol) == 'training'): ephys_sess = one.alyx.rest('sessions', 'list', subject=subj, date_range=[sess_dates[-1], sess_dates[0]], - django='json__PYBPOD_BOARD__icontains,ephys') + django='location__name__icontains,ephys') if len(ephys_sess) > 0: ephys_sess_dates = [sess['start_time'][:10] for sess in ephys_sess] n_delay = len(one.alyx.rest('sessions', 'list', subject=subj, date_range=[sess_dates[-1], sess_dates[0]], - django='json__SESSION_START_DELAY_SEC__gte,900')) + django='json__SESSION_DELAY_START__gte,900')) else: ephys_sess_dates = [] n_delay = 0 @@ -313,23 +313,32 @@ def get_training_status(trials, task_protocol, ephys_sess_dates, n_delay): info = Bunch() trials_all = concatenate_trials(trials) + info.session_dates = list(trials.keys()) + info.protocols = [p for p in task_protocol] # Case when all sessions are trainingChoiceWorld if np.all(np.array(task_protocol) == 'training'): - signed_contrast = get_signed_contrast(trials_all) + signed_contrast = np.unique(get_signed_contrast(trials_all)) (info.perf_easy, info.n_trials, info.psych, info.rt) = compute_training_info(trials, trials_all) - if not np.any(signed_contrast == 0): - status = 'in training' + + pass_criteria, criteria = criterion_1b(info.psych, info.n_trials, info.perf_easy, info.rt, + signed_contrast) + if pass_criteria: + failed_criteria = Bunch() + failed_criteria['NBiased'] = {'val': info.protocols, 'pass': False} + failed_criteria['Criteria'] = {'val': 'ready4ephysrig', 'pass': False} + status = 'trained 1b' else: - if criterion_1b(info.psych, info.n_trials, info.perf_easy, info.rt): - status = 'trained 1b' - elif criterion_1a(info.psych, info.n_trials, info.perf_easy): + failed_criteria = criteria + pass_criteria, criteria = criterion_1a(info.psych, info.n_trials, info.perf_easy, signed_contrast) + if pass_criteria: status = 'trained 1a' else: + failed_criteria = criteria status = 'in training' - return status, info + return status, info, failed_criteria # Case when there are < 3 biasedChoiceWorld sessions after reaching trained_1b criterion if ~np.all(np.array(task_protocol) == 'training') and \ @@ -338,7 +347,11 @@ def get_training_status(trials, task_protocol, ephys_sess_dates, n_delay): (info.perf_easy, info.n_trials, info.psych, info.rt) = compute_training_info(trials, trials_all) - return status, info + criteria = Bunch() + criteria['NBiased'] = {'val': info.protocols, 'pass': False} + criteria['Criteria'] = {'val': 'ready4ephysrig', 'pass': False} + + return status, info, criteria # Case when there is biasedChoiceWorld or ephysChoiceWorld in last three sessions if not np.any(np.array(task_protocol) == 'training'): @@ -346,37 +359,40 @@ def get_training_status(trials, task_protocol, ephys_sess_dates, n_delay): (info.perf_easy, info.n_trials, info.psych_20, info.psych_80, info.rt) = compute_bias_info(trials, trials_all) - # We are still on training rig and so all sessions should be biased - if len(ephys_sess_dates) == 0: - assert np.all(np.array(task_protocol) == 'biased') - if criterion_ephys(info.psych_20, info.psych_80, info.n_trials, info.perf_easy, - info.rt): - status = 'ready4ephysrig' - else: - status = 'trained 1b' - elif len(ephys_sess_dates) < 3: + n_ephys = len(ephys_sess_dates) + info.n_ephys = n_ephys + info.n_delay = n_delay + + # Criterion recording + pass_criteria, criteria = criteria_recording(n_ephys, n_delay, info.psych_20, info.psych_80, info.n_trials, + info.perf_easy, info.rt) + if pass_criteria: + # Here the criteria doesn't actually fail but we have no other criteria to meet so we return this + failed_criteria = criteria + status = 'ready4recording' + else: + failed_criteria = criteria assert all(date in trials for date in ephys_sess_dates) perf_ephys_easy = np.array([compute_performance_easy(trials[k]) for k in ephys_sess_dates]) n_ephys_trials = np.array([compute_n_trials(trials[k]) for k in ephys_sess_dates]) - if criterion_delay(n_ephys_trials, perf_ephys_easy): - status = 'ready4delay' - else: - status = 'ready4ephysrig' - - elif len(ephys_sess_dates) >= 3: - if n_delay > 0 and \ - criterion_ephys(info.psych_20, info.psych_80, info.n_trials, info.perf_easy, - info.rt): - status = 'ready4recording' - elif criterion_delay(info.n_trials, info.perf_easy): + pass_criteria, criteria = criterion_delay(n_ephys, n_ephys_trials, perf_ephys_easy) + + if pass_criteria: status = 'ready4delay' else: - status = 'ready4ephysrig' + failed_criteria = criteria + pass_criteria, criteria = criterion_ephys(info.psych_20, info.psych_80, info.n_trials, + info.perf_easy, info.rt) + if pass_criteria: + status = 'ready4ephysrig' + else: + failed_criteria = criteria + status = 'trained 1b' - return status, info + return status, info, failed_criteria def display_status(subj, sess_dates, status, perf_easy=None, n_trials=None, psych=None, @@ -814,7 +830,7 @@ def compute_reaction_time(trials, stim_on_type='stimOn_times', stim_off_type='re return reaction_time, contrasts, n_contrasts, -def criterion_1a(psych, n_trials, perf_easy): +def criterion_1a(psych, n_trials, perf_easy, signed_contrast): """ Returns bool indicating whether criteria for status 'trained_1a' are met. @@ -825,6 +841,7 @@ def criterion_1a(psych, n_trials, perf_easy): - Lapse rate on both sides is less than 0.2 - The total number of trials is greater than 200 for each session - Performance on easy contrasts > 80% for all sessions + - Zero contrast trials must be present Parameters ---------- @@ -835,11 +852,15 @@ def criterion_1a(psych, n_trials, perf_easy): The number for trials for each session. perf_easy : numpy.array of float The proportion of correct high contrast trials for each session. + signed_contrast: numpy.array + Unique list of contrasts displayed Returns ------- bool True if the criteria are met for 'trained_1a'. + Bunch + Bunch containing breakdown of the passing/ failing critieria Notes ----- @@ -847,12 +868,23 @@ def criterion_1a(psych, n_trials, perf_easy): for a number of sessions determined to be of 'good' performance by an experimenter. """ - criterion = (abs(psych[0]) < 16 and psych[1] < 19 and psych[2] < 0.2 and psych[3] < 0.2 and - np.all(n_trials > 200) and np.all(perf_easy > 0.8)) - return criterion + criteria = Bunch() + criteria['Zero_contrast'] = {'val': signed_contrast, 'pass': np.any(signed_contrast == 0)} + criteria['LapseLow_50'] = {'val': psych[2], 'pass': psych[2] < 0.2} + criteria['LapseHigh_50'] = {'val': psych[3], 'pass': psych[3] < 0.2} + criteria['Bias'] = {'val': psych[0], 'pass': abs(psych[0]) < 16} + criteria['Threshold'] = {'val': psych[1], 'pass': psych[1] < 19} + criteria['N_trials'] = {'val': n_trials, 'pass': np.all(n_trials > 200)} + criteria['Perf_easy'] = {'val': perf_easy, 'pass': np.all(perf_easy > 0.8)} + + passing = np.all([v['pass'] for k, v in criteria.items()]) + criteria['Criteria'] = {'val': 'trained_1a', 'pass': passing} -def criterion_1b(psych, n_trials, perf_easy, rt): + return passing, criteria + + +def criterion_1b(psych, n_trials, perf_easy, rt, signed_contrast): """ Returns bool indicating whether criteria for trained_1b are met. @@ -864,6 +896,7 @@ def criterion_1b(psych, n_trials, perf_easy, rt): - The total number of trials is greater than 400 for each session - Performance on easy contrasts > 90% for all sessions - The median response time across all zero contrast trials is less than 2 seconds + - Zero contrast trials must be present Parameters ---------- @@ -876,11 +909,15 @@ def criterion_1b(psych, n_trials, perf_easy, rt): The proportion of correct high contrast trials for each session. rt : float The median response time for zero contrast trials. + signed_contrast: numpy.array + Unique list of contrasts displayed Returns ------- bool True if the criteria are met for 'trained_1b'. + Bunch + Bunch containing breakdown of the passing/ failing critieria Notes ----- @@ -890,17 +927,27 @@ def criterion_1b(psych, n_trials, perf_easy, rt): regrettably means that the maximum threshold fit for 1b is greater than for 1a, meaning the slope of the psychometric curve may be slightly less steep than 1a. """ - criterion = (abs(psych[0]) < 10 and psych[1] < 20 and psych[2] < 0.1 and psych[3] < 0.1 and - np.all(n_trials > 400) and np.all(perf_easy > 0.9) and rt < 2) - return criterion + + criteria = Bunch() + criteria['Zero_contrast'] = {'val': signed_contrast, 'pass': np.any(signed_contrast == 0)} + criteria['LapseLow_50'] = {'val': psych[2], 'pass': psych[2] < 0.1} + criteria['LapseHigh_50'] = {'val': psych[3], 'pass': psych[3] < 0.1} + criteria['Bias'] = {'val': psych[0], 'pass': abs(psych[0]) < 10} + criteria['Threshold'] = {'val': psych[1], 'pass': psych[1] < 20} + criteria['N_trials'] = {'val': n_trials, 'pass': np.all(n_trials > 400)} + criteria['Perf_tasy'] = {'val': perf_easy, 'pass': np.all(perf_easy > 0.9)} + criteria['Reaction_time'] = {'val': rt, 'pass': rt < 2} + + passing = np.all([v['pass'] for k, v in criteria.items()]) + + criteria['Criteria'] = {'val': 'trained_1b', 'pass': passing} + + return passing, criteria def criterion_ephys(psych_20, psych_80, n_trials, perf_easy, rt): """ - Returns bool indicating whether criteria for ready4ephysrig or ready4recording are met. - - NB: The difference between these two is whether the sessions were acquired ot a recording rig - with a delay before the first trial. Neither of these two things are tested here. + Returns bool indicating whether criteria for ready4ephysrig are met. Criteria -------- @@ -929,21 +976,34 @@ def criterion_ephys(psych_20, psych_80, n_trials, perf_easy, rt): Returns ------- bool - True if subject passes the ready4ephysrig or ready4recording criteria. + True if subject passes the ready4ephysrig criteria. + Bunch + Bunch containing breakdown of the passing/ failing critieria """ + criteria = Bunch() + criteria['LapseLow_80'] = {'val': psych_80[2], 'pass': psych_80[2] < 0.1} + criteria['LapseHigh_80'] = {'val': psych_80[3], 'pass': psych_80[3] < 0.1} + criteria['LapseLow_20'] = {'val': psych_20[2], 'pass': psych_20[2] < 0.1} + criteria['LapseHigh_20'] = {'val': psych_20[3], 'pass': psych_20[3] < 0.1} + criteria['Bias_shift'] = {'val': psych_80[0] - psych_20[0], 'pass': psych_80[0] - psych_20[0] > 5} + criteria['N_trials'] = {'val': n_trials, 'pass': np.all(n_trials > 400)} + criteria['Perf_easy'] = {'val': perf_easy, 'pass': np.all(perf_easy > 0.9)} + criteria['Reaction_time'] = {'val': rt, 'pass': rt < 2} - criterion = (np.all(np.r_[psych_20[2:4], psych_80[2:4]] < 0.1) and # lapse - psych_80[0] - psych_20[0] > 5 and np.all(n_trials > 400) and # bias shift and n trials - np.all(perf_easy > 0.9) and rt < 2) # overall performance and response times - return criterion + passing = np.all([v['pass'] for k, v in criteria.items()]) + criteria['Criteria'] = {'val': 'ready4ephysrig', 'pass': passing} -def criterion_delay(n_trials, perf_easy): + return passing, criteria + + +def criterion_delay(n_ephys, n_trials, perf_easy): """ Returns bool indicating whether criteria for 'ready4delay' is met. Criteria -------- + - At least one session on an ephys rig - Total number of trials for any of the sessions is greater than 400 - Performance on easy contrasts is greater than 90% for any of the sessions @@ -959,9 +1019,69 @@ def criterion_delay(n_trials, perf_easy): ------- bool True if subject passes the 'ready4delay' criteria. + Bunch + Bunch containing breakdown of the passing/ failing critieria + """ + + criteria = Bunch() + criteria['N_ephys'] = {'val': n_ephys, 'pass': n_ephys > 0} + criteria['N_trials'] = {'val': n_trials, 'pass': np.any(n_trials > 400)} + criteria['Perf_easy'] = {'val': perf_easy, 'pass': np.any(perf_easy > 0.9)} + + passing = np.all([v['pass'] for k, v in criteria.items()]) + + criteria['Criteria'] = {'val': 'ready4delay', 'pass': passing} + + return passing, criteria + + +def criteria_recording(n_ephys, delay, psych_20, psych_80, n_trials, perf_easy, rt): """ - criterion = np.any(n_trials > 400) and np.any(perf_easy > 0.9) - return criterion + Returns bool indicating whether criteria for ready4recording are met. + + Criteria + -------- + - At least 3 ephys sessions + - Delay on any session > 0 + - Lapse on both sides < 0.1 for both bias blocks + - Bias shift between blocks > 5 + - Total number of trials > 400 for all sessions + - Performance on easy contrasts > 90% for all sessions + - Median response time for zero contrast stimuli < 2 seconds + + Parameters + ---------- + psych_20 : numpy.array + The fit psychometric parameters for the blocks where probability of a left stimulus is 0.2. + Parameters are bias, threshold, lapse high, lapse low. + psych_80 : numpy.array + The fit psychometric parameters for the blocks where probability of a left stimulus is 0.8. + Parameters are bias, threshold, lapse high, lapse low. + n_trials : numpy.array + The number of trials for each session (typically three consecutive sessions). + perf_easy : numpy.array + The proportion of correct high contrast trials for each session (typically three + consecutive sessions). + rt : float + The median response time for zero contrast trials. + + Returns + ------- + bool + True if subject passes the ready4recording criteria. + Bunch + Bunch containing breakdown of the passing/ failing critieria + """ + + _, criteria = criterion_ephys(psych_20, psych_80, n_trials, perf_easy, rt) + criteria['N_ephys'] = {'val': n_ephys, 'pass': n_ephys >= 3} + criteria['N_delay'] = {'val': delay, 'pass': delay > 0} + + passing = np.all([v['pass'] for k, v in criteria.items()]) + + criteria['Criteria'] = {'val': 'ready4recording', 'pass': passing} + + return passing, criteria def plot_psychometric(trials, ax=None, title=None, plot_ci=False, ci_alpha=0.032, **kwargs): diff --git a/brainbox/tests/test_behavior.py b/brainbox/tests/test_behavior.py index 8d02d185a..493234937 100644 --- a/brainbox/tests/test_behavior.py +++ b/brainbox/tests/test_behavior.py @@ -177,58 +177,65 @@ def test_in_training(self): trials, task_protocol = self._get_trials( sess_dates=['2020-08-25', '2020-08-24', '2020-08-21']) assert (np.all(np.array(task_protocol) == 'training')) - status, info = train.get_training_status( + status, info, crit = train.get_training_status( trials, task_protocol, ephys_sess_dates=[], n_delay=0) assert (status == 'in training') + assert (crit['Criteria']['val'] == 'trained_1a') def test_trained_1a(self): trials, task_protocol = self._get_trials( sess_dates=['2020-08-26', '2020-08-25', '2020-08-24']) assert (np.all(np.array(task_protocol) == 'training')) - status, info = train.get_training_status(trials, task_protocol, ephys_sess_dates=[], - n_delay=0) + status, info, crit = train.get_training_status(trials, task_protocol, ephys_sess_dates=[], + n_delay=0) assert (status == 'trained 1a') + assert (crit['Criteria']['val'] == 'trained_1b') def test_trained_1b(self): trials, task_protocol = self._get_trials( sess_dates=['2020-08-27', '2020-08-26', '2020-08-25']) assert (np.all(np.array(task_protocol) == 'training')) - status, info = train.get_training_status(trials, task_protocol, ephys_sess_dates=[], - n_delay=0) + status, info, crit = train.get_training_status(trials, task_protocol, ephys_sess_dates=[], + n_delay=0) self.assertEqual(status, 'trained 1b') + assert (crit['Criteria']['val'] == 'ready4ephysrig') def test_training_to_bias(self): trials, task_protocol = self._get_trials( sess_dates=['2020-08-31', '2020-08-28', '2020-08-27']) assert (~np.all(np.array(task_protocol) == 'training') and np.any(np.array(task_protocol) == 'training')) - status, info = train.get_training_status(trials, task_protocol, ephys_sess_dates=[], - n_delay=0) + status, info, crit = train.get_training_status(trials, task_protocol, ephys_sess_dates=[], + n_delay=0) assert (status == 'trained 1b') + assert (crit['Criteria']['val'] == 'ready4ephysrig') def test_ready4ephys(self): sess_dates = ['2020-09-01', '2020-08-31', '2020-08-28'] trials, task_protocol = self._get_trials(sess_dates=sess_dates) assert (np.all(np.array(task_protocol) == 'biased')) - status, info = train.get_training_status(trials, task_protocol, ephys_sess_dates=[], - n_delay=0) + status, info, crit = train.get_training_status(trials, task_protocol, ephys_sess_dates=[], + n_delay=0) assert (status == 'ready4ephysrig') + assert (crit['Criteria']['val'] == 'ready4delay') def test_ready4delay(self): sess_dates = ['2020-09-03', '2020-09-02', '2020-08-31'] trials, task_protocol = self._get_trials(sess_dates=sess_dates) assert (np.all(np.array(task_protocol) == 'biased')) - status, info = train.get_training_status(trials, task_protocol, - ephys_sess_dates=['2020-09-03'], n_delay=0) + status, info, crit = train.get_training_status(trials, task_protocol, + ephys_sess_dates=['2020-09-03'], n_delay=0) assert (status == 'ready4delay') + assert (crit['Criteria']['val'] == 'ready4recording') def test_ready4recording(self): sess_dates = ['2020-09-01', '2020-08-31', '2020-08-28'] trials, task_protocol = self._get_trials(sess_dates=sess_dates) assert (np.all(np.array(task_protocol) == 'biased')) - status, info = train.get_training_status(trials, task_protocol, - ephys_sess_dates=sess_dates, n_delay=1) + status, info, crit = train.get_training_status(trials, task_protocol, + ephys_sess_dates=sess_dates, n_delay=1) assert (status == 'ready4recording') + assert (crit['Criteria']['val'] == 'ready4recording') def test_query_criterion(self): """Test for brainbox.behavior.training.query_criterion function.""" diff --git a/examples/exploring_data/data_download.ipynb b/examples/exploring_data/data_download.ipynb index bfaca800f..48d706d2d 100644 --- a/examples/exploring_data/data_download.ipynb +++ b/examples/exploring_data/data_download.ipynb @@ -37,10 +37,10 @@ "source": [ "## Installation\n", "### Environment\n", - "To use IBL data you will need a python environment with python > 3.8. To create a new environment from scratch you can install [anaconda](https://www.anaconda.com/products/distribution#download-section) and follow the instructions below to create a new python environment (more information can also be found [here](https://docs.conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html))\n", + "To use IBL data you will need a python environment with python > 3.10, although Python 3.12 is recommended. To create a new environment from scratch you can install [anaconda](https://www.anaconda.com/products/distribution#download-section) and follow the instructions below to create a new python environment (more information can also be found [here](https://docs.conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html))\n", "\n", "```\n", - "conda create --name ibl python=3.11\n", + "conda create --name ibl python=3.12\n", "```\n", "Make sure to always activate this environment before installing or working with the IBL data\n", "```\n", @@ -138,9 +138,33 @@ "outputs": [], "source": [ "# Each session is represented by a unique experiment id (eID)\n", - "print(sessions[0])" + "print(sessions[0],)" ] }, + { + "metadata": {}, + "cell_type": "markdown", + "source": [ + "### Find recordings of a specific brain region\n", + "If we are interested in a given brain region, we can use the `search_insertions` method to find all recordings associated with that region. For example, to find all recordings associated with the **Rhomboid Nucleus (RH)** region of the thalamus." + ] + }, + { + "metadata": {}, + "cell_type": "code", + "source": [ + "# this is the query that yields the few recordings for the Rhomboid Nucleus (RH) region\n", + "insertions_rh = one.search_insertions(atlas_acronym='RH', datasets='spikes.times.npy', project='brainwide')\n", + "\n", + "# if we want to extend the search to all thalamic regions, we can do the following\n", + "insertions_th = one.search_insertions(atlas_acronym='TH', datasets='spikes.times.npy', project='brainwide')\n", + "\n", + "# the Allen brain regions parcellation is hierarchical, and searching for Thalamus will return all child Rhomboid Nucleus (RH) regions\n", + "assert set(insertions_rh).issubset(set(insertions_th))\n" + ], + "outputs": [], + "execution_count": null + }, { "cell_type": "markdown", "metadata": {}, @@ -402,9 +426,9 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.16" + "version": "3.11.9" } }, "nbformat": 4, - "nbformat_minor": 1 + "nbformat_minor": 4 } diff --git a/examples/loading_data/loading_spike_waveforms.ipynb b/examples/loading_data/loading_spike_waveforms.ipynb deleted file mode 100644 index 44b659980..000000000 --- a/examples/loading_data/loading_spike_waveforms.ipynb +++ /dev/null @@ -1,184 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "f73e02ee", - "metadata": {}, - "source": [ - "# Loading Spike Waveforms" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "ea70eb4a", - "metadata": { - "nbsphinx": "hidden" - }, - "outputs": [], - "source": [ - "# Turn off logging and disable tqdm this is a hidden cell on docs page\n", - "import logging\n", - "import os\n", - "\n", - "logger = logging.getLogger('ibllib')\n", - "logger.setLevel(logging.CRITICAL)\n", - "\n", - "os.environ[\"TQDM_DISABLE\"] = \"1\"" - ] - }, - { - "cell_type": "markdown", - "id": "64cec921", - "metadata": {}, - "source": [ - "Sample of spike waveforms extracted during spike sorting" - ] - }, - { - "cell_type": "markdown", - "id": "dca47f09", - "metadata": {}, - "source": [ - "## Relevant Alf objects\n", - "* \\_phy_spikes_subset" - ] - }, - { - "cell_type": "markdown", - "id": "eb34d848", - "metadata": {}, - "source": [ - "## Loading" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "c5d32232", - "metadata": {}, - "outputs": [], - "source": [ - "%%capture\n", - "from one.api import ONE\n", - "from brainbox.io.one import SpikeSortingLoader\n", - "from iblatlas.atlas import AllenAtlas\n", - "\n", - "one = ONE()\n", - "ba = AllenAtlas()\n", - "pid = 'da8dfec1-d265-44e8-84ce-6ae9c109b8bd' \n", - "\n", - "# Load in the spikesorting\n", - "sl = SpikeSortingLoader(pid=pid, one=one, atlas=ba)\n", - "spikes, clusters, channels = sl.load_spike_sorting()\n", - "clusters = sl.merge_clusters(spikes, clusters, channels)\n", - "\n", - "# Load the spike waveforms\n", - "spike_wfs = one.load_object(sl.eid, '_phy_spikes_subset', collection=sl.collection)" - ] - }, - { - "cell_type": "markdown", - "id": "327a23e7", - "metadata": {}, - "source": [ - "## More details\n", - "* [Description of datasets](https://docs.google.com/document/d/1OqIqqakPakHXRAwceYLwFY9gOrm8_P62XIfCTnHwstg/edit#heading=h.vcop4lz26gs9)" - ] - }, - { - "cell_type": "markdown", - "id": "257fb8b8", - "metadata": {}, - "source": [ - "## Useful modules\n", - "* COMING SOON" - ] - }, - { - "cell_type": "markdown", - "id": "157bf219", - "metadata": {}, - "source": [ - "## Exploring sample waveforms" - ] - }, - { - "cell_type": "markdown", - "id": "a617f8fb", - "metadata": {}, - "source": [ - "### Example 1: Finding the cluster ID for each sample waveform" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "1ac805b6", - "metadata": {}, - "outputs": [], - "source": [ - "# Find the cluster id for each sample waveform\n", - "wf_clusterIDs = spikes['clusters'][spike_wfs['spikes']]" - ] - }, - { - "cell_type": "markdown", - "id": "baf9eb11", - "metadata": {}, - "source": [ - "### Example 2: Compute average waveform for cluster" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "3d8a729c", - "metadata": {}, - "outputs": [], - "source": [ - "import numpy as np\n", - "\n", - "# define cluster of interest\n", - "clustID = 2\n", - "\n", - "# Find waveforms for this cluster\n", - "wf_idx = np.where(wf_clusterIDs == clustID)[0]\n", - "wfs = spike_wfs['waveforms'][wf_idx, :, :]\n", - "\n", - "# Compute average waveform on channel with max signal (chn_index 0)\n", - "wf_avg_chn_max = np.mean(wfs[:, :, 0], axis=0)" - ] - }, - { - "cell_type": "markdown", - "id": "a20b24ea", - "metadata": {}, - "source": [ - "## Other relevant examples\n", - "* COMING SOON" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.6" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/examples/loading_data/loading_spikesorting_data.ipynb b/examples/loading_data/loading_spikesorting_data.ipynb index f711414a1..db568215b 100644 --- a/examples/loading_data/loading_spikesorting_data.ipynb +++ b/examples/loading_data/loading_spikesorting_data.ipynb @@ -43,7 +43,8 @@ "## Relevant Alf objects\n", "* channels\n", "* clusters\n", - "* spikes" + "* spikes\n", + "* waveforms" ] }, { @@ -74,9 +75,10 @@ "outputs": [], "source": [ "pid = 'da8dfec1-d265-44e8-84ce-6ae9c109b8bd' \n", - "sl = SpikeSortingLoader(pid=pid, one=one)\n", - "spikes, clusters, channels = sl.load_spike_sorting()\n", - "clusters = sl.merge_clusters(spikes, clusters, channels)" + "ssl = SpikeSortingLoader(pid=pid, one=one)\n", + "spikes, clusters, channels = ssl.load_spike_sorting()\n", + "clusters = ssl.merge_clusters(spikes, clusters, channels)\n", + "waveforms = ssl.load_spike_sorting_object('waveforms') # loads in the template waveforms" ] }, { diff --git a/ibllib/__init__.py b/ibllib/__init__.py index 3525165e1..c1857164d 100644 --- a/ibllib/__init__.py +++ b/ibllib/__init__.py @@ -2,7 +2,7 @@ import logging import warnings -__version__ = '3.1.0' +__version__ = '3.2.0' warnings.filterwarnings('always', category=DeprecationWarning, module='ibllib') # if this becomes a full-blown library we should let the logging configuration to the discretion of the dev diff --git a/ibllib/io/extractors/ephys_fpga.py b/ibllib/io/extractors/ephys_fpga.py index 2980eb7bf..4da3d6bd8 100644 --- a/ibllib/io/extractors/ephys_fpga.py +++ b/ibllib/io/extractors/ephys_fpga.py @@ -69,7 +69,7 @@ """int: The number of encoder pulses per channel for one complete rotation.""" BPOD_FPGA_DRIFT_THRESHOLD_PPM = 150 -"""int: Throws an error if Bpod to FPGA clock drift is higher than this value.""" +"""int: Logs a warning if Bpod to FPGA clock drift is higher than this value.""" CHMAPS = {'3A': {'ap': @@ -545,17 +545,23 @@ def get_main_probe_sync(session_path, bin_exists=False): return sync, sync_chmap -def get_protocol_period(session_path, protocol_number, bpod_sync): +def get_protocol_period(session_path, protocol_number, bpod_sync, exclude_empty_periods=True): """ + Return the start and end time of the protocol number. + + Note that the start time is the start of the spacer pulses and the end time is either None + if the protocol is the final one, or the start of the next spacer. Parameters ---------- session_path : str, pathlib.Path The absolute session path, i.e. '/path/to/subject/yyyy-mm-dd/nnn'. protocol_number : int - The order that the protocol was run in. + The order that the protocol was run in, counted from 0. bpod_sync : dict The sync times and polarities for Bpod BNC1. + exclude_empty_periods : bool + When true, spacers are ignored if no bpod pulses are detected between periods. Returns ------- @@ -565,7 +571,14 @@ def get_protocol_period(session_path, protocol_number, bpod_sync): The time of the next detected spacer or None if this is the last protocol run. """ # The spacers are TTLs generated by Bpod at the start of each protocol - spacer_times = Spacer().find_spacers_from_fronts(bpod_sync) + sp = Spacer() + spacer_times = sp.find_spacers_from_fronts(bpod_sync) + if exclude_empty_periods: + # Drop dud protocol spacers (those without any bpod pulses after the spacer) + spacer_length = len(sp.generate_template(fs=1000)) / 1000 + periods = np.c_[spacer_times + spacer_length, np.r_[spacer_times[1:], np.inf]] + valid = [np.any((bpod_sync['times'] > pp[0]) & (bpod_sync['times'] < pp[1])) for pp in periods] + spacer_times = spacer_times[valid] # Ensure that the number of detected spacers matched the number of expected tasks if acquisition_description := session_params.read_params(session_path): n_tasks = len(acquisition_description.get('tasks', [])) diff --git a/ibllib/oneibl/registration.py b/ibllib/oneibl/registration.py index c5dc71553..51a4af84f 100644 --- a/ibllib/oneibl/registration.py +++ b/ibllib/oneibl/registration.py @@ -293,6 +293,11 @@ def register_session(self, ses_path, file_list=True, projects=None, procedures=N poo_counts = [md.get('POOP_COUNT') for md in settings if md.get('POOP_COUNT') is not None] if poo_counts: json_field['POOP_COUNT'] = int(sum(poo_counts)) + # Get the session start delay if available, needed for the training status + session_delay = [md.get('SESSION_DELAY_START') for md in settings + if md.get('SESSION_DELAY_START') is not None] + if session_delay: + json_field['SESSION_DELAY_START'] = int(sum(session_delay)) if not len(session): # Create session and weighings ses_ = {'subject': subject['nickname'], diff --git a/ibllib/pipes/behavior_tasks.py b/ibllib/pipes/behavior_tasks.py index 3f519a10c..be75cf0d6 100644 --- a/ibllib/pipes/behavior_tasks.py +++ b/ibllib/pipes/behavior_tasks.py @@ -20,7 +20,7 @@ from ibllib.pipes import training_status from ibllib.plots.figures import BehaviourPlots -_logger = logging.getLogger('ibllib') +_logger = logging.getLogger(__name__) class HabituationRegisterRaw(base_tasks.RegisterRawDataTask, base_tasks.BehaviourTask): diff --git a/ibllib/pipes/training_status.py b/ibllib/pipes/training_status.py index fec33baaf..50d28707f 100644 --- a/ibllib/pipes/training_status.py +++ b/ibllib/pipes/training_status.py @@ -208,9 +208,11 @@ def load_combined_trials(sess_paths, one, force=True): """ trials_dict = {} for sess_path in sess_paths: - trials = load_trials(Path(sess_path), one, force=force) + trials = load_trials(Path(sess_path), one, force=force, mode='warn') if trials is not None: - trials_dict[Path(sess_path).stem] = load_trials(Path(sess_path), one, force=force) + trials_dict[Path(sess_path).stem] = load_trials(Path(sess_path), one, force=force, mode='warn' + + ) return training.concatenate_trials(trials_dict) @@ -270,7 +272,7 @@ def get_latest_training_information(sess_path, one, save=True): # Find the earliest date in missing dates that we need to recompute the training status for missing_status = find_earliest_recompute_date(df.drop_duplicates('date').reset_index(drop=True)) for date in missing_status: - df = compute_training_status(df, date, one) + df, _, _, _ = compute_training_status(df, date, one) df_lim = df.drop_duplicates(subset='session_path', keep='first') @@ -314,7 +316,7 @@ def find_earliest_recompute_date(df): return df[first_index:].date.values -def compute_training_status(df, compute_date, one, force=True): +def compute_training_status(df, compute_date, one, force=True, populate=True): """ Compute the training status for compute date based on training from that session and two previous days. @@ -331,11 +333,19 @@ def compute_training_status(df, compute_date, one, force=True): An instance of ONE for loading trials data. force : bool When true and if the session trials can't be found, will attempt to re-extract from disk. + populate : bool + Whether to update the training data frame with the new training status value Returns ------- pandas.DataFrame - The input data frame with a 'training_status' column populated for `compute_date`. + The input data frame with a 'training_status' column populated for `compute_date` if populate=True + Bunch + Bunch containing information fit parameters information for the combined sessions + Bunch + Bunch cotaining the training status criteria information + str + The training status """ # compute_date = str(alfiles.session_path_parts(session_path, as_dict=True)['date']) @@ -378,11 +388,12 @@ def compute_training_status(df, compute_date, one, force=True): ephys_sessions.append(df_date.iloc[-1]['date']) n_status = np.max([-2, -1 * len(status)]) - training_status, _ = training.get_training_status(trials, protocol, ephys_sessions, n_delay) + training_status, info, criteria = training.get_training_status(trials, protocol, ephys_sessions, n_delay) training_status = pass_through_training_hierachy(training_status, status[n_status]) - df.loc[df['date'] == compute_date, 'training_status'] = training_status + if populate: + df.loc[df['date'] == compute_date, 'training_status'] = training_status - return df + return df, info, criteria, training_status def pass_through_training_hierachy(status_new, status_old): @@ -433,12 +444,13 @@ def compute_session_duration_delay_location(sess_path, collections=None, **kwarg try: start_time, end_time = _get_session_times(sess_path, md, sess_data) session_duration = session_duration + int((end_time - start_time).total_seconds() / 60) - session_delay = session_delay + md.get('SESSION_START_DELAY_SEC', 0) + session_delay = session_delay + md.get('SESSION_DELAY_START', + md.get('SESSION_START_DELAY_SEC', 0)) except Exception: session_duration = session_duration + 0 session_delay = session_delay + 0 - if 'ephys' in md.get('PYBPOD_BOARD', None): + if 'ephys' in md.get('RIG_NAME', md.get('PYBPOD_BOARD', None)): session_location = 'ephys_rig' else: session_location = 'training_rig' @@ -586,9 +598,12 @@ def get_training_info_for_session(session_paths, one, force=True): session_path = Path(session_path) protocols = [] for c in collections: - prot = get_bpod_extractor_class(session_path, task_collection=c) - prot = prot[:-6].lower() - protocols.append(prot) + try: + prot = get_bpod_extractor_class(session_path, task_collection=c) + prot = prot[:-6].lower() + protocols.append(prot) + except ValueError: + continue un_protocols = np.unique(protocols) # Example, training, training, biased - training would be combined, biased not @@ -751,9 +766,54 @@ def plot_performance_easy_median_reaction_time(df, subject): return ax +def display_info(df, axs): + compute_date = df['date'].values[-1] + _, info, criteria, _ = compute_training_status(df, compute_date, None, force=False, populate=False) + + def _array_to_string(vals): + if isinstance(vals, (str, bool, int, float)): + if isinstance(vals, float): + vals = np.round(vals, 3) + return f'{vals}' + + str_vals = '' + for v in vals: + if isinstance(v, float): + v = np.round(v, 3) + str_vals += f'{v}, ' + return str_vals[:-2] + + pos = np.arange(len(criteria))[::-1] * 0.1 + for i, (k, v) in enumerate(info.items()): + str_v = _array_to_string(v) + text = axs[0].text(0, pos[i], k.capitalize(), color='k', weight='bold', fontsize=8, transform=axs[0].transAxes) + axs[0].annotate(': ' + str_v, xycoords=text, xy=(1, 0), verticalalignment="bottom", + color='k', fontsize=7) + + pos = np.arange(len(criteria))[::-1] * 0.1 + crit_val = criteria.pop('Criteria') + c = 'g' if crit_val['pass'] else 'r' + str_v = _array_to_string(crit_val['val']) + text = axs[1].text(0, pos[0], 'Criteria', color='k', weight='bold', fontsize=8, transform=axs[1].transAxes) + axs[1].annotate(': ' + str_v, xycoords=text, xy=(1, 0), verticalalignment="bottom", + color=c, fontsize=7) + pos = pos[1:] + + for i, (k, v) in enumerate(criteria.items()): + c = 'g' if v['pass'] else 'r' + str_v = _array_to_string(v['val']) + text = axs[1].text(0, pos[i], k.capitalize(), color='k', weight='bold', fontsize=8, transform=axs[1].transAxes) + axs[1].annotate(': ' + str_v, xycoords=text, xy=(1, 0), verticalalignment="bottom", + color=c, fontsize=7) + + axs[0].set_axis_off() + axs[1].set_axis_off() + + def plot_fit_params(df, subject): - fig, axs = plt.subplots(2, 2, figsize=(12, 6)) - axs = axs.ravel() + fig, axs = plt.subplots(2, 3, figsize=(12, 6), gridspec_kw={'width_ratios': [2, 2, 1]}) + + display_info(df, axs=[axs[0, 2], axs[1, 2]]) df = df.drop_duplicates('date').reset_index(drop=True) @@ -777,11 +837,11 @@ def plot_fit_params(df, subject): 'color': cmap[0], 'join': False} - plot_over_days(df, subject, y50, ax=axs[0], legend=False, title=False) - plot_over_days(df, subject, y80, ax=axs[0], legend=False, title=False) - plot_over_days(df, subject, y20, ax=axs[0], legend=False, title=False) - axs[0].axhline(16, linewidth=2, linestyle='--', color='k') - axs[0].axhline(-16, linewidth=2, linestyle='--', color='k') + plot_over_days(df, subject, y50, ax=axs[0, 0], legend=False, title=False) + plot_over_days(df, subject, y80, ax=axs[0, 0], legend=False, title=False) + plot_over_days(df, subject, y20, ax=axs[0, 0], legend=False, title=False) + axs[0, 0].axhline(16, linewidth=2, linestyle='--', color='k') + axs[0, 0].axhline(-16, linewidth=2, linestyle='--', color='k') y50['column'] = 'combined_thres_50' y50['title'] = 'Threshold' @@ -793,10 +853,10 @@ def plot_fit_params(df, subject): y20['title'] = 'Threshold' y80['lim'] = [0, 100] - plot_over_days(df, subject, y50, ax=axs[1], legend=False, title=False) - plot_over_days(df, subject, y80, ax=axs[1], legend=False, title=False) - plot_over_days(df, subject, y20, ax=axs[1], legend=False, title=False) - axs[1].axhline(19, linewidth=2, linestyle='--', color='k') + plot_over_days(df, subject, y50, ax=axs[0, 1], legend=False, title=False) + plot_over_days(df, subject, y80, ax=axs[0, 1], legend=False, title=False) + plot_over_days(df, subject, y20, ax=axs[0, 1], legend=False, title=False) + axs[0, 1].axhline(19, linewidth=2, linestyle='--', color='k') y50['column'] = 'combined_lapselow_50' y50['title'] = 'Lapse Low' @@ -808,10 +868,10 @@ def plot_fit_params(df, subject): y20['title'] = 'Lapse Low' y20['lim'] = [0, 1] - plot_over_days(df, subject, y50, ax=axs[2], legend=False, title=False) - plot_over_days(df, subject, y80, ax=axs[2], legend=False, title=False) - plot_over_days(df, subject, y20, ax=axs[2], legend=False, title=False) - axs[2].axhline(0.2, linewidth=2, linestyle='--', color='k') + plot_over_days(df, subject, y50, ax=axs[1, 0], legend=False, title=False) + plot_over_days(df, subject, y80, ax=axs[1, 0], legend=False, title=False) + plot_over_days(df, subject, y20, ax=axs[1, 0], legend=False, title=False) + axs[1, 0].axhline(0.2, linewidth=2, linestyle='--', color='k') y50['column'] = 'combined_lapsehigh_50' y50['title'] = 'Lapse High' @@ -823,19 +883,21 @@ def plot_fit_params(df, subject): y20['title'] = 'Lapse High' y20['lim'] = [0, 1] - plot_over_days(df, subject, y50, ax=axs[3], legend=False, title=False, training_lines=True) - plot_over_days(df, subject, y80, ax=axs[3], legend=False, title=False, training_lines=False) - plot_over_days(df, subject, y20, ax=axs[3], legend=False, title=False, training_lines=False) - axs[3].axhline(0.2, linewidth=2, linestyle='--', color='k') + plot_over_days(df, subject, y50, ax=axs[1, 1], legend=False, title=False, training_lines=True) + plot_over_days(df, subject, y80, ax=axs[1, 1], legend=False, title=False, training_lines=False) + plot_over_days(df, subject, y20, ax=axs[1, 1], legend=False, title=False, training_lines=False) + axs[1, 1].axhline(0.2, linewidth=2, linestyle='--', color='k') fig.suptitle(f'{subject} {df.iloc[-1]["date"]}: {df.iloc[-1]["training_status"]}') - lines, labels = axs[3].get_legend_handles_labels() - fig.legend(lines, labels, loc='upper center', bbox_to_anchor=(0.5, 0.1), fancybox=True, shadow=True, ncol=5) + lines, labels = axs[1, 1].get_legend_handles_labels() + fig.legend(lines, labels, loc='upper center', bbox_to_anchor=(0.5, 0.1), facecolor='w', fancybox=True, shadow=True, + ncol=5) legend_elements = [Line2D([0], [0], marker='o', color='w', label='p=0.5', markerfacecolor=cmap[1], markersize=8), Line2D([0], [0], marker='o', color='w', label='p=0.2', markerfacecolor=cmap[0], markersize=8), Line2D([0], [0], marker='o', color='w', label='p=0.8', markerfacecolor=cmap[2], markersize=8)] - legend2 = plt.legend(handles=legend_elements, loc='upper right', bbox_to_anchor=(1.1, -0.2), fancybox=True, shadow=True) + legend2 = plt.legend(handles=legend_elements, loc='upper right', bbox_to_anchor=(1.1, -0.2), fancybox=True, + shadow=True, facecolor='w') fig.add_artist(legend2) return axs @@ -844,7 +906,7 @@ def plot_fit_params(df, subject): def plot_psychometric_curve(df, subject, one): df = df.drop_duplicates('date').reset_index(drop=True) sess_path = Path(df.iloc[-1]["session_path"]) - trials = load_trials(sess_path, one) + trials = load_trials(sess_path, one, mode='warn') fig, ax1 = plt.subplots(figsize=(8, 6)) @@ -907,7 +969,7 @@ def plot_over_days(df, subject, y1, y2=None, ax=None, legend=True, title=True, t box.width, box.height * 0.9]) if legend: ax1.legend(loc='upper center', bbox_to_anchor=(0.5, -0.1), - fancybox=True, shadow=True, ncol=5) + fancybox=True, shadow=True, ncol=5, facecolor='white') return ax1 @@ -1010,7 +1072,7 @@ def make_plots(session_path, one, df=None, save=False, upload=False, task_collec save_name = save_path.joinpath('subj_psychometric_fit_params.png') outputs.append(save_name) - ax4[0].get_figure().savefig(save_name, bbox_inches='tight') + ax4[0, 0].get_figure().savefig(save_name, bbox_inches='tight') save_name = save_path.joinpath('subj_psychometric_curve.png') outputs.append(save_name) diff --git a/ibllib/qc/task_metrics.py b/ibllib/qc/task_metrics.py index 86f0e4a9b..b5f3cd7e7 100644 --- a/ibllib/qc/task_metrics.py +++ b/ibllib/qc/task_metrics.py @@ -682,6 +682,7 @@ def check_iti_delays(data, subtract_pauses=False, iti_delay_secs=ITI_DELAY_SECS, numpy.array An array of boolean values, 1 per trial, where True means trial passes QC threshold. """ + # Initialize array the length of completed trials metric = np.full(data['intervals'].shape[0], np.nan) passed = metric.copy() pauses = (data['pause_duration'] if subtract_pauses else np.zeros_like(metric))[:-1] diff --git a/ibllib/qc/task_qc_viewer/ViewEphysQC.py b/ibllib/qc/task_qc_viewer/ViewEphysQC.py index 48155b270..cae7431c2 100644 --- a/ibllib/qc/task_qc_viewer/ViewEphysQC.py +++ b/ibllib/qc/task_qc_viewer/ViewEphysQC.py @@ -1,7 +1,21 @@ """An interactive PyQT QC data frame.""" + import logging -from PyQt5 import QtCore, QtWidgets +from PyQt5 import QtWidgets +from PyQt5.QtCore import ( + Qt, + QModelIndex, + pyqtSignal, + pyqtSlot, + QCoreApplication, + QSettings, + QSize, + QPoint, +) +from PyQt5.QtGui import QPalette, QShowEvent +from PyQt5.QtWidgets import QMenu, QAction +from iblqt.core import ColoredDataFrameTableModel from matplotlib.figure import Figure from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg, NavigationToolbar2QT import pandas as pd @@ -12,101 +26,17 @@ _logger = logging.getLogger(__name__) -class DataFrameModel(QtCore.QAbstractTableModel): - DtypeRole = QtCore.Qt.UserRole + 1000 - ValueRole = QtCore.Qt.UserRole + 1001 - - def __init__(self, df=pd.DataFrame(), parent=None): - super(DataFrameModel, self).__init__(parent) - self._dataframe = df - - def setDataFrame(self, dataframe): - self.beginResetModel() - self._dataframe = dataframe.copy() - self.endResetModel() - - def dataFrame(self): - return self._dataframe - - dataFrame = QtCore.pyqtProperty(pd.DataFrame, fget=dataFrame, fset=setDataFrame) - - @QtCore.pyqtSlot(int, QtCore.Qt.Orientation, result=str) - def headerData(self, section: int, orientation: QtCore.Qt.Orientation, - role: int = QtCore.Qt.DisplayRole): - if role == QtCore.Qt.DisplayRole: - if orientation == QtCore.Qt.Horizontal: - return self._dataframe.columns[section] - else: - return str(self._dataframe.index[section]) - return QtCore.QVariant() - - def rowCount(self, parent=QtCore.QModelIndex()): - if parent.isValid(): - return 0 - return len(self._dataframe.index) - - def columnCount(self, parent=QtCore.QModelIndex()): - if parent.isValid(): - return 0 - return self._dataframe.columns.size - - def data(self, index, role=QtCore.Qt.DisplayRole): - if (not index.isValid() or not (0 <= index.row() < self.rowCount() and - 0 <= index.column() < self.columnCount())): - return QtCore.QVariant() - row = self._dataframe.index[index.row()] - col = self._dataframe.columns[index.column()] - dt = self._dataframe[col].dtype - - val = self._dataframe.iloc[row][col] - if role == QtCore.Qt.DisplayRole: - return str(val) - elif role == DataFrameModel.ValueRole: - return val - if role == DataFrameModel.DtypeRole: - return dt - return QtCore.QVariant() - - def roleNames(self): - roles = { - QtCore.Qt.DisplayRole: b'display', - DataFrameModel.DtypeRole: b'dtype', - DataFrameModel.ValueRole: b'value' - } - return roles - - def sort(self, col, order): - """ - Sort table by given column number. - - :param col: the column number selected (between 0 and self._dataframe.columns.size) - :param order: the order to be sorted, 0 is descending; 1, ascending - :return: - """ - self.layoutAboutToBeChanged.emit() - col_name = self._dataframe.columns.values[col] - # print('sorting by ' + col_name) - self._dataframe.sort_values(by=col_name, ascending=not order, inplace=True) - self._dataframe.reset_index(inplace=True, drop=True) - self.layoutChanged.emit() - - class PlotCanvas(FigureCanvasQTAgg): - def __init__(self, parent=None, width=5, height=4, dpi=100, wheel=None): fig = Figure(figsize=(width, height), dpi=dpi) FigureCanvasQTAgg.__init__(self, fig) self.setParent(parent) - FigureCanvasQTAgg.setSizePolicy( - self, - QtWidgets.QSizePolicy.Expanding, - QtWidgets.QSizePolicy.Expanding) + FigureCanvasQTAgg.setSizePolicy(self, QtWidgets.QSizePolicy.Expanding, QtWidgets.QSizePolicy.Expanding) FigureCanvasQTAgg.updateGeometry(self) if wheel: - self.ax, self.ax2 = fig.subplots( - 2, 1, gridspec_kw={'height_ratios': [2, 1]}, sharex=True) + self.ax, self.ax2 = fig.subplots(2, 1, gridspec_kw={'height_ratios': [2, 1]}, sharex=True) else: self.ax = fig.add_subplot(111) self.draw() @@ -116,69 +46,210 @@ class PlotWindow(QtWidgets.QWidget): def __init__(self, parent=None, wheel=None): QtWidgets.QWidget.__init__(self, parent=None) self.canvas = PlotCanvas(wheel=wheel) - self.vbl = QtWidgets.QVBoxLayout() # Set box for plotting + self.vbl = QtWidgets.QVBoxLayout() # Set box for plotting self.vbl.addWidget(self.canvas) self.setLayout(self.vbl) self.vbl.addWidget(NavigationToolbar2QT(self.canvas, self)) class GraphWindow(QtWidgets.QWidget): + _pinnedColumns = [] + def __init__(self, parent=None, wheel=None): QtWidgets.QWidget.__init__(self, parent=parent) - vLayout = QtWidgets.QVBoxLayout(self) + + self.columnPinned = pyqtSignal(int, bool) + + # load button + self.pushButtonLoad = QtWidgets.QPushButton('Select File', self) + self.pushButtonLoad.clicked.connect(self.loadFile) + + # define table model & view + self.tableModel = ColoredDataFrameTableModel(self) + self.tableView = QtWidgets.QTableView(self) + self.tableView.setModel(self.tableModel) + self.tableView.setSortingEnabled(True) + self.tableView.horizontalHeader().setDefaultAlignment(Qt.AlignLeft | Qt.AlignVCenter) + self.tableView.horizontalHeader().setSectionsMovable(True) + self.tableView.horizontalHeader().setContextMenuPolicy(Qt.CustomContextMenu) + self.tableView.horizontalHeader().customContextMenuRequested.connect(self.contextMenu) + self.tableView.verticalHeader().hide() + self.tableView.doubleClicked.connect(self.tv_double_clicked) + + # define colors for highlighted cells + p = self.tableView.palette() + p.setColor(QPalette.Highlight, Qt.black) + p.setColor(QPalette.HighlightedText, Qt.white) + self.tableView.setPalette(p) + + # QAction for pinning columns + self.pinAction = QAction('Pin column', self) + self.pinAction.setCheckable(True) + self.pinAction.toggled.connect(self.pinColumn) + + # Filter columns by name + self.lineEditFilter = QtWidgets.QLineEdit(self) + self.lineEditFilter.setPlaceholderText('Filter columns') + self.lineEditFilter.textChanged.connect(self.changeFilter) + self.lineEditFilter.setMinimumWidth(200) + + # colormap picker + self.comboboxColormap = QtWidgets.QComboBox(self) + colormaps = {self.tableModel.colormap, 'inferno', 'magma', 'plasma', 'summer'} + self.comboboxColormap.addItems(sorted(list(colormaps))) + self.comboboxColormap.setCurrentText(self.tableModel.colormap) + self.comboboxColormap.currentTextChanged.connect(self.tableModel.setColormap) + + # slider for alpha values + self.sliderAlpha = QtWidgets.QSlider(Qt.Horizontal, self) + self.sliderAlpha.setMaximumWidth(100) + self.sliderAlpha.setMinimum(0) + self.sliderAlpha.setMaximum(255) + self.sliderAlpha.setValue(self.tableModel.alpha) + self.sliderAlpha.valueChanged.connect(self.tableModel.setAlpha) + + # Horizontal layout hLayout = QtWidgets.QHBoxLayout() - self.pathLE = QtWidgets.QLineEdit(self) - hLayout.addWidget(self.pathLE) - self.loadBtn = QtWidgets.QPushButton("Select File", self) - hLayout.addWidget(self.loadBtn) + hLayout.addWidget(self.lineEditFilter) + hLayout.addSpacing(50) + hLayout.addWidget(QtWidgets.QLabel('Colormap', self)) + hLayout.addWidget(self.comboboxColormap) + hLayout.addWidget(QtWidgets.QLabel('Alpha', self)) + hLayout.addWidget(self.sliderAlpha) + hLayout.addSpacing(50) + hLayout.addWidget(self.pushButtonLoad) + + # Vertical layout + vLayout = QtWidgets.QVBoxLayout(self) vLayout.addLayout(hLayout) - self.pandasTv = QtWidgets.QTableView(self) - vLayout.addWidget(self.pandasTv) - self.loadBtn.clicked.connect(self.load_file) - self.pandasTv.setSortingEnabled(True) - self.pandasTv.doubleClicked.connect(self.tv_double_clicked) + vLayout.addWidget(self.tableView) + + # Recover layout from QSettings + self.settings = QSettings() + self.settings.beginGroup('MainWindow') + self.resize(self.settings.value('size', QSize(800, 600), QSize)) + self.comboboxColormap.setCurrentText(self.settings.value('colormap', 'plasma', str)) + self.sliderAlpha.setValue(self.settings.value('alpha', 255, int)) + self.settings.endGroup() + self.wplot = PlotWindow(wheel=wheel) self.wplot.show() + self.tableModel.dataChanged.connect(self.wplot.canvas.draw) + self.wheel = wheel - def load_file(self): - fileName, _ = QtWidgets.QFileDialog.getOpenFileName( - self, "Open File", "", "CSV Files (*.csv)") - self.pathLE.setText(fileName) + def closeEvent(self, _) -> bool: + self.settings.beginGroup('MainWindow') + self.settings.setValue('size', self.size()) + self.settings.setValue('colormap', self.tableModel.colormap) + self.settings.setValue('alpha', self.tableModel.alpha) + self.settings.endGroup() + self.wplot.close() + + def showEvent(self, a0: QShowEvent) -> None: + super().showEvent(a0) + self.activateWindow() + + def contextMenu(self, pos: QPoint): + idx = self.sender().logicalIndexAt(pos) + action = self.pinAction + action.setData(idx) + action.setChecked(idx in self._pinnedColumns) + menu = QMenu(self) + menu.addAction(action) + menu.exec(self.sender().mapToGlobal(pos)) + + @pyqtSlot(bool) + @pyqtSlot(bool, int) + def pinColumn(self, pin: bool, idx: int | None = None): + idx = idx if idx is not None else self.sender().data() + if not pin and idx in self._pinnedColumns: + self._pinnedColumns.remove(idx) + if pin and idx not in self._pinnedColumns: + self._pinnedColumns.append(idx) + self.changeFilter(self.lineEditFilter.text()) + + def changeFilter(self, string: str): + headers = [ + self.tableModel.headerData(x, Qt.Horizontal, Qt.DisplayRole).lower() + for x in range(self.tableModel.columnCount()) + ] + tokens = [y.lower() for y in (x.strip() for x in string.split(',')) if len(y)] + showAll = len(tokens) == 0 + for idx, column in enumerate(headers): + show = showAll or any((t in column for t in tokens)) or idx in self._pinnedColumns + self.tableView.setColumnHidden(idx, not show) + + def loadFile(self): + fileName, _ = QtWidgets.QFileDialog.getOpenFileName(self, 'Open File', '', 'CSV Files (*.csv)') + if len(fileName) == 0: + return df = pd.read_csv(fileName) - self.update_df(df) - - def update_df(self, df): - model = DataFrameModel(df) - self.pandasTv.setModel(model) - self.wplot.canvas.draw() - - def tv_double_clicked(self): - df = self.pandasTv.model()._dataframe - ind = self.pandasTv.currentIndex() - start = df.loc[ind.row()]['intervals_0'] - finish = df.loc[ind.row()]['intervals_1'] - dt = finish - start + self.updateDataframe(df) + + def updateDataframe(self, df: pd.DataFrame): + # clear pinned columns + self._pinnedColumns = [] + + # try to identify and sort columns containing timestamps + col_names = df.select_dtypes('number').columns + df_interp = df[col_names].replace([-np.inf, np.inf], np.nan) + df_interp = df_interp.interpolate(limit_direction='both') + cols_mono = col_names[[df_interp[c].is_monotonic_increasing for c in col_names]] + cols_mono = [c for c in cols_mono if df[c].nunique() > 1] + cols_mono = df_interp[cols_mono].mean().sort_values().keys() + for idx, col_name in enumerate(cols_mono): + df.insert(idx, col_name, df.pop(col_name)) + + # columns containing boolean values are sorted to the end + # of those, columns containing 'pass' in their title will be sorted by number of False values + col_names = df.columns + cols_bool = list(df.select_dtypes(['bool', 'boolean']).columns) + cols_pass = [c for c in cols_bool if 'pass' in c] + cols_bool = [c for c in cols_bool if c not in cols_pass] # I know. Friday evening, brain is fried ... sorry. + cols_pass = list((~df[cols_pass]).sum().sort_values().keys()) + cols_bool += cols_pass + for col_name in cols_bool: + df = df.join(df.pop(col_name)) + + # trial_no should always be the first column + if 'trial_no' in col_names: + df.insert(0, 'trial_no', df.pop('trial_no')) + + # define columns that should be pinned by default + for col in ['trial_no']: + self._pinnedColumns.append(df.columns.get_loc(col)) + + self.tableModel.setDataFrame(df) + + def tv_double_clicked(self, index: QModelIndex): + data = self.tableModel.dataFrame.iloc[index.row()] + t0 = data['intervals_0'] + t1 = data['intervals_1'] + dt = t1 - t0 if self.wheel: - idx = np.searchsorted( - self.wheel['re_ts'], np.array([start - dt / 10, finish + dt / 10])) + idx = np.searchsorted(self.wheel['re_ts'], np.array([t0 - dt / 10, t1 + dt / 10])) period = self.wheel['re_pos'][idx[0]:idx[1]] if period.size == 0: - _logger.warning('No wheel data during trial #%i', ind.row()) + _logger.warning('No wheel data during trial #%i', index.row()) else: min_val, max_val = np.min(period), np.max(period) self.wplot.canvas.ax2.set_ylim(min_val - 1, max_val + 1) - self.wplot.canvas.ax2.set_xlim(start - dt / 10, finish + dt / 10) - self.wplot.canvas.ax.set_xlim(start - dt / 10, finish + dt / 10) - + self.wplot.canvas.ax2.set_xlim(t0 - dt / 10, t1 + dt / 10) + self.wplot.canvas.ax.set_xlim(t0 - dt / 10, t1 + dt / 10) + self.wplot.setWindowTitle(f"Trial {data.get('trial_no', '?')}") self.wplot.canvas.draw() def viewqc(qc=None, title=None, wheel=None): - qt.create_app() + app = qt.create_app() + app.setStyle('Fusion') + QCoreApplication.setOrganizationName('International Brain Laboratory') + QCoreApplication.setOrganizationDomain('internationalbrainlab.org') + QCoreApplication.setApplicationName('QC Viewer') qcw = GraphWindow(wheel=wheel) qcw.setWindowTitle(title) if qc is not None: - qcw.update_df(qc) + qcw.updateDataframe(qc) qcw.show() return qcw diff --git a/ibllib/qc/task_qc_viewer/task_qc.py b/ibllib/qc/task_qc_viewer/task_qc.py index 89a8d172f..b9f212a5c 100644 --- a/ibllib/qc/task_qc_viewer/task_qc.py +++ b/ibllib/qc/task_qc_viewer/task_qc.py @@ -140,7 +140,8 @@ def create_plots(self, axes, 'ymin': 0, 'ymax': 4, 'linewidth': 2, - 'ax': axes + 'ax': axes, + 'alpha': 0.5, } bnc1 = self.qc.extractor.frame_ttls @@ -240,7 +241,8 @@ def show_session_task_qc(qc_or_session=None, bpod_only=False, local=False, one=N if isinstance(qc_or_session, QcFrame): qc = qc_or_session elif isinstance(qc_or_session, TaskQC): - qc = QcFrame(qc_or_session) + task_qc = qc_or_session + qc = QcFrame(task_qc) else: # assumed to be eid or session path one = one or ONE(mode='local' if local else 'auto') if not is_session_path(Path(qc_or_session)): @@ -284,8 +286,22 @@ def show_session_task_qc(qc_or_session=None, bpod_only=False, local=False, one=N trial_events=list(events), color_map=cm, linestyle=ls) + # Update table and callbacks - w.update_df(qc.frame) + n_trials = qc.frame.shape[0] + if 'task_qc' in locals(): + df_trials = pd.DataFrame({ + k: v for k, v in task_qc.extractor.data.items() + if v.size == n_trials and not k.startswith('wheel') + }) + df = df_trials.merge(qc.frame, left_index=True, right_index=True) + else: + df = qc.frame + df_pass = pd.DataFrame({k: v for k, v in qc.qc.passed.items() if isinstance(v, np.ndarray) and v.size == n_trials}) + df_pass.drop('_task_passed_trial_checks', axis=1, errors='ignore', inplace=True) + df_pass.rename(columns=lambda x: x.replace('_task', 'passed'), inplace=True) + df = df.merge(df_pass.astype('boolean'), left_index=True, right_index=True) + w.updateDataframe(df) qt.run_app() return qc diff --git a/ibllib/tests/qc/test_task_qc_viewer.py b/ibllib/tests/qc/test_task_qc_viewer.py index 6db045f91..7115f371f 100644 --- a/ibllib/tests/qc/test_task_qc_viewer.py +++ b/ibllib/tests/qc/test_task_qc_viewer.py @@ -66,6 +66,7 @@ def test_show_session_task_qc(self, trials_tasks_mock, run_app_mock): qc_mock.compute_session_status.return_value = ('Fail', qc_mock.metrics, {'foo': 'FAIL'}) qc_mock.extractor.data = {'intervals': np.array([[0, 1]])} qc_mock.extractor.frame_ttls = qc_mock.extractor.audio_ttls = qc_mock.extractor.bpod_ttls = mock.MagicMock() + qc_mock.passed = dict() active_task = mock.Mock(spec=ChoiceWorldTrialsNidq, unsafe=True) active_task.run_qc.return_value = qc_mock diff --git a/release_notes.md b/release_notes.md index e2eb6ce78..a89c2b60d 100644 --- a/release_notes.md +++ b/release_notes.md @@ -1,3 +1,15 @@ +## Release Note 3.2.0 + +### features +- Add session delay info during registration of Bpod session +- Add detailed criteria info to behaviour plots +- Add column filtering, sorting and color-coding of values to metrics table of + task_qc_viewer + +### Bugfixes +- Read in updated json keys from task settings to establish ready4recording +- Handle extraction of sessions with dud spacers + ## Release Note 3.1.0 ### features diff --git a/requirements.txt b/requirements.txt index bf0f3128e..b890b3e5e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -25,9 +25,11 @@ tqdm>=4.32.1 iblatlas>=0.5.3 ibl-neuropixel>=1.5.0 iblutil>=1.13.0 +iblqt>=0.3.2 mtscomp>=1.0.1 ONE-api>=2.11 phylib>=2.6.0 psychofit slidingRP>=1.1.1 # steinmetz lab refractory period metrics pyqt5 +ibl-style diff --git a/ruff.toml b/ruff.toml new file mode 100644 index 000000000..253516e9f --- /dev/null +++ b/ruff.toml @@ -0,0 +1,4 @@ +line-length = 130 + +[format] +quote-style = "single"