diff --git a/.gitignore b/.gitignore index 906c5d9ac..e291b8572 100644 --- a/.gitignore +++ b/.gitignore @@ -7,6 +7,7 @@ __pycache__ python/scratch .idea/* .vscode/ +*.code-workspace *checkpoint.ipynb build/ venv/ diff --git a/ibllib/ephys/ephysqc.py b/ibllib/ephys/ephysqc.py index b8721bfe2..16ab9f870 100644 --- a/ibllib/ephys/ephysqc.py +++ b/ibllib/ephys/ephysqc.py @@ -580,7 +580,7 @@ def _qc_from_path(sess_path, display=True): sync, chmap = ephys_fpga.get_main_probe_sync(sess_path, bin_exists=False) _ = ephys_fpga.extract_all(sess_path, output_path=temp_alf_folder, save=True) # check that the output is complete - fpga_trials = ephys_fpga.extract_behaviour_sync(sync, chmap=chmap, display=display) + fpga_trials, *_ = ephys_fpga.extract_behaviour_sync(sync, chmap=chmap, display=display) # align with the bpod bpod2fpga = ephys_fpga.align_with_bpod(temp_alf_folder.parent) alf_trials = alfio.load_object(temp_alf_folder, 'trials') diff --git a/ibllib/io/extractors/biased_trials.py b/ibllib/io/extractors/biased_trials.py index c7c16d6c0..16d8f8111 100644 --- a/ibllib/io/extractors/biased_trials.py +++ b/ibllib/io/extractors/biased_trials.py @@ -95,12 +95,12 @@ class TrialsTableBiased(BaseBpodTrialsExtractor): intervals, goCue_times, response_times, choice, stimOn_times, contrastLeft, contrastRight, feedback_times, feedbackType, rewardVolume, probabilityLeft, firstMovement_times Additionally extracts the following wheel data: - wheel_timestamps, wheel_position, wheel_moves_intervals, wheel_moves_peak_amplitude + wheel_timestamps, wheel_position, wheelMoves_intervals, wheelMoves_peakAmplitude """ save_names = ('_ibl_trials.table.pqt', None, None, '_ibl_wheel.timestamps.npy', '_ibl_wheel.position.npy', '_ibl_wheelMoves.intervals.npy', '_ibl_wheelMoves.peakAmplitude.npy', None, None) - var_names = ('table', 'stimOff_times', 'stimFreeze_times', 'wheel_timestamps', 'wheel_position', 'wheel_moves_intervals', - 'wheel_moves_peak_amplitude', 'peakVelocity_times', 'is_final_movement') + var_names = ('table', 'stimOff_times', 'stimFreeze_times', 'wheel_timestamps', 'wheel_position', 'wheelMoves_intervals', + 'wheelMoves_peakAmplitude', 'peakVelocity_times', 'is_final_movement') def _extract(self, extractor_classes=None, **kwargs): base = [Intervals, GoCueTimes, ResponseTimes, Choice, StimOnOffFreezeTimes, ContrastLR, FeedbackTimes, FeedbackType, @@ -120,13 +120,13 @@ class TrialsTableEphys(BaseBpodTrialsExtractor): intervals, goCue_times, response_times, choice, stimOn_times, contrastLeft, contrastRight, feedback_times, feedbackType, rewardVolume, probabilityLeft, firstMovement_times Additionally extracts the following wheel data: - wheel_timestamps, wheel_position, wheel_moves_intervals, wheel_moves_peak_amplitude + wheel_timestamps, wheel_position, wheelMoves_intervals, wheelMoves_peakAmplitude """ save_names = ('_ibl_trials.table.pqt', None, None, '_ibl_wheel.timestamps.npy', '_ibl_wheel.position.npy', '_ibl_wheelMoves.intervals.npy', '_ibl_wheelMoves.peakAmplitude.npy', None, None, None, None, '_ibl_trials.quiescencePeriod.npy') - var_names = ('table', 'stimOff_times', 'stimFreeze_times', 'wheel_timestamps', 'wheel_position', 'wheel_moves_intervals', - 'wheel_moves_peak_amplitude', 'peakVelocity_times', 'is_final_movement', + var_names = ('table', 'stimOff_times', 'stimFreeze_times', 'wheel_timestamps', 'wheel_position', 'wheelMoves_intervals', + 'wheelMoves_peakAmplitude', 'peakVelocity_times', 'is_final_movement', 'phase', 'position', 'quiescence') def _extract(self, extractor_classes=None, **kwargs): @@ -154,16 +154,16 @@ class BiasedTrials(BaseBpodTrialsExtractor): None, None, '_ibl_trials.quiescencePeriod.npy') var_names = ('goCueTrigger_times', 'stimOnTrigger_times', 'itiIn_times', 'stimOffTrigger_times', 'stimFreezeTrigger_times', 'errorCueTrigger_times', 'table', 'stimOff_times', 'stimFreeze_times', 'wheel_timestamps', 'wheel_position', - 'wheel_moves_intervals', 'wheel_moves_peak_amplitude', 'peakVelocity_times', 'is_final_movement', 'included', + 'wheelMoves_intervals', 'wheelMoves_peakAmplitude', 'peakVelocity_times', 'is_final_movement', 'included', 'phase', 'position', 'quiescence') - def _extract(self, extractor_classes=None, **kwargs): + def _extract(self, extractor_classes=None, **kwargs) -> dict: base = [GoCueTriggerTimes, StimOnTriggerTimes, ItiInTimes, StimOffTriggerTimes, StimFreezeTriggerTimes, ErrorCueTriggerTimes, TrialsTableBiased, IncludedTrials, PhasePosQuiescence] # Exclude from trials table out, _ = run_extractor_classes(base, session_path=self.session_path, bpod_trials=self.bpod_trials, settings=self.settings, save=False, task_collection=self.task_collection) - return tuple(out.pop(x) for x in self.var_names) + return {k: out[k] for k in self.var_names} class EphysTrials(BaseBpodTrialsExtractor): @@ -177,16 +177,16 @@ class EphysTrials(BaseBpodTrialsExtractor): '_ibl_trials.included.npy', None, None, '_ibl_trials.quiescencePeriod.npy') var_names = ('goCueTrigger_times', 'stimOnTrigger_times', 'itiIn_times', 'stimOffTrigger_times', 'stimFreezeTrigger_times', 'errorCueTrigger_times', 'table', 'stimOff_times', 'stimFreeze_times', 'wheel_timestamps', 'wheel_position', - 'wheel_moves_intervals', 'wheel_moves_peak_amplitude', 'peakVelocity_times', 'is_final_movement', 'included', + 'wheelMoves_intervals', 'wheelMoves_peakAmplitude', 'peakVelocity_times', 'is_final_movement', 'included', 'phase', 'position', 'quiescence') - def _extract(self, extractor_classes=None, **kwargs): + def _extract(self, extractor_classes=None, **kwargs) -> dict: base = [GoCueTriggerTimes, StimOnTriggerTimes, ItiInTimes, StimOffTriggerTimes, StimFreezeTriggerTimes, ErrorCueTriggerTimes, TrialsTableEphys, IncludedTrials, PhasePosQuiescence] # Exclude from trials table out, _ = run_extractor_classes(base, session_path=self.session_path, bpod_trials=self.bpod_trials, settings=self.settings, save=False, task_collection=self.task_collection) - return tuple(out.pop(x) for x in self.var_names) + return {k: out[k] for k in self.var_names} def extract_all(session_path, save=False, bpod_trials=False, settings=False, extra_classes=None, diff --git a/ibllib/io/extractors/camera.py b/ibllib/io/extractors/camera.py index f3e5dbd1d..bf3c95528 100644 --- a/ibllib/io/extractors/camera.py +++ b/ibllib/io/extractors/camera.py @@ -16,6 +16,7 @@ from ibllib.io.extractors.base import get_session_extractor_type from ibllib.io.extractors.ephys_fpga import get_sync_fronts, get_sync_and_chn_map import ibllib.io.raw_data_loaders as raw +import ibllib.io.extractors.video_motion as vmotion from ibllib.io.extractors.base import ( BaseBpodTrialsExtractor, BaseExtractor, @@ -148,12 +149,30 @@ def _extract(self, sync=None, chmap=None, video_path=None, sync_label='audio', except AssertionError as ex: _logger.critical('Failed to extract using %s: %s', sync_label, ex) - # If you reach here extracting using sync TTLs was not possible - _logger.warning('Alignment by wheel data not yet implemented') + # If you reach here extracting using sync TTLs was not possible, we attempt to align using wheel motion energy + _logger.warning('Attempting to align using wheel') + + try: + if self.label not in ['left', 'right']: + # Can only use wheel alignment for left and right cameras + raise ValueError(f'Wheel alignment not supported for {self.label} camera') + + motion_class = vmotion.MotionAlignmentFullSession(self.session_path, self.label, upload=True) + new_times = motion_class.process() + if not motion_class.qc_outcome: + raise ValueError(f'Wheel alignment failed to pass qc: {motion_class.qc}') + else: + _logger.warning(f'Wheel alignment successful, qc: {motion_class.qc}') + return new_times + + except Exception as err: + _logger.critical(f'Failed to align with wheel: {err}') + if length < raw_ts.size: df = raw_ts.size - length _logger.info(f'Discarding first {df} pulses') raw_ts = raw_ts[df:] + return raw_ts diff --git a/ibllib/io/extractors/ephys_fpga.py b/ibllib/io/extractors/ephys_fpga.py index 98bdcdd25..74ac1e551 100644 --- a/ibllib/io/extractors/ephys_fpga.py +++ b/ibllib/io/extractors/ephys_fpga.py @@ -17,7 +17,7 @@ from iblutil.spacer import Spacer import ibllib.exceptions as err -from ibllib.io import raw_data_loaders, session_params +from ibllib.io import raw_data_loaders as raw, session_params from ibllib.io.extractors.bpod_trials import extract_all as bpod_extract_all import ibllib.io.extractors.base as extractors_base from ibllib.io.extractors.training_wheel import extract_wheel_moves @@ -554,7 +554,7 @@ def extract_behaviour_sync(sync, chmap=None, display=False, bpod_trials=None, tm ax.set_yticks([0, 1, 2, 3, 4, 5]) ax.set_ylim([0, 5]) - return trials + return trials, frame2ttl, audio, bpod def extract_sync(session_path, overwrite=False, ephys_files=None, namespace='spikeglx'): @@ -734,6 +734,7 @@ def __init__(self, *args, bpod_trials=None, bpod_extractor=None, **kwargs): super().__init__(*args, **kwargs) self.bpod2fpga = None self.bpod_trials = bpod_trials + self.frame2ttl = self.audio = self.bpod = self.settings = None if bpod_extractor: self.bpod_extractor = bpod_extractor self._update_var_names() @@ -750,14 +751,37 @@ def _update_var_names(self, bpod_fields=None, bpod_rsync_fields=None): A set of Bpod trials fields to keep. bpod_rsync_fields : tuple A set of Bpod trials fields to sync to the DAQ times. - - TODO Turn into property getter; requires ensuring the output field are the same for legacy """ if self.bpod_extractor: - self.var_names = self.bpod_extractor.var_names - self.save_names = self.bpod_extractor.save_names - self.bpod_rsync_fields = bpod_rsync_fields or self._time_fields(self.bpod_extractor.var_names) - self.bpod_fields = bpod_fields or [x for x in self.bpod_extractor.var_names if x not in self.bpod_rsync_fields] + for var_name, save_name in zip(self.bpod_extractor.var_names, self.bpod_extractor.save_names): + if var_name not in self.var_names: + self.var_names += (var_name,) + self.save_names += (save_name,) + + # self.var_names = self.bpod_extractor.var_names + # self.save_names = self.bpod_extractor.save_names + self.settings = self.bpod_extractor.settings # This is used by the TaskQC + self.bpod_rsync_fields = bpod_rsync_fields + if self.bpod_rsync_fields is None: + self.bpod_rsync_fields = tuple(self._time_fields(self.bpod_extractor.var_names)) + if 'table' in self.bpod_extractor.var_names: + if not self.bpod_trials: + self.bpod_trials = self.bpod_extractor.extract(save=False) + table_keys = alfio.AlfBunch.from_df(self.bpod_trials['table']).keys() + self.bpod_rsync_fields += tuple(self._time_fields(table_keys)) + elif bpod_rsync_fields: + self.bpod_rsync_fields = bpod_rsync_fields + excluded = (*self.bpod_rsync_fields, 'table') + if bpod_fields: + assert not set(self.bpod_fields).intersection(excluded), 'bpod_fields must not also be bpod_rsync_fields' + self.bpod_fields = bpod_fields + elif self.bpod_extractor: + self.bpod_fields = tuple(x for x in self.bpod_extractor.var_names if x not in excluded) + if 'table' in self.bpod_extractor.var_names: + if not self.bpod_trials: + self.bpod_trials = self.bpod_extractor.extract(save=False) + table_keys = alfio.AlfBunch.from_df(self.bpod_trials['table']).keys() + self.bpod_fields += (*[x for x in table_keys if x not in excluded], self.sync_field + '_bpod') @staticmethod def _time_fields(trials_attr) -> set: @@ -778,7 +802,8 @@ def _time_fields(trials_attr) -> set: pattern = re.compile(fr'^[_\w]*({"|".join(FIELDS)})[_\w]*$') return set(filter(pattern.match, trials_attr)) - def _extract(self, sync=None, chmap=None, sync_collection='raw_ephys_data', task_collection='raw_behavior_data', **kwargs): + def _extract(self, sync=None, chmap=None, sync_collection='raw_ephys_data', + task_collection='raw_behavior_data', **kwargs) -> dict: """Extracts ephys trials by combining Bpod and FPGA sync pulses""" # extract the behaviour data from bpod if sync is None or chmap is None: @@ -804,7 +829,8 @@ def _extract(self, sync=None, chmap=None, sync_collection='raw_ephys_data', task else: tmin = tmax = None - fpga_trials = extract_behaviour_sync( + # Store the cleaned frame2ttl, audio, and bpod pulses as this will be used for QC + fpga_trials, self.frame2ttl, self.audio, self.bpod = extract_behaviour_sync( sync=sync, chmap=chmap, bpod_trials=self.bpod_trials, tmin=tmin, tmax=tmax) assert self.sync_field in self.bpod_trials and self.sync_field in fpga_trials self.bpod_trials[f'{self.sync_field}_bpod'] = np.copy(self.bpod_trials[self.sync_field]) @@ -827,18 +853,20 @@ def _extract(self, sync=None, chmap=None, sync_collection='raw_ephys_data', task # extract the wheel data wheel, moves = self.get_wheel_positions(sync=sync, chmap=chmap, tmin=tmin, tmax=tmax) from ibllib.io.extractors.training_wheel import extract_first_movement_times - settings = raw_data_loaders.load_settings(session_path=self.session_path, task_collection=task_collection) - min_qt = settings.get('QUIESCENT_PERIOD', None) + if not self.settings: + self.settings = raw.load_settings(session_path=self.session_path, task_collection=task_collection) + min_qt = self.settings.get('QUIESCENT_PERIOD', None) first_move_onsets, *_ = extract_first_movement_times(moves, out, min_qt=min_qt) out.update({'firstMovement_times': first_move_onsets}) # Re-create trials table trials_table = alfio.AlfBunch({x: out.pop(x) for x in table_columns}) out['table'] = trials_table.to_df() + out.update({f'wheel_{k}': v for k, v in wheel.items()}) + out.update({f'wheelMoves_{k}': v for k, v in moves.items()}) out = {k: out[k] for k in self.var_names if k in out} # Reorder output - assert tuple(filter(lambda x: 'wheel' not in x, self.var_names)) == tuple(out.keys()) - return [out[k] for k in out] + [wheel['timestamps'], wheel['position'], - moves['intervals'], moves['peakAmplitude']] + assert self.var_names == tuple(out.keys()) + return out def get_wheel_positions(self, *args, **kwargs): """Extract wheel and wheelMoves objects. @@ -882,7 +910,7 @@ def extract_all(session_path, sync_collection='raw_ephys_data', save=True, save_ If save is True, a list of file paths to the extracted data. """ # Extract Bpod trials - bpod_raw = raw_data_loaders.load_data(session_path, task_collection=task_collection) + bpod_raw = raw.load_data(session_path, task_collection=task_collection) assert bpod_raw is not None, 'No task trials data in raw_behavior_data - Exit' bpod_trials, *_ = bpod_extract_all( session_path=session_path, bpod_trials=bpod_raw, task_collection=task_collection, diff --git a/ibllib/io/extractors/habituation_trials.py b/ibllib/io/extractors/habituation_trials.py index a78a57eef..9dedbd3d5 100644 --- a/ibllib/io/extractors/habituation_trials.py +++ b/ibllib/io/extractors/habituation_trials.py @@ -15,16 +15,15 @@ class HabituationTrials(BaseBpodTrialsExtractor): var_names = ('feedbackType', 'rewardVolume', 'stimOff_times', 'contrastLeft', 'contrastRight', 'feedback_times', 'stimOn_times', 'stimOnTrigger_times', 'intervals', 'goCue_times', 'goCueTrigger_times', 'itiIn_times', 'stimOffTrigger_times', - 'stimCenterTrigger_times', 'stimCenter_times') + 'stimCenterTrigger_times', 'stimCenter_times', 'position', 'phase') def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - exclude = ['itiIn_times', 'stimOffTrigger_times', - 'stimCenter_times', 'stimCenterTrigger_times'] - self.save_names = tuple([f'_ibl_trials.{x}.npy' if x not in exclude else None - for x in self.var_names]) + exclude = ['itiIn_times', 'stimOffTrigger_times', 'stimCenter_times', + 'stimCenterTrigger_times', 'position', 'phase'] + self.save_names = tuple(f'_ibl_trials.{x}.npy' if x not in exclude else None for x in self.var_names) - def _extract(self): + def _extract(self) -> dict: # Extract all trials... # Get all stim_sync events detected @@ -101,9 +100,14 @@ def _extract(self): ["iti"][0][0] for tr in self.bpod_trials] ) + # Phase and position + out['position'] = np.array([t['position'] for t in self.bpod_trials]) + out['phase'] = np.array([t['stim_phase'] for t in self.bpod_trials]) + # NB: We lose the last trial because the stim off event occurs at trial_num + 1 n_trials = out['stimOff_times'].size - return [out[k][:n_trials] for k in self.var_names] + # return [out[k][:n_trials] for k in self.var_names] + return {k: out[k][:n_trials] for k in self.var_names} def extract_all(session_path, save=False, bpod_trials=False, settings=False, task_collection='raw_behavior_data', save_path=None): diff --git a/ibllib/io/extractors/mesoscope.py b/ibllib/io/extractors/mesoscope.py index 93491945e..561bb6343 100644 --- a/ibllib/io/extractors/mesoscope.py +++ b/ibllib/io/extractors/mesoscope.py @@ -100,7 +100,7 @@ def __init__(self, *args, sync_collection='raw_sync_data', **kwargs): super().__init__(*args, **kwargs) self.timeline = alfio.load_object(self.session_path / sync_collection, 'DAQdata', namespace='timeline') - def _extract(self, sync=None, chmap=None, sync_collection='raw_sync_data', **kwargs): + def _extract(self, sync=None, chmap=None, sync_collection='raw_sync_data', **kwargs) -> dict: if not (sync or chmap): sync, chmap = load_timeline_sync_and_chmap( self.session_path / sync_collection, timeline=self.timeline, chmap=chmap) @@ -110,20 +110,17 @@ def _extract(self, sync=None, chmap=None, sync_collection='raw_sync_data', **kwa trials = super()._extract(sync, chmap, sync_collection, extractor_type='ephys', **kwargs) # If no protocol number is defined, trim timestamps based on Bpod trials intervals - trials_table = trials[self.var_names.index('table')] + trials_table = trials['table'] bpod = get_sync_fronts(sync, chmap['bpod']) if kwargs.get('protocol_number') is None: tmin = trials_table.intervals_0.iloc[0] - 1 tmax = trials_table.intervals_1.iloc[-1] # Ensure wheel is cut off based on trials - wheel_ts_idx = self.var_names.index('wheel_timestamps') - mask = np.logical_and(tmin <= trials[wheel_ts_idx], trials[wheel_ts_idx] <= tmax) - trials[wheel_ts_idx] = trials[wheel_ts_idx][mask] - wheel_pos_idx = self.var_names.index('wheel_position') - trials[wheel_pos_idx] = trials[wheel_pos_idx][mask] - move_idx = self.var_names.index('wheelMoves_intervals') - mask = np.logical_and(trials[move_idx][:, 0] >= tmin, trials[move_idx][:, 0] <= tmax) - trials[move_idx] = trials[move_idx][mask, :] + mask = np.logical_and(tmin <= trials['wheel_timestamps'], trials['wheel_timestamps'] <= tmax) + trials['wheel_timestamps'] = trials['wheel_timestamps'][mask] + trials['wheel_position'] = trials['wheel_position'][mask] + mask = np.logical_and(trials['wheelMoves_intervals'][:, 0] >= tmin, trials['wheelMoves_intervals'][:, 0] <= tmax) + trials['wheelMoves_intervals'] = trials['wheelMoves_intervals'][mask, :] else: tmin, tmax = get_protocol_period(self.session_path, kwargs['protocol_number'], bpod) bpod = get_sync_fronts(sync, chmap['bpod'], tmin, tmax) @@ -138,7 +135,7 @@ def _extract(self, sync=None, chmap=None, sync_collection='raw_sync_data', **kwa valve_open_times = self.get_valve_open_times(driver_ttls=driver_out) assert len(valve_open_times) == sum(trials_table.feedbackType == 1) # TODO Relax assertion correct = trials_table.feedbackType == 1 - trials[self.var_names.index('valveOpen_times')][correct] = valve_open_times + trials['valveOpen_times'][correct] = valve_open_times trials_table.feedback_times[correct] = valve_open_times # Replace audio events @@ -191,7 +188,7 @@ def first_true(arr): trials_table.feedback_times[~correct] = error_cue trials_table.goCue_times = go_cue - return trials + return {k: trials[k] for k in self.var_names} def extract_wheel_sync(self, ticks=WHEEL_TICKS, radius=WHEEL_RADIUS_CM, coding='x4', tmin=None, tmax=None): """ diff --git a/ibllib/io/extractors/training_trials.py b/ibllib/io/extractors/training_trials.py index dc13ed7dd..41a69d815 100644 --- a/ibllib/io/extractors/training_trials.py +++ b/ibllib/io/extractors/training_trials.py @@ -682,8 +682,8 @@ class TrialsTable(BaseBpodTrialsExtractor): """ save_names = ('_ibl_trials.table.pqt', None, None, '_ibl_wheel.timestamps.npy', '_ibl_wheel.position.npy', '_ibl_wheelMoves.intervals.npy', '_ibl_wheelMoves.peakAmplitude.npy', None, None) - var_names = ('table', 'stimOff_times', 'stimFreeze_times', 'wheel_timestamps', 'wheel_position', 'wheel_moves_intervals', - 'wheel_moves_peak_amplitude', 'peakVelocity_times', 'is_final_movement') + var_names = ('table', 'stimOff_times', 'stimFreeze_times', 'wheel_timestamps', 'wheel_position', 'wheelMoves_intervals', + 'wheelMoves_peakAmplitude', 'peakVelocity_times', 'is_final_movement') def _extract(self, extractor_classes=None, **kwargs): base = [Intervals, GoCueTimes, ResponseTimes, Choice, StimOnOffFreezeTimes, ContrastLR, FeedbackTimes, FeedbackType, @@ -703,16 +703,16 @@ class TrainingTrials(BaseBpodTrialsExtractor): '_ibl_wheelMoves.intervals.npy', '_ibl_wheelMoves.peakAmplitude.npy', None, None, None, None, None) var_names = ('repNum', 'goCueTrigger_times', 'stimOnTrigger_times', 'itiIn_times', 'stimOffTrigger_times', 'stimFreezeTrigger_times', 'errorCueTrigger_times', 'table', 'stimOff_times', 'stimFreeze_times', - 'wheel_timestamps', 'wheel_position', 'wheel_moves_intervals', 'wheel_moves_peak_amplitude', + 'wheel_timestamps', 'wheel_position', 'wheelMoves_intervals', 'wheelMoves_peakAmplitude', 'peakVelocity_times', 'is_final_movement', 'phase', 'position', 'quiescence') - def _extract(self): + def _extract(self) -> dict: base = [RepNum, GoCueTriggerTimes, StimOnTriggerTimes, ItiInTimes, StimOffTriggerTimes, StimFreezeTriggerTimes, ErrorCueTriggerTimes, TrialsTable, PhasePosQuiescence] out, _ = run_extractor_classes( base, session_path=self.session_path, bpod_trials=self.bpod_trials, settings=self.settings, save=False, task_collection=self.task_collection) - return tuple(out.pop(x) for x in self.var_names) + return {k: out[k] for k in self.var_names} def extract_all(session_path, save=False, bpod_trials=None, settings=None, task_collection='raw_behavior_data', save_path=None): diff --git a/ibllib/io/extractors/training_wheel.py b/ibllib/io/extractors/training_wheel.py index 617b5f1df..2f1aded8c 100644 --- a/ibllib/io/extractors/training_wheel.py +++ b/ibllib/io/extractors/training_wheel.py @@ -385,8 +385,8 @@ class Wheel(BaseBpodTrialsExtractor): save_names = ('_ibl_wheel.timestamps.npy', '_ibl_wheel.position.npy', '_ibl_wheelMoves.intervals.npy', '_ibl_wheelMoves.peakAmplitude.npy', None, '_ibl_trials.firstMovement_times.npy', None) - var_names = ('wheel_timestamps', 'wheel_position', 'wheel_moves_intervals', - 'wheel_moves_peak_amplitude', 'peakVelocity_times', 'firstMovement_times', + var_names = ('wheel_timestamps', 'wheel_position', 'wheelMoves_intervals', + 'wheelMoves_peakAmplitude', 'peakVelocity_times', 'firstMovement_times', 'is_final_movement') def _extract(self): diff --git a/ibllib/io/extractors/video_motion.py b/ibllib/io/extractors/video_motion.py index ef75187b5..4d567b2d0 100644 --- a/ibllib/io/extractors/video_motion.py +++ b/ibllib/io/extractors/video_motion.py @@ -4,21 +4,32 @@ """ import matplotlib import matplotlib.pyplot as plt +import matplotlib.gridspec as gridspec from matplotlib.widgets import RectangleSelector import numpy as np -from scipy import signal +from scipy import signal, ndimage, interpolate import cv2 from itertools import cycle import matplotlib.animation as animation import logging from pathlib import Path +from joblib import Parallel, delayed, cpu_count +from neurodsp.utils import WindowGenerator from one.api import ONE import ibllib.io.video as vidio from iblutil.util import Bunch +from ibllib.io.extractors.ephys_fpga import get_sync_fronts, get_sync_and_chn_map +import ibllib.io.raw_data_loaders as raw +import ibllib.io.extractors.camera as cam +from ibllib.plots.snapshot import ReportSnapshot import brainbox.video as video import brainbox.behavior.wheel as wh +from brainbox.singlecell import bin_spikes +from brainbox.behavior.dlc import likelihood_threshold, get_speed +from brainbox.task.trials import find_trial_ids import one.alf.io as alfio +from one.alf.exceptions import ALFObjectNotFound from one.alf.spec import is_session_path, is_uuid_string @@ -383,3 +394,521 @@ def process_key(event): anim.save(str(filename), writer=writer) else: plt.show() + + +class MotionAlignmentFullSession: + def __init__(self, session_path, label, **kwargs): + self.session_path = session_path + self.label = label + self.threshold = kwargs.get('threshold', 20) + self.upload = kwargs.get('upload', False) + self.twin = kwargs.get('twin', 150) + self.nprocess = kwargs.get('nprocess', int(cpu_count() - cpu_count() / 4)) + + self.load_data(sync=kwargs.get('sync', 'nidq'), location=kwargs.get('location', None)) + self.roi, self.mask = self.get_roi_mask() + + if self.upload: + self.one = ONE(mode='remote') + self.one.alyx.authenticate() + self.eid = self.one.path2eid(self.session_path) + + def load_data(self, sync='nidq', location=None): + def fix_keys(alf_object): + ob = Bunch() + for key in alf_object.keys(): + vals = alf_object[key] + ob[key.split('.')[0]] = vals + return ob + + alf_path = self.session_path.joinpath('alf') + wheel = (fix_keys(alfio.load_object(alf_path, 'wheel')) if location == 'SDSC' + else alfio.load_object(alf_path, 'wheel')) + self.wheel_timestamps = wheel.timestamps + wheel_pos, self.wheel_time = wh.interpolate_position(wheel.timestamps, wheel.position, freq=1000) + self.wheel_vel, _ = wh.velocity_filtered(wheel_pos, 1000) + self.camera_times = alfio.load_file_content(next(alf_path.glob(f'_ibl_{self.label}Camera.times*.npy'))) + self.camera_path = str(next(self.session_path.joinpath('raw_video_data').glob( + f'_iblrig_{self.label}Camera.raw*.mp4'))) + self.camera_meta = vidio.get_video_meta(self.camera_path) + + # TODO should read in the description file to get the correct sync location + if sync == 'nidq': + sync, chmap = get_sync_and_chn_map(self.session_path, sync_collection='raw_ephys_data') + sr = get_sync_fronts(sync, chmap[f'{self.label}_camera']) + self.ttls = sr.times[::2] + else: + cam_extractor = cam.CameraTimestampsBpod(session_path=self.session_path) + cam_extractor.bpod_trials = raw.load_data(self.session_path, task_collection='raw_behavior_data') + self.ttls = cam_extractor._times_from_bpod() + + self.tdiff = self.ttls.size - self.camera_meta['length'] + + if self.tdiff < 0: + self.ttl_times = self.ttls + self.times = np.r_[self.ttl_times, np.full((np.abs(self.tdiff)), np.nan)] + self.short_flag = True + elif self.tdiff > 0: + self.ttl_times = self.ttls[self.tdiff:] + self.times = self.ttls[self.tdiff:] + self.short_flag = False + + self.frate = round(1 / np.nanmedian(np.diff(self.ttl_times))) + + try: + self.trials = alfio.load_file_content(next(alf_path.glob('_ibl_trials.table*.pqt'))) + self.dlc = alfio.load_file_content(next(alf_path.glob(f'_ibl_{self.label}Camera.dlc*.pqt'))) + self.dlc = likelihood_threshold(self.dlc) + self.behavior = True + except (ALFObjectNotFound, StopIteration): + self.behavior = False + + self.frame_example = vidio.get_video_frames_preload(self.camera_path, np.arange(10, 11), mask=np.s_[:, :, 0]) + + def get_roi_mask(self): + + if self.label == 'right': + roi = ((450, 512), (120, 200)) + else: + roi = ((900, 1024), (850, 1010)) + roi_mask = (*[slice(*r) for r in roi], 0) + + return roi, roi_mask + + def find_contaminated_frames(self, video_frames, thresold=20, normalise=True): + high = np.zeros((video_frames.shape[0])) + for idx, frame in enumerate(video_frames): + ret, _ = cv2.threshold(cv2.GaussianBlur(frame, (5, 5), 0), 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) + high[idx] = ret + + if normalise: + high -= np.min(high) + + contaminated_frames = np.where(high > thresold)[0] + + return contaminated_frames + + def compute_motion_energy(self, first, last, wg, iw): + + if iw == wg.nwin - 1: + return + + cap = cv2.VideoCapture(self.camera_path) + frames = vidio.get_video_frames_preload(cap, np.arange(first, last), mask=self.mask) + idx = self.find_contaminated_frames(frames, self.threshold) + + if len(idx) != 0: + + before_status = False + after_status = False + + counter = 0 + n_frames = 200 + while np.any(idx == 0) and counter < 20 and iw != 0: + n_before_offset = (counter + 1) * n_frames + first -= n_frames + extra_frames = vidio.get_video_frames_preload(cap, frame_numbers=np.arange(first - n_frames, first), + mask=self.mask) + frames = np.concatenate([extra_frames, frames], axis=0) + + idx = self.find_contaminated_frames(frames, self.threshold) + before_status = True + counter += 1 + if counter > 0: + print(f'In before: {counter}') + + counter = 0 + while np.any(idx == frames.shape[0] - 1) and counter < 20 and iw != wg.nwin - 1: + n_after_offset = (counter + 1) * n_frames + last += n_frames + extra_frames = vidio.get_video_frames_preload(cap, frame_numbers=np.arange(last, last + n_frames), + mask=self.mask) + frames = np.concatenate([frames, extra_frames], axis=0) + idx = self.find_contaminated_frames(frames, self.threshold) + after_status = True + counter += 1 + + if counter > 0: + print(f'In after: {counter}') + + intervals = np.split(idx, np.where(np.diff(idx) != 1)[0] + 1) + for ints in intervals: + if len(ints) > 0 and ints[0] == 0: + ints = ints[1:] + if len(ints) > 0 and ints[-1] == frames.shape[0] - 1: + ints = ints[:-1] + th_all = np.zeros_like(frames[0]) + for idx in ints: + img = np.copy(frames[idx]) + blur = cv2.GaussianBlur(img, (5, 5), 0) + ret, th = cv2.threshold(blur, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) + th = cv2.GaussianBlur(th, (5, 5), 10) + th_all += th + vals = np.mean(np.dstack([frames[ints[0] - 1], frames[ints[-1] + 1]]), axis=-1) + for idx in ints: + img = frames[idx] + img[th_all > 0] = vals[th_all > 0] + + if before_status: + frames = frames[n_before_offset:] + if after_status: + frames = frames[:(-1 * n_after_offset)] + + frame_me, _ = video.motion_energy(frames, diff=2, normalize=False) + + cap.release() + + return frame_me[2:] + + def compute_shifts(self, times, me, first, last, iw, wg): + + if iw == wg.nwin - 1: + return np.nan, np.nan + t_first = times[first] + t_last = times[last] + if np.isnan(t_last) and np.isnan(t_first): + return np.nan, np.nan + elif np.isnan(t_last): + t_last = times[np.where(~np.isnan(times))[0][-1]] + + mask = np.logical_and(times >= t_first, times <= t_last) + align_me = me[np.where(mask)[0]] + align_me = (align_me - np.nanmin(align_me)) / (np.nanmax(align_me) - np.nanmin(align_me)) + + # Find closest timepoints in wheel that match the camera times + wh_mask = np.logical_and(self.wheel_time >= t_first, self.wheel_time <= t_last) + if np.sum(wh_mask) == 0: + return np.nan, np.nan + xs = np.searchsorted(self.wheel_time[wh_mask], times[mask]) + xs[xs == np.sum(wh_mask)] = np.sum(wh_mask) - 1 + # Convert to normalized speed + vs = np.abs(self.wheel_vel[wh_mask][xs]) + vs = (vs - np.min(vs)) / (np.max(vs) - np.min(vs)) + + isnan = np.isnan(align_me) + + if np.sum(isnan) > 0: + where_nan = np.where(isnan)[0] + assert where_nan[0] == 0 + assert where_nan[-1] == np.sum(isnan) - 1 + + if np.all(isnan): + return np.nan, np.nan + + xcorr = signal.correlate(align_me[~isnan], vs[~isnan]) + shift = np.nanargmax(xcorr) - align_me[~isnan].size + 2 + + return shift, t_first + (t_last - t_first) / 2 + + def clean_shifts(self, x, n=1): + y = x.copy() + dy = np.diff(y, prepend=y[0]) + while True: + pos = np.where(dy == 1)[0] if n == 1 else np.where(dy > 2)[0] + # added frames: this doesn't make sense and this is noise + if pos.size == 0: + break + neg = np.where(dy == -1)[0] if n == 1 else np.where(dy < -2)[0] + + if len(pos) > len(neg): + neg = np.append(neg, dy.size - 1) + + iss = np.minimum(np.searchsorted(neg, pos), neg.size - 1) + imin = np.argmin(np.minimum(np.abs(pos - neg[iss - 1]), np.abs(pos - neg[iss]))) + + idx = np.max([0, iss[imin] - 1]) + ineg = neg[idx:iss[imin] + 1] + ineg = ineg[np.argmin(np.abs(pos[imin] - ineg))] + dy[pos[imin]] = 0 + dy[ineg] = 0 + + return np.cumsum(dy) + y[0] + + def qc_shifts(self, shifts, shifts_filt): + + ttl_per = (np.abs(self.tdiff) / self.camera_meta['length']) * 100 if self.tdiff < 0 else 0 + nan_per = (np.sum(np.isnan(shifts_filt)) / shifts_filt.size) * 100 + shifts_sum = np.where(np.abs(np.diff(shifts)) > 10)[0].size + shifts_filt_sum = np.where(np.abs(np.diff(shifts_filt)) > 1)[0].size + + qc = dict() + qc['ttl_per'] = ttl_per + qc['nan_per'] = nan_per + qc['shifts_sum'] = shifts_sum + qc['shifts_filt_sum'] = shifts_filt_sum + + qc_outcome = True + # If more than 10% of ttls are missing we don't get new times + if ttl_per > 10: + qc_outcome = False + # If too many of the shifts are nans it means the alignment is not accurate + if nan_per > 40: + qc_outcome = False + # If there are too many artefacts could be errors + if shifts_sum > 60: + qc_outcome = False + # If there are jumps > 1 in the filtered shifts then there is a problem + if shifts_filt_sum > 0: + qc_outcome = False + + return qc, qc_outcome + + def extract_times(self, shifts_filt, t_shifts): + + t_new = t_shifts - (shifts_filt * 1 / self.frate) + fcn = interpolate.interp1d(t_shifts, t_new, fill_value="extrapolate") + new_times = fcn(self.ttl_times) + + if self.tdiff < 0: + to_app = (np.arange(np.abs(self.tdiff), ) + 1) / self.frate + new_times[-1] + new_times = np.r_[new_times, to_app] + + return new_times + + @staticmethod + def single_cluster_raster(spike_times, events, trial_idx, dividers, colors, labels, weights=None, fr=True, + norm=False, + axs=None): + pre_time = 0.4 + post_time = 1 + raster_bin = 0.01 + psth_bin = 0.05 + raster, t_raster = bin_spikes( + spike_times, events, pre_time=pre_time, post_time=post_time, bin_size=raster_bin, weights=weights) + psth, t_psth = bin_spikes( + spike_times, events, pre_time=pre_time, post_time=post_time, bin_size=psth_bin, weights=weights) + + if fr: + psth = psth / psth_bin + + if norm: + psth = psth - np.repeat(psth[:, 0][:, np.newaxis], psth.shape[1], axis=1) + raster = raster - np.repeat(raster[:, 0][:, np.newaxis], raster.shape[1], axis=1) + + dividers = [0] + dividers + [len(trial_idx)] + if axs is None: + fig, axs = plt.subplots(2, 1, figsize=(4, 6), gridspec_kw={'height_ratios': [1, 3], 'hspace': 0}, + sharex=True) + else: + fig = axs[0].get_figure() + + label, lidx = np.unique(labels, return_index=True) + label_pos = [] + for lab, lid in zip(label, lidx): + idx = np.where(np.array(labels) == lab)[0] + for iD in range(len(idx)): + if iD == 0: + t_ids = trial_idx[dividers[idx[iD]] + 1:dividers[idx[iD] + 1] + 1] + t_ints = dividers[idx[iD] + 1] - dividers[idx[iD]] + else: + t_ids = np.r_[t_ids, trial_idx[dividers[idx[iD]] + 1:dividers[idx[iD] + 1] + 1]] + t_ints = np.r_[t_ints, dividers[idx[iD] + 1] - dividers[idx[iD]]] + + psth_div = np.nanmean(psth[t_ids], axis=0) + std_div = np.nanstd(psth[t_ids], axis=0) / np.sqrt(len(t_ids)) + + axs[0].fill_between(t_psth, psth_div - std_div, + psth_div + std_div, alpha=0.4, color=colors[lid]) + axs[0].plot(t_psth, psth_div, alpha=1, color=colors[lid]) + + lab_max = idx[np.argmax(t_ints)] + label_pos.append((dividers[lab_max + 1] - dividers[lab_max]) / 2 + dividers[lab_max]) + + axs[1].imshow(raster[trial_idx], cmap='binary', origin='lower', + extent=[np.min(t_raster), np.max(t_raster), 0, len(trial_idx)], aspect='auto') + + width = raster_bin * 4 + for iD in range(len(dividers) - 1): + axs[1].fill_between([post_time + raster_bin / 2, post_time + raster_bin / 2 + width], + [dividers[iD + 1], dividers[iD + 1]], [dividers[iD], dividers[iD]], color=colors[iD]) + + axs[1].set_xlim([-1 * pre_time, post_time + raster_bin / 2 + width]) + secax = axs[1].secondary_yaxis('right') + + secax.set_yticks(label_pos) + secax.set_yticklabels(label, rotation=90, + rotation_mode='anchor', ha='center') + for ic, c in enumerate(np.array(colors)[lidx]): + secax.get_yticklabels()[ic].set_color(c) + + axs[0].axvline(0, *axs[0].get_ylim(), c='k', ls='--', zorder=10) # TODO this doesn't always work + axs[1].axvline(0, *axs[1].get_ylim(), c='k', ls='--', zorder=10) + + return fig, axs + + def plot_with_behavior(self): + + self.dlc = likelihood_threshold(self.dlc) + trial_idx, dividers = find_trial_ids(self.trials, sort='side') + feature_ext = get_speed(self.dlc, self.camera_times, self.label, feature='paw_r') + feature_new = get_speed(self.dlc, self.new_times, self.label, feature='paw_r') + + fig = plt.figure() + fig.set_size_inches(15, 9) + gs = gridspec.GridSpec(1, 5, figure=fig, width_ratios=[4, 1, 1, 1, 3], wspace=0.3, hspace=0.5) + gs0 = gridspec.GridSpecFromSubplotSpec(3, 1, subplot_spec=gs[0, 0]) + ax01 = fig.add_subplot(gs0[0, 0]) + ax02 = fig.add_subplot(gs0[1, 0]) + ax03 = fig.add_subplot(gs0[2, 0]) + gs1 = gridspec.GridSpecFromSubplotSpec(2, 1, subplot_spec=gs[0, 1], height_ratios=[1, 3]) + ax11 = fig.add_subplot(gs1[0, 0]) + ax12 = fig.add_subplot(gs1[1, 0]) + gs2 = gridspec.GridSpecFromSubplotSpec(2, 1, subplot_spec=gs[0, 2], height_ratios=[1, 3]) + ax21 = fig.add_subplot(gs2[0, 0]) + ax22 = fig.add_subplot(gs2[1, 0]) + gs3 = gridspec.GridSpecFromSubplotSpec(2, 1, subplot_spec=gs[0, 3], height_ratios=[1, 3]) + ax31 = fig.add_subplot(gs3[0, 0]) + ax32 = fig.add_subplot(gs3[1, 0]) + gs4 = gridspec.GridSpecFromSubplotSpec(2, 1, subplot_spec=gs[0, 4]) + ax41 = fig.add_subplot(gs4[0, 0]) + ax42 = fig.add_subplot(gs4[1, 0]) + + ax01.plot(self.t_shifts, self.shifts, label='shifts') + ax01.plot(self.t_shifts, self.shifts_filt, label='shifts_filt') + ax01.set_ylim(np.min(self.shifts_filt) - 10, np.max(self.shifts_filt) + 10) + ax01.legend() + ax01.set_ylabel('Frames') + ax01.set_xlabel('Time in session') + + xs = np.searchsorted(self.ttl_times, self.t_shifts) + ttl_diff = (self.times - self.camera_times)[xs] * self.camera_meta['fps'] + ax02.plot(self.t_shifts, ttl_diff, label='extracted - ttl') + ax02.set_ylim(np.min(ttl_diff) - 10, np.max(ttl_diff) + 10) + ax02.legend() + ax02.set_ylabel('Frames') + ax02.set_xlabel('Time in session') + + ax03.plot(self.camera_times, (self.camera_times - self.new_times) * self.camera_meta['fps'], + 'k', label='extracted - new') + ax03.legend() + ax03.set_ylim(-5, 5) + ax03.set_ylabel('Frames') + ax03.set_xlabel('Time in session') + + self.single_cluster_raster(self.wheel_timestamps, self.trials['firstMovement_times'].values, trial_idx, dividers, + ['g', 'y'], ['left', 'right'], weights=self.wheel_vel, fr=False, axs=[ax11, ax12]) + ax11.sharex(ax12) + ax11.set_ylabel('Wheel velocity') + ax11.set_title('Wheel') + ax12.set_xlabel('Time from first move') + + self.single_cluster_raster(self.camera_times, self.trials['firstMovement_times'].values, trial_idx, dividers, + ['g', 'y'], ['left', 'right'], weights=feature_ext, fr=False, axs=[ax21, ax22]) + ax21.sharex(ax22) + ax21.set_ylabel('Paw r velocity') + ax21.set_title('Extracted times') + ax22.set_xlabel('Time from first move') + + self.single_cluster_raster(self.new_times, self.trials['firstMovement_times'].values, trial_idx, dividers, ['g', 'y'], + ['left', 'right'], weights=feature_new, fr=False, axs=[ax31, ax32]) + ax31.sharex(ax32) + ax31.set_ylabel('Paw r velocity') + ax31.set_title('New times') + ax32.set_xlabel('Time from first move') + + ax41.imshow(self.frame_example[0]) + rect = matplotlib.patches.Rectangle((self.roi[1][1], self.roi[0][0]), self.roi[1][0] - self.roi[1][1], + self.roi[0][1] - self.roi[0][0], + linewidth=4, edgecolor='g', facecolor='none') + ax41.add_patch(rect) + + ax42.plot(self.all_me) + + return fig + + def plot_without_behavior(self): + + fig = plt.figure() + fig.set_size_inches(7, 7) + gs = gridspec.GridSpec(1, 2, figure=fig) + gs0 = gridspec.GridSpecFromSubplotSpec(3, 1, subplot_spec=gs[0, 0]) + ax01 = fig.add_subplot(gs0[0, 0]) + ax02 = fig.add_subplot(gs0[1, 0]) + ax03 = fig.add_subplot(gs0[2, 0]) + + gs1 = gridspec.GridSpecFromSubplotSpec(2, 1, subplot_spec=gs[0, 1]) + ax04 = fig.add_subplot(gs1[0, 0]) + ax05 = fig.add_subplot(gs1[1, 0]) + + ax01.plot(self.t_shifts, self.shifts, label='shifts') + ax01.plot(self.t_shifts, self.shifts_filt, label='shifts_filt') + ax01.set_ylim(np.min(self.shifts_filt) - 10, np.max(self.shifts_filt) + 10) + ax01.legend() + ax01.set_ylabel('Frames') + ax01.set_xlabel('Time in session') + + xs = np.searchsorted(self.ttl_times, self.t_shifts) + ttl_diff = (self.times - self.camera_times)[xs] * self.camera_meta['fps'] + ax02.plot(self.t_shifts, ttl_diff, label='extracted - ttl') + ax02.set_ylim(np.min(ttl_diff) - 10, np.max(ttl_diff) + 10) + ax02.legend() + ax02.set_ylabel('Frames') + ax02.set_xlabel('Time in session') + + ax03.plot(self.camera_times, (self.camera_times - self.new_times) * self.camera_meta['fps'], + 'k', label='extracted - new') + ax03.legend() + ax03.set_ylim(-5, 5) + ax03.set_ylabel('Frames') + ax03.set_xlabel('Time in session') + + ax04.imshow(self.frame_example[0]) + rect = matplotlib.patches.Rectangle((self.roi[1][1], self.roi[0][0]), self.roi[1][0] - self.roi[1][1], + self.roi[0][1] - self.roi[0][0], + linewidth=4, edgecolor='g', facecolor='none') + ax04.add_patch(rect) + + ax05.plot(self.all_me) + + return fig + + def process(self): + + # Compute the motion energy of the wheel for the whole video + wg = WindowGenerator(self.camera_meta['length'], 5000, 4) + out = Parallel(n_jobs=self.nprocess)(delayed(self.compute_motion_energy)(first, last, wg, iw) + for iw, (first, last) in enumerate(wg.firstlast)) + # Concatenate the motion energy into one big array + self.all_me = np.array([]) + for vals in out[:-1]: + self.all_me = np.r_[self.all_me, vals] + + toverlap = self.twin - 1 + all_me = np.r_[np.full((int(self.camera_meta['fps'] * toverlap)), np.nan), self.all_me] + to_app = self.times[0] - ((np.arange(int(self.camera_meta['fps'] * toverlap), ) + 1) / self.frate)[::-1] + times = np.r_[to_app, self.times] + + wg = WindowGenerator(all_me.size - 1, int(self.camera_meta['fps'] * self.twin), + int(self.camera_meta['fps'] * toverlap)) + + out = Parallel(n_jobs=self.nprocess)(delayed(self.compute_shifts)(times, all_me, first, last, iw, wg) + for iw, (first, last) in enumerate(wg.firstlast)) + + self.shifts = np.array([]) + self.t_shifts = np.array([]) + for vals in out[:-1]: + self.shifts = np.r_[self.shifts, vals[0]] + self.t_shifts = np.r_[self.t_shifts, vals[1]] + + idx = np.bitwise_and(self.t_shifts >= self.ttl_times[0], self.t_shifts < self.ttl_times[-1]) + self.shifts = self.shifts[idx] + self.t_shifts = self.t_shifts[idx] + shifts_filt = ndimage.percentile_filter(self.shifts, 80, 120) + shifts_filt = self.clean_shifts(shifts_filt, n=1) + self.shifts_filt = self.clean_shifts(shifts_filt, n=2) + + self.qc, self.qc_outcome = self.qc_shifts(self.shifts, self.shifts_filt) + + self.new_times = self.extract_times(self.shifts_filt, self.t_shifts) + + if self.upload: + fig = self.plot_with_behavior() if self.behavior else self.plot_without_behavior() + save_fig_path = Path(self.session_path.joinpath('snapshot', 'video', + f'video_wheel_alignment_{self.label}.png')) + save_fig_path.parent.mkdir(exist_ok=True, parents=True) + fig.savefig(save_fig_path) + snp = ReportSnapshot(self.session_path, self.eid, content_type='session', one=self.one) + snp.outputs = [save_fig_path] + snp.register_images(widths=['orig']) + + return self.new_times diff --git a/ibllib/io/session_params.py b/ibllib/io/session_params.py index 34e668ced..5bcaf2873 100644 --- a/ibllib/io/session_params.py +++ b/ibllib/io/session_params.py @@ -510,7 +510,7 @@ def prepare_experiment(session_path, acquisition_description=None, local=None, r # won't be preserved by create_basic_transfer_params by default remote = False if remote is False else params['REMOTE_DATA_FOLDER_PATH'] - # THis is in the docstring but still, if the session Path is absolute, we need to make it relative + # This is in the docstring but still, if the session Path is absolute, we need to make it relative if Path(session_path).is_absolute(): session_path = Path(*session_path.parts[-3:]) diff --git a/ibllib/pipes/behavior_tasks.py b/ibllib/pipes/behavior_tasks.py index 7cc317c28..6f1c8d506 100644 --- a/ibllib/pipes/behavior_tasks.py +++ b/ibllib/pipes/behavior_tasks.py @@ -9,14 +9,12 @@ from ibllib.oneibl.registration import get_lab from ibllib.pipes import base_tasks -from ibllib.io.raw_data_loaders import load_settings +from ibllib.io.raw_data_loaders import load_settings, load_bpod_fronts from ibllib.qc.task_extractors import TaskQCExtractor from ibllib.qc.task_metrics import HabituationQC, TaskQC from ibllib.io.extractors.ephys_passive import PassiveChoiceWorld -from ibllib.io.extractors import bpod_trials -from ibllib.io.extractors.base import get_session_extractor_type from ibllib.io.extractors.bpod_trials import get_bpod_extractor -from ibllib.io.extractors.ephys_fpga import extract_all +from ibllib.io.extractors.ephys_fpga import FpgaTrials, get_sync_and_chn_map from ibllib.io.extractors.mesoscope import TimelineTrials from ibllib.pipes import training_status from ibllib.plots.figures import BehaviourPlots @@ -73,25 +71,43 @@ def signature(self): } return signature - def _run(self, update=True): + def _run(self, update=True, save=True): """ Extracts an iblrig training session """ - extractor = bpod_trials.get_bpod_extractor(self.session_path, task_collection=self.collection) - trials, output_files = extractor.extract(task_collection=self.collection, save=True) + trials, output_files = self._extract_behaviour(save=save) if trials is None: return None if self.one is None or self.one.offline: return output_files + # Run the task QC + self._run_qc(trials, update=update) + return output_files + + def _extract_behaviour(self, **kwargs): + self.extractor = get_bpod_extractor(self.session_path, task_collection=self.collection) + self.extractor.default_path = self.output_collection + return self.extractor.extract(task_collection=self.collection, **kwargs) + + def _run_qc(self, trials_data=None, update=True): + if not self.extractor or trials_data is None: + trials_data, _ = self._extract_behaviour(save=False) + if not trials_data: + raise ValueError('No trials data found') + # Compile task data for QC qc = HabituationQC(self.session_path, one=self.one) - qc.extractor = TaskQCExtractor(self.session_path, sync_collection=self.sync_collection, + qc.extractor = TaskQCExtractor(self.session_path, lazy=True, sync_collection=self.sync_collection, one=self.one, sync_type=self.sync, task_collection=self.collection) + + # Currently only the data field is accessed + qc.extractor.data = qc.extractor.rename_data(trials_data.copy()) + namespace = 'task' if self.protocol_number is None else f'task_{self.protocol_number:02}' qc.run(update=update, namespace=namespace) - return output_files + return qc class TrialRegisterRaw(base_tasks.RegisterRawDataTask, base_tasks.BehaviourTask): @@ -213,6 +229,7 @@ def _run(self, **kwargs): class ChoiceWorldTrialsBpod(base_tasks.BehaviourTask): priority = 90 job_size = 'small' + extractor = None @property def signature(self): @@ -234,38 +251,53 @@ def signature(self): } return signature - def _run(self, update=True): + def _run(self, update=True, save=True): """ Extracts an iblrig training session """ - extractor = bpod_trials.get_bpod_extractor(self.session_path, task_collection=self.collection) - extractor.default_path = self.output_collection - trials, output_files = extractor.extract(task_collection=self.collection, save=True) + trials, output_files = self._extract_behaviour(save=save) if trials is None: return None if self.one is None or self.one.offline: return output_files + # Run the task QC + self._run_qc(trials) + + return output_files + + def _extract_behaviour(self, **kwargs): + self.extractor = get_bpod_extractor(self.session_path, task_collection=self.collection) + self.extractor.default_path = self.output_collection + return self.extractor.extract(task_collection=self.collection, **kwargs) + + def _run_qc(self, trials_data=None, update=True): + if not self.extractor or trials_data is None: + trials_data, _ = self._extract_behaviour(save=False) + if not trials_data: + raise ValueError('No trials data found') + # Compile task data for QC - type = get_session_extractor_type(self.session_path, task_collection=self.collection) - # FIXME Task data should not need re-extracting - if type == 'habituation': - qc = HabituationQC(self.session_path, one=self.one) - qc.extractor = TaskQCExtractor(self.session_path, one=self.one, sync_collection=self.sync_collection, - sync_type=self.sync, task_collection=self.collection) - else: # Update wheel data - qc = TaskQC(self.session_path, one=self.one) - qc.extractor = TaskQCExtractor(self.session_path, one=self.one, sync_collection=self.sync_collection, - sync_type=self.sync, task_collection=self.collection) - qc.extractor.wheel_encoding = 'X1' + qc_extractor = TaskQCExtractor(self.session_path, lazy=True, sync_collection=self.sync_collection, one=self.one, + sync_type=self.sync, task_collection=self.collection) + qc_extractor.data = qc_extractor.rename_data(trials_data) + if type(self.extractor).__name__ == 'HabituationTrials': + qc = HabituationQC(self.session_path, one=self.one, log=_logger) + else: + qc = TaskQC(self.session_path, one=self.one, log=_logger) + qc_extractor.wheel_encoding = 'X1' + qc_extractor.settings = self.extractor.settings + qc_extractor.frame_ttls, qc_extractor.audio_ttls = load_bpod_fronts( + self.session_path, task_collection=self.collection) + qc.extractor = qc_extractor + # Aggregate and update Alyx QC fields namespace = 'task' if self.protocol_number is None else f'task_{self.protocol_number:02}' qc.run(update=update, namespace=namespace) - - return output_files + return qc -class ChoiceWorldTrialsNidq(base_tasks.BehaviourTask): +class ChoiceWorldTrialsNidq(ChoiceWorldTrialsBpod): priority = 90 job_size = 'small' @@ -312,21 +344,41 @@ def _behaviour_criterion(self, update=True): "sessions", eid, "extended_qc", {"behavior": int(good_enough)} ) - def _extract_behaviour(self): - dsets, out_files = extract_all(self.session_path, self.sync_collection, task_collection=self.collection, - save_path=self.session_path.joinpath(self.output_collection), - protocol_number=self.protocol_number, save=True) + def _extract_behaviour(self, save=True, **kwargs): + # Extract Bpod trials + bpod_trials, _ = super()._extract_behaviour(save=False, **kwargs) - return dsets, out_files + # Sync Bpod trials to FPGA + sync, chmap = get_sync_and_chn_map(self.session_path, self.sync_collection) + self.extractor = FpgaTrials(self.session_path, bpod_trials=bpod_trials, bpod_extractor=self.extractor) + outputs, files = self.extractor.extract( + save=save, sync=sync, chmap=chmap, path_out=self.session_path.joinpath(self.output_collection), + task_collection=self.collection, protocol_number=self.protocol_number, **kwargs) + return outputs, files - def _run_qc(self, trials_data, update=True, plot_qc=True): - # Run the task QC - qc = TaskQC(self.session_path, one=self.one, log=_logger) - qc.extractor = TaskQCExtractor(self.session_path, lazy=True, one=qc.one, sync_collection=self.sync_collection, + def _run_qc(self, trials_data=None, update=False, plot_qc=False): + if not self.extractor or trials_data is None: + trials_data, _ = self._extract_behaviour(save=False) + if not trials_data: + raise ValueError('No trials data found') + + # Compile task data for QC + qc_extractor = TaskQCExtractor(self.session_path, lazy=True, sync_collection=self.sync_collection, one=self.one, sync_type=self.sync, task_collection=self.collection) - # Extract extra datasets required for QC - qc.extractor.data = trials_data # FIXME This line is pointless - qc.extractor.extract_data() + qc_extractor.data = qc_extractor.rename_data(trials_data.copy()) + if type(self.extractor).__name__ == 'HabituationTrials': + qc = HabituationQC(self.session_path, one=self.one, log=_logger) + else: + qc = TaskQC(self.session_path, one=self.one, log=_logger) + qc_extractor.settings = self.extractor.settings + # Add Bpod wheel data + wheel_ts_bpod = self.extractor.bpod2fpga(self.extractor.bpod_trials['wheel_timestamps']) + qc_extractor.data['wheel_timestamps_bpod'] = wheel_ts_bpod + qc_extractor.data['wheel_position_bpod'] = self.extractor.bpod_trials['wheel_position'] + qc_extractor.wheel_encoding = 'X4' + qc_extractor.frame_ttls = self.extractor.frame2ttl + qc_extractor.audio_ttls = self.extractor.audio + qc.extractor = qc_extractor # Aggregate and update Alyx QC fields namespace = 'task' if self.protocol_number is None else f'task_{self.protocol_number:02}' @@ -345,9 +397,10 @@ def _run_qc(self, trials_data, update=True, plot_qc=True): _logger.error('Could not create Trials QC Plot') _logger.error(traceback.format_exc()) self.status = -1 + return qc - def _run(self, update=True, plot_qc=True): - dsets, out_files = self._extract_behaviour() + def _run(self, update=True, plot_qc=True, save=True): + dsets, out_files = self._extract_behaviour(save=save) if not self.one or self.one.offline: return out_files @@ -378,63 +431,24 @@ def signature(self): for fn in filter(None, extractor.save_names)] return signature - def _extract_behaviour(self): + def _extract_behaviour(self, save=True, **kwargs): """Extract the Bpod trials data and Timeline acquired signals.""" # First determine the extractor from the task protocol - extractor = get_bpod_extractor(self.session_path, self.protocol, self.collection) - ret, _ = extractor.extract(save=False, task_collection=self.collection) - bpod_trials = {k: v for k, v in zip(extractor.var_names, ret)} + bpod_trials, _ = ChoiceWorldTrialsBpod._extract_behaviour(self, save=False, **kwargs) - trials = TimelineTrials(self.session_path, bpod_trials=bpod_trials) + # Sync Bpod trials to DAQ + self.extractor = TimelineTrials(self.session_path, bpod_trials=bpod_trials, bpod_extractor=self.extractor) save_path = self.session_path / self.output_collection - if not self._spacer_support(extractor.settings): + if not self._spacer_support(self.extractor.settings): _logger.warning('Protocol spacers not supported; setting protocol_number to None') self.protocol_number = None - dsets, out_files = trials.extract( - save=True, path_out=save_path, sync_collection=self.sync_collection, - task_collection=self.collection, protocol_number=self.protocol_number) - if not isinstance(dsets, dict): - dsets = {k: v for k, v in zip(trials.var_names, dsets)} - - self.timeline = trials.timeline # Store for QC later - self.frame2ttl = trials.frame2ttl - self.audio = trials.audio + dsets, out_files = self.extractor.extract( + save=save, path_out=save_path, sync_collection=self.sync_collection, + task_collection=self.collection, protocol_number=self.protocol_number, **kwargs) return dsets, out_files - def _run_qc(self, trials_data, update=True, **kwargs): - """ - Run the task QC and update Alyx with results. - - Parameters - ---------- - trials_data : dict - The extracted trials data. - update : bool - If true, update Alyx with the result. - - Notes - ----- - - Unlike the super class, currently the QC plots are not generated. - - Expects the frame2ttl and audio attributes to be set from running _extract_behaviour. - """ - # TODO Task QC extractor for Timeline - qc = TaskQC(self.session_path, one=self.one, log=_logger) - qc.extractor = TaskQCExtractor(self.session_path, lazy=True, one=qc.one, sync_collection=self.sync_collection, - sync_type=self.sync, task_collection=self.collection) - # Extract extra datasets required for QC - qc.extractor.data = TaskQCExtractor.rename_data(trials_data.copy()) - qc.extractor.load_raw_data() - - qc.extractor.frame_ttls = self.frame2ttl - qc.extractor.audio_ttls = self.audio - # qc.extractor.bpod_ttls = channel_events('bpod') - - # Aggregate and update Alyx QC fields - namespace = 'task' if self.protocol_number is None else f'task_{self.protocol_number:02}' - qc.run(update=update, namespace=namespace) - class TrainingStatus(base_tasks.BehaviourTask): priority = 90 diff --git a/ibllib/pipes/local_server.py b/ibllib/pipes/local_server.py index 895b0f20b..47f6322b5 100644 --- a/ibllib/pipes/local_server.py +++ b/ibllib/pipes/local_server.py @@ -10,7 +10,7 @@ from one.api import ONE from one.webclient import AlyxClient -from one.remote.globus import get_lab_from_endpoint_id +from one.remote.globus import get_lab_from_endpoint_id, get_local_endpoint_id from iblutil.util import setup_logger from ibllib.io.extractors.base import get_pipeline, get_task_protocol, get_session_extractor_type @@ -74,9 +74,10 @@ def report_health(one): status.update(_get_volume_usage('/mnt/s0/Data', 'raid')) status.update(_get_volume_usage('/', 'system')) - lab_names = get_lab_from_endpoint_id(alyx=one.alyx) - for ln in lab_names: - one.alyx.json_field_update(endpoint='labs', uuid=ln, field_name='json', data=status) + data_repos = one.alyx.rest('data-repository', 'list', globus_endpoint_id=get_local_endpoint_id()) + + for dr in data_repos: + one.alyx.json_field_update(endpoint='data-repository', uuid=dr['name'], field_name='json', data=status) def job_creator(root_path, one=None, dry=False, rerun=False, max_md5_size=None): diff --git a/ibllib/plots/misc.py b/ibllib/plots/misc.py index 133eb12e8..36cd56afb 100644 --- a/ibllib/plots/misc.py +++ b/ibllib/plots/misc.py @@ -74,13 +74,19 @@ def insert_zeros(trace): class Density: - def __init__(self, w, fs=1, cmap='Greys_r', ax=None, taxis=0, title=None, **kwargs): + def __init__(self, w, fs=30_000, cmap='Greys_r', ax=None, taxis=0, title=None, gain=None, **kwargs): """ - Matplotlib display of traces as a density display + Matplotlib display of traces as a density display using `imshow()`. :param w: 2D array (numpy array dimension nsamples, ntraces) - :param fs: sampling frequency (Hz) - :param ax: axis to plot in + :param fs: sampling frequency (Hz). [default: 30000] + :param cmap: Name of MPL colormap to use in `imshow()`. [default: 'Greys_r'] + :param ax: Axis to plot in. If `None`, a new one is created. [default: `None`] + :param taxis: Time axis of input array (w). [default: 0] + :param title: Title to display on plot. [default: `None`] + :param gain: Gain in dB to display. Note: overrides `vmin` and `vmax` kwargs to `imshow()`. + Default: [`None` (auto)] + :param kwargs: Key word arguments passed to `imshow()` :return: None """ w = w.reshape(w.shape[0], -1) @@ -98,6 +104,9 @@ def __init__(self, w, fs=1, cmap='Greys_r', ax=None, taxis=0, title=None, **kwar self.figure, ax = plt.subplots() else: self.figure = ax.get_figure() + if gain: + kwargs["vmin"] = - 4 * (10 ** (gain / 20)) + kwargs["vmax"] = -kwargs["vmin"] self.im = ax.imshow(w, aspect='auto', cmap=cmap, extent=extent, origin=origin, **kwargs) ax.set_ylabel(ylabel) ax.set_xlabel(xlabel) diff --git a/ibllib/qc/task_extractors.py b/ibllib/qc/task_extractors.py index f0d46ed02..5f5269710 100644 --- a/ibllib/qc/task_extractors.py +++ b/ibllib/qc/task_extractors.py @@ -1,4 +1,5 @@ import logging +import warnings import numpy as np from scipy.interpolate import interp1d @@ -26,16 +27,16 @@ 'wheel_position', 'wheel_timestamps'] -class TaskQCExtractor(object): +class TaskQCExtractor: def __init__(self, session_path, lazy=False, one=None, download_data=False, bpod_only=False, sync_collection=None, sync_type=None, task_collection=None): """ - A class for extracting the task data required to perform task quality control + A class for extracting the task data required to perform task quality control. :param session_path: a valid session path :param lazy: if True, the data are not extracted immediately :param one: an instance of ONE, used to download the raw data if download_data is True :param download_data: if True, any missing raw data is downloaded via ONE - :param bpod_only: extract from from raw Bpod data only, even for FPGA sessions + :param bpod_only: extract from raw Bpod data only, even for FPGA sessions """ if not is_session_path(session_path): raise ValueError('Invalid session path') @@ -151,6 +152,8 @@ def extract_data(self): intervals_bpod to be assigned to the data attribute before calling this function. :return: """ + warnings.warn('The TaskQCExtractor.extract_data will be removed in the future, ' + 'use dynamic pipeline behaviour tasks instead.', DeprecationWarning) self.log.info(f'Extracting session: {self.session_path}') self.type = self.type or get_session_extractor_type(self.session_path, task_collection=self.task_collection) # Finds the sync type when it isn't explicitly set, if ephys we assume nidq otherwise bpod diff --git a/ibllib/qc/task_metrics.py b/ibllib/qc/task_metrics.py index 36f2b4806..42361645d 100644 --- a/ibllib/qc/task_metrics.py +++ b/ibllib/qc/task_metrics.py @@ -69,21 +69,21 @@ class TaskQC(base.QC): """A class for computing task QC metrics""" criteria = dict() - criteria['default'] = {"PASS": 0.99, "WARNING": 0.90, "FAIL": 0} # Note: WARNING was 0.95 prior to Aug 2022 - criteria['_task_stimOff_itiIn_delays'] = {"PASS": 0.99, "WARNING": 0} - criteria['_task_positive_feedback_stimOff_delays'] = {"PASS": 0.99, "WARNING": 0} - criteria['_task_negative_feedback_stimOff_delays'] = {"PASS": 0.99, "WARNING": 0} - criteria['_task_wheel_move_during_closed_loop'] = {"PASS": 0.99, "WARNING": 0} - criteria['_task_response_stimFreeze_delays'] = {"PASS": 0.99, "WARNING": 0} - criteria['_task_detected_wheel_moves'] = {"PASS": 0.99, "WARNING": 0} - criteria['_task_trial_length'] = {"PASS": 0.99, "WARNING": 0} - criteria['_task_goCue_delays'] = {"PASS": 0.99, "WARNING": 0} - criteria['_task_errorCue_delays'] = {"PASS": 0.99, "WARNING": 0} - criteria['_task_stimOn_delays'] = {"PASS": 0.99, "WARNING": 0} - criteria['_task_stimOff_delays'] = {"PASS": 0.99, "WARNING": 0} - criteria['_task_stimFreeze_delays'] = {"PASS": 0.99, "WARNING": 0} - criteria['_task_iti_delays'] = {"NOT_SET": 0} - criteria['_task_passed_trial_checks'] = {"NOT_SET": 0} + criteria['default'] = {'PASS': 0.99, 'WARNING': 0.90, 'FAIL': 0} # Note: WARNING was 0.95 prior to Aug 2022 + criteria['_task_stimOff_itiIn_delays'] = {'PASS': 0.99, 'WARNING': 0} + criteria['_task_positive_feedback_stimOff_delays'] = {'PASS': 0.99, 'WARNING': 0} + criteria['_task_negative_feedback_stimOff_delays'] = {'PASS': 0.99, 'WARNING': 0} + criteria['_task_wheel_move_during_closed_loop'] = {'PASS': 0.99, 'WARNING': 0} + criteria['_task_response_stimFreeze_delays'] = {'PASS': 0.99, 'WARNING': 0} + criteria['_task_detected_wheel_moves'] = {'PASS': 0.99, 'WARNING': 0} + criteria['_task_trial_length'] = {'PASS': 0.99, 'WARNING': 0} + criteria['_task_goCue_delays'] = {'PASS': 0.99, 'WARNING': 0} + criteria['_task_errorCue_delays'] = {'PASS': 0.99, 'WARNING': 0} + criteria['_task_stimOn_delays'] = {'PASS': 0.99, 'WARNING': 0} + criteria['_task_stimOff_delays'] = {'PASS': 0.99, 'WARNING': 0} + criteria['_task_stimFreeze_delays'] = {'PASS': 0.99, 'WARNING': 0} + criteria['_task_iti_delays'] = {'NOT_SET': 0} + criteria['_task_passed_trial_checks'] = {'NOT_SET': 0} @staticmethod def _thresholding(qc_value, thresholds=None): @@ -100,7 +100,7 @@ def _thresholding(qc_value, thresholds=None): if qc_value is None or np.isnan(qc_value): return int(-1) elif (qc_value > MAX_BOUND) or (qc_value < MIN_BOUND): - raise ValueError("Values out of bound") + raise ValueError('Values out of bound') if 'PASS' in thresholds.keys() and qc_value >= thresholds['PASS']: return 0 if 'WARNING' in thresholds.keys() and qc_value >= thresholds['WARNING']: @@ -151,7 +151,7 @@ def compute(self, **kwargs): if self.extractor is None: kwargs['download_data'] = kwargs.pop('download_data', self.download_data) self.load_data(**kwargs) - self.log.info(f"Session {self.session_path}: Running QC on behavior data...") + self.log.info(f'Session {self.session_path}: Running QC on behavior data...') self.metrics, self.passed = get_bpodqc_metrics_frame( self.extractor.data, wheel_gain=self.extractor.settings['STIM_GAIN'], # The wheel gain @@ -229,7 +229,7 @@ def compute(self, download_data=None): # If download_data is None, decide based on whether eid or session path was provided ensure_data = self.download_data if download_data is None else download_data self.load_data(download_data=ensure_data) - self.log.info(f"Session {self.session_path}: Running QC on habituation data...") + self.log.info(f'Session {self.session_path}: Running QC on habituation data...') # Initialize checks prefix = '_task_' @@ -274,16 +274,16 @@ def compute(self, download_data=None): # Check event orders: trial_start < stim on < stim center < feedback < stim off check = prefix + 'trial_event_sequence' nans = ( - np.isnan(data["intervals"][:, 0]) | # noqa - np.isnan(data["stimOn_times"]) | # noqa - np.isnan(data["stimCenter_times"]) | - np.isnan(data["valveOpen_times"]) | # noqa - np.isnan(data["stimOff_times"]) + np.isnan(data['intervals'][:, 0]) | # noqa + np.isnan(data['stimOn_times']) | # noqa + np.isnan(data['stimCenter_times']) | + np.isnan(data['valveOpen_times']) | # noqa + np.isnan(data['stimOff_times']) ) - a = np.less(data["intervals"][:, 0], data["stimOn_times"], where=~nans) - b = np.less(data["stimOn_times"], data["stimCenter_times"], where=~nans) - c = np.less(data["stimCenter_times"], data["valveOpen_times"], where=~nans) - d = np.less(data["valveOpen_times"], data["stimOff_times"], where=~nans) + a = np.less(data['intervals'][:, 0], data['stimOn_times'], where=~nans) + b = np.less(data['stimOn_times'], data['stimCenter_times'], where=~nans) + c = np.less(data['stimCenter_times'], data['valveOpen_times'], where=~nans) + d = np.less(data['valveOpen_times'], data['stimOff_times'], where=~nans) metrics[check] = a & b & c & d & ~nans passed[check] = metrics[check].astype(float) @@ -291,7 +291,7 @@ def compute(self, download_data=None): # Check that the time difference between the visual stimulus center-command being # triggered and the stimulus effectively appearing in the center is smaller than 150 ms. check = prefix + 'stimCenter_delays' - metric = np.nan_to_num(data["stimCenter_times"] - data["stimCenterTrigger_times"], + metric = np.nan_to_num(data['stimCenter_times'] - data['stimCenterTrigger_times'], nan=np.inf) passed[check] = (metric <= 0.15) & (metric > 0) metrics[check] = metric @@ -375,9 +375,9 @@ def check_stimOn_goCue_delays(data, **_): """ # Calculate the difference between stimOn and goCue times. # If either are NaN, the result will be Inf to ensure that it crosses the failure threshold. - metric = np.nan_to_num(data["goCue_times"] - data["stimOn_times"], nan=np.inf) + metric = np.nan_to_num(data['goCue_times'] - data['stimOn_times'], nan=np.inf) passed = (metric < 0.01) & (metric > 0) - assert data["intervals"].shape[0] == len(metric) == len(passed) + assert data['intervals'].shape[0] == len(metric) == len(passed) return metric, passed @@ -391,9 +391,9 @@ def check_response_feedback_delays(data, **_): :param data: dict of trial data with keys ('feedback_times', 'response_times', 'intervals') """ - metric = np.nan_to_num(data["feedback_times"] - data["response_times"], nan=np.inf) + metric = np.nan_to_num(data['feedback_times'] - data['response_times'], nan=np.inf) passed = (metric < 0.01) & (metric > 0) - assert data["intervals"].shape[0] == len(metric) == len(passed) + assert data['intervals'].shape[0] == len(metric) == len(passed) return metric, passed @@ -410,13 +410,13 @@ def check_response_stimFreeze_delays(data, **_): """ # Calculate the difference between stimOn and goCue times. # If either are NaN, the result will be Inf to ensure that it crosses the failure threshold. - metric = np.nan_to_num(data["stimFreeze_times"] - data["response_times"], nan=np.inf) + metric = np.nan_to_num(data['stimFreeze_times'] - data['response_times'], nan=np.inf) # Test for valid values passed = ((metric < 0.1) & (metric > 0)).astype(float) # Finally remove no_go trials (stimFreeze triggered differently in no_go trials) # These values are ignored in calculation of proportion passed - passed[data["choice"] == 0] = np.nan - assert data["intervals"].shape[0] == len(metric) == len(passed) + passed[data['choice'] == 0] = np.nan + assert data['intervals'].shape[0] == len(metric) == len(passed) return metric, passed @@ -431,12 +431,12 @@ def check_stimOff_itiIn_delays(data, **_): 'choice') """ # If either are NaN, the result will be Inf to ensure that it crosses the failure threshold. - metric = np.nan_to_num(data["itiIn_times"] - data["stimOff_times"], nan=np.inf) + metric = np.nan_to_num(data['itiIn_times'] - data['stimOff_times'], nan=np.inf) passed = ((metric < 0.01) & (metric >= 0)).astype(float) # Remove no_go trials (stimOff triggered differently in no_go trials) # NaN values are ignored in calculation of proportion passed - metric[data["choice"] == 0] = passed[data["choice"] == 0] = np.nan - assert data["intervals"].shape[0] == len(metric) == len(passed) + metric[data['choice'] == 0] = passed[data['choice'] == 0] = np.nan + assert data['intervals'].shape[0] == len(metric) == len(passed) return metric, passed @@ -451,14 +451,14 @@ def check_iti_delays(data, **_): :param data: dict of trial data with keys ('stimOff_times', 'intervals') """ # Initialize array the length of completed trials - metric = np.full(data["intervals"].shape[0], np.nan) + metric = np.full(data['intervals'].shape[0], np.nan) passed = metric.copy() # Get the difference between stim off and the start of the next trial # Missing data are set to Inf, except for the last trial which is a NaN metric[:-1] = \ - np.nan_to_num(data["intervals"][1:, 0] - data["stimOff_times"][:-1] - 0.5, nan=np.inf) + np.nan_to_num(data['intervals'][1:, 0] - data['stimOff_times'][:-1] - 0.5, nan=np.inf) passed[:-1] = np.abs(metric[:-1]) < .5 # Last trial is not counted - assert data["intervals"].shape[0] == len(metric) == len(passed) + assert data['intervals'].shape[0] == len(metric) == len(passed) return metric, passed @@ -474,11 +474,11 @@ def check_positive_feedback_stimOff_delays(data, **_): 'correct') """ # If either are NaN, the result will be Inf to ensure that it crosses the failure threshold. - metric = np.nan_to_num(data["stimOff_times"] - data["feedback_times"] - 1, nan=np.inf) + metric = np.nan_to_num(data['stimOff_times'] - data['feedback_times'] - 1, nan=np.inf) passed = (np.abs(metric) < 0.15).astype(float) # NaN values are ignored in calculation of proportion passed; ignore incorrect trials here - metric[~data["correct"]] = passed[~data["correct"]] = np.nan - assert data["intervals"].shape[0] == len(metric) == len(passed) + metric[~data['correct']] = passed[~data['correct']] = np.nan + assert data['intervals'].shape[0] == len(metric) == len(passed) return metric, passed @@ -492,12 +492,12 @@ def check_negative_feedback_stimOff_delays(data, **_): :param data: dict of trial data with keys ('stimOff_times', 'errorCue_times', 'intervals') """ - metric = np.nan_to_num(data["stimOff_times"] - data["errorCue_times"] - 2, nan=np.inf) + metric = np.nan_to_num(data['stimOff_times'] - data['errorCue_times'] - 2, nan=np.inf) # Apply criteria passed = (np.abs(metric) < 0.15).astype(float) # Remove none negative feedback trials - metric[data["correct"]] = passed[data["correct"]] = np.nan - assert data["intervals"].shape[0] == len(metric) == len(passed) + metric[data['correct']] = passed[data['correct']] = np.nan + assert data['intervals'].shape[0] == len(metric) == len(passed) return metric, passed @@ -515,12 +515,12 @@ def check_wheel_move_before_feedback(data, **_): """ # Get tuple of wheel times and positions within 100ms of feedback traces = traces_by_trial( - data["wheel_timestamps"], - data["wheel_position"], - start=data["feedback_times"] - 0.05, - end=data["feedback_times"] + 0.05, + data['wheel_timestamps'], + data['wheel_position'], + start=data['feedback_times'] - 0.05, + end=data['feedback_times'] + 0.05, ) - metric = np.zeros_like(data["feedback_times"]) + metric = np.zeros_like(data['feedback_times']) # For each trial find the displacement for i, trial in enumerate(traces): pos = trial[1] @@ -528,12 +528,12 @@ def check_wheel_move_before_feedback(data, **_): metric[i] = pos[-1] - pos[0] # except no-go trials - metric[data["choice"] == 0] = np.nan # NaN = trial ignored for this check + metric[data['choice'] == 0] = np.nan # NaN = trial ignored for this check nans = np.isnan(metric) passed = np.zeros_like(metric) * np.nan passed[~nans] = (metric[~nans] != 0).astype(float) - assert data["intervals"].shape[0] == len(metric) == len(passed) + assert data['intervals'].shape[0] == len(metric) == len(passed) return metric, passed @@ -555,15 +555,15 @@ def _wheel_move_during_closed_loop(re_ts, re_pos, data, wheel_gain=None, tol=1, :param tol: the criterion in visual degrees """ if wheel_gain is None: - _log.warning("No wheel_gain input in function call, returning None") + _log.warning('No wheel_gain input in function call, returning None') return None, None # Get tuple of wheel times and positions over each trial's closed-loop period traces = traces_by_trial(re_ts, re_pos, - start=data["goCueTrigger_times"], - end=data["response_times"]) + start=data['goCueTrigger_times'], + end=data['response_times']) - metric = np.zeros_like(data["feedback_times"]) + metric = np.zeros_like(data['feedback_times']) # For each trial find the absolute displacement for i, trial in enumerate(traces): t, pos = trial @@ -574,16 +574,16 @@ def _wheel_move_during_closed_loop(re_ts, re_pos, data, wheel_gain=None, tol=1, metric[i] = np.abs(pos - origin).max() # Load wheel_gain and thresholds for each trial - wheel_gain = np.array([wheel_gain] * len(data["position"])) - thresh = data["position"] + wheel_gain = np.array([wheel_gain] * len(data['position'])) + thresh = data['position'] # abs displacement, s, in mm required to move 35 visual degrees s_mm = np.abs(thresh / wheel_gain) # don't care about direction criterion = cm_to_rad(s_mm * 1e-1) # convert abs displacement to radians (wheel pos is in rad) metric = metric - criterion # difference should be close to 0 rad_per_deg = cm_to_rad(1 / wheel_gain * 1e-1) passed = (np.abs(metric) < rad_per_deg * tol).astype(float) # less than 1 visual degree off - metric[data["choice"] == 0] = passed[data["choice"] == 0] = np.nan # except no-go trials - assert data["intervals"].shape[0] == len(metric) == len(passed) + metric[data['choice'] == 0] = passed[data['choice'] == 0] = np.nan # except no-go trials + assert data['intervals'].shape[0] == len(metric) == len(passed) return metric, passed @@ -642,25 +642,25 @@ def check_wheel_freeze_during_quiescence(data, **_): :param data: dict of trial data with keys ('wheel_timestamps', 'wheel_position', 'quiescence', 'intervals', 'stimOnTrigger_times') """ - assert np.all(np.diff(data["wheel_timestamps"]) >= 0) - assert data["quiescence"].size == data["stimOnTrigger_times"].size + assert np.all(np.diff(data['wheel_timestamps']) >= 0) + assert data['quiescence'].size == data['stimOnTrigger_times'].size # Get tuple of wheel times and positions over each trial's quiescence period - qevt_start_times = data["stimOnTrigger_times"] - data["quiescence"] + qevt_start_times = data['stimOnTrigger_times'] - data['quiescence'] traces = traces_by_trial( - data["wheel_timestamps"], - data["wheel_position"], + data['wheel_timestamps'], + data['wheel_position'], start=qevt_start_times, - end=data["stimOnTrigger_times"] + end=data['stimOnTrigger_times'] ) - metric = np.zeros((len(data["quiescence"]), 2)) # (n_trials, n_directions) + metric = np.zeros((len(data['quiescence']), 2)) # (n_trials, n_directions) for i, trial in enumerate(traces): t, pos = trial # Get the last position before the period began if pos.size > 0: # Find the position of the preceding sample and subtract it - idx = np.abs(data["wheel_timestamps"] - t[0]).argmin() - 1 - origin = data["wheel_position"][idx if idx != -1 else 0] + idx = np.abs(data['wheel_timestamps'] - t[0]).argmin() - 1 + origin = data['wheel_position'][idx if idx != -1 else 0] # Find the absolute min and max relative to the last sample metric[i, :] = np.abs([np.min(pos - origin), np.max(pos - origin)]) # Reduce to the largest displacement found in any direction @@ -668,7 +668,7 @@ def check_wheel_freeze_during_quiescence(data, **_): metric = 180 * metric / np.pi # convert to degrees from radians criterion = 2 # Position shouldn't change more than 2 in either direction passed = metric < criterion - assert data["intervals"].shape[0] == len(metric) == len(passed) + assert data['intervals'].shape[0] == len(metric) == len(passed) return metric, passed @@ -685,8 +685,8 @@ def check_detected_wheel_moves(data, min_qt=0, **_): """ # Depending on task version this may be a single value or an array of quiescent periods min_qt = np.array(min_qt) - if min_qt.size > data["intervals"].shape[0]: - min_qt = min_qt[:data["intervals"].shape[0]] + if min_qt.size > data['intervals'].shape[0]: + min_qt = min_qt[:data['intervals'].shape[0]] metric = data['firstMovement_times'] qevt_start = data['goCueTrigger_times'] - np.array(min_qt) @@ -714,25 +714,25 @@ def check_error_trial_event_sequence(data, **_): """ # An array the length of N trials where True means at least one event time was NaN (bad) nans = ( - np.isnan(data["intervals"][:, 0]) | - np.isnan(data["goCue_times"]) | # noqa - np.isnan(data["errorCue_times"]) | # noqa - np.isnan(data["itiIn_times"]) | # noqa - np.isnan(data["intervals"][:, 1]) + np.isnan(data['intervals'][:, 0]) | + np.isnan(data['goCue_times']) | # noqa + np.isnan(data['errorCue_times']) | # noqa + np.isnan(data['itiIn_times']) | # noqa + np.isnan(data['intervals'][:, 1]) ) # For each trial check that the events happened in the correct order (ignore NaN values) - a = np.less(data["intervals"][:, 0], data["goCue_times"], where=~nans) # Start time < go cue - b = np.less(data["goCue_times"], data["errorCue_times"], where=~nans) # Go cue < error cue - c = np.less(data["errorCue_times"], data["itiIn_times"], where=~nans) # Error cue < ITI start - d = np.less(data["itiIn_times"], data["intervals"][:, 1], where=~nans) # ITI start < end time + a = np.less(data['intervals'][:, 0], data['goCue_times'], where=~nans) # Start time < go cue + b = np.less(data['goCue_times'], data['errorCue_times'], where=~nans) # Go cue < error cue + c = np.less(data['errorCue_times'], data['itiIn_times'], where=~nans) # Error cue < ITI start + d = np.less(data['itiIn_times'], data['intervals'][:, 1], where=~nans) # ITI start < end time # For each trial check all events were in order AND all event times were not NaN metric = a & b & c & d & ~nans passed = metric.astype(float) - passed[data["correct"]] = np.nan # Look only at incorrect trials - assert data["intervals"].shape[0] == len(metric) == len(passed) + passed[data['correct']] = np.nan # Look only at incorrect trials + assert data['intervals'].shape[0] == len(metric) == len(passed) return metric, passed @@ -749,25 +749,25 @@ def check_correct_trial_event_sequence(data, **_): """ # An array the length of N trials where True means at least one event time was NaN (bad) nans = ( - np.isnan(data["intervals"][:, 0]) | - np.isnan(data["goCue_times"]) | # noqa - np.isnan(data["valveOpen_times"]) | - np.isnan(data["itiIn_times"]) | # noqa - np.isnan(data["intervals"][:, 1]) + np.isnan(data['intervals'][:, 0]) | + np.isnan(data['goCue_times']) | # noqa + np.isnan(data['valveOpen_times']) | + np.isnan(data['itiIn_times']) | # noqa + np.isnan(data['intervals'][:, 1]) ) # For each trial check that the events happened in the correct order (ignore NaN values) - a = np.less(data["intervals"][:, 0], data["goCue_times"], where=~nans) # Start time < go cue - b = np.less(data["goCue_times"], data["valveOpen_times"], where=~nans) # Go cue < feedback - c = np.less(data["valveOpen_times"], data["itiIn_times"], where=~nans) # Feedback < ITI start - d = np.less(data["itiIn_times"], data["intervals"][:, 1], where=~nans) # ITI start < end time + a = np.less(data['intervals'][:, 0], data['goCue_times'], where=~nans) # Start time < go cue + b = np.less(data['goCue_times'], data['valveOpen_times'], where=~nans) # Go cue < feedback + c = np.less(data['valveOpen_times'], data['itiIn_times'], where=~nans) # Feedback < ITI start + d = np.less(data['itiIn_times'], data['intervals'][:, 1], where=~nans) # ITI start < end time # For each trial True means all events were in order AND all event times were not NaN metric = a & b & c & d & ~nans passed = metric.astype(float) - passed[~data["correct"]] = np.nan # Look only at correct trials - assert data["intervals"].shape[0] == len(metric) == len(passed) + passed[~data['correct']] = np.nan # Look only at correct trials + assert data['intervals'].shape[0] == len(metric) == len(passed) return metric, passed @@ -799,7 +799,7 @@ def check_n_trial_events(data, **_): 'wheel_moves_peak_amplitude', 'wheel_moves_intervals', 'wheel_timestamps', 'wheel_intervals', 'stimFreeze_times'] events = [k for k in data.keys() if k.endswith('_times') and k not in exclude] - metric = np.zeros(data["intervals"].shape[0], dtype=bool) + metric = np.zeros(data['intervals'].shape[0], dtype=bool) # For each trial interval check that one of each trial event occurred. For incorrect trials, # check the error cue trigger occurred within the interval, otherwise check it is nan. @@ -822,9 +822,9 @@ def check_trial_length(data, **_): :param data: dict of trial data with keys ('feedback_times', 'goCue_times', 'intervals') """ # NaN values are usually ignored so replace them with Inf so they fail the threshold - metric = np.nan_to_num(data["feedback_times"] - data["goCue_times"], nan=np.inf) + metric = np.nan_to_num(data['feedback_times'] - data['goCue_times'], nan=np.inf) passed = (metric < 60.1) & (metric > 0) - assert data["intervals"].shape[0] == len(metric) == len(passed) + assert data['intervals'].shape[0] == len(metric) == len(passed) return metric, passed @@ -835,14 +835,14 @@ def check_goCue_delays(data, **_): effectively played is smaller than 1ms. Metric: M = goCue_times - goCueTrigger_times - Criterion: 0 < M <= 0.001 s + Criterion: 0 < M <= 0.0015 s Units: seconds [s] :param data: dict of trial data with keys ('goCue_times', 'goCueTrigger_times', 'intervals') """ - metric = np.nan_to_num(data["goCue_times"] - data["goCueTrigger_times"], nan=np.inf) + metric = np.nan_to_num(data['goCue_times'] - data['goCueTrigger_times'], nan=np.inf) passed = (metric <= 0.0015) & (metric > 0) - assert data["intervals"].shape[0] == len(metric) == len(passed) + assert data['intervals'].shape[0] == len(metric) == len(passed) return metric, passed @@ -850,16 +850,16 @@ def check_errorCue_delays(data, **_): """ Check that the time difference between the error sound being triggered and effectively played is smaller than 1ms. Metric: M = errorCue_times - errorCueTrigger_times - Criterion: 0 < M <= 0.001 s + Criterion: 0 < M <= 0.0015 s Units: seconds [s] :param data: dict of trial data with keys ('errorCue_times', 'errorCueTrigger_times', 'intervals', 'correct') """ - metric = np.nan_to_num(data["errorCue_times"] - data["errorCueTrigger_times"], nan=np.inf) + metric = np.nan_to_num(data['errorCue_times'] - data['errorCueTrigger_times'], nan=np.inf) passed = ((metric <= 0.0015) & (metric > 0)).astype(float) - passed[data["correct"]] = metric[data["correct"]] = np.nan - assert data["intervals"].shape[0] == len(metric) == len(passed) + passed[data['correct']] = metric[data['correct']] = np.nan + assert data['intervals'].shape[0] == len(metric) == len(passed) return metric, passed @@ -868,15 +868,15 @@ def check_stimOn_delays(data, **_): and the stimulus effectively appearing on the screen is smaller than 150 ms. Metric: M = stimOn_times - stimOnTrigger_times - Criterion: 0 < M < 0.150 s + Criterion: 0 < M < 0.15 s Units: seconds [s] :param data: dict of trial data with keys ('stimOn_times', 'stimOnTrigger_times', 'intervals') """ - metric = np.nan_to_num(data["stimOn_times"] - data["stimOnTrigger_times"], nan=np.inf) + metric = np.nan_to_num(data['stimOn_times'] - data['stimOnTrigger_times'], nan=np.inf) passed = (metric <= 0.15) & (metric > 0) - assert data["intervals"].shape[0] == len(metric) == len(passed) + assert data['intervals'].shape[0] == len(metric) == len(passed) return metric, passed @@ -886,15 +886,15 @@ def check_stimOff_delays(data, **_): is smaller than 150 ms. Metric: M = stimOff_times - stimOffTrigger_times - Criterion: 0 < M < 0.150 s + Criterion: 0 < M < 0.15 s Units: seconds [s] :param data: dict of trial data with keys ('stimOff_times', 'stimOffTrigger_times', 'intervals') """ - metric = np.nan_to_num(data["stimOff_times"] - data["stimOffTrigger_times"], nan=np.inf) + metric = np.nan_to_num(data['stimOff_times'] - data['stimOffTrigger_times'], nan=np.inf) passed = (metric <= 0.15) & (metric > 0) - assert data["intervals"].shape[0] == len(metric) == len(passed) + assert data['intervals'].shape[0] == len(metric) == len(passed) return metric, passed @@ -904,15 +904,15 @@ def check_stimFreeze_delays(data, **_): is smaller than 150 ms. Metric: M = stimFreeze_times - stimFreezeTrigger_times - Criterion: 0 < M < 0.150 s + Criterion: 0 < M < 0.15 s Units: seconds [s] :param data: dict of trial data with keys ('stimFreeze_times', 'stimFreezeTrigger_times', 'intervals') """ - metric = np.nan_to_num(data["stimFreeze_times"] - data["stimFreezeTrigger_times"], nan=np.inf) + metric = np.nan_to_num(data['stimFreeze_times'] - data['stimFreezeTrigger_times'], nan=np.inf) passed = (metric <= 0.15) & (metric > 0) - assert data["intervals"].shape[0] == len(metric) == len(passed) + assert data['intervals'].shape[0] == len(metric) == len(passed) return metric, passed @@ -934,7 +934,7 @@ def check_reward_volumes(data, **_): passed[correct] = (1.5 <= metric[correct]) & (metric[correct] <= 3.) # Check incorrect trials are 0 passed[~correct] = metric[~correct] == 0 - assert data["intervals"].shape[0] == len(metric) == len(passed) + assert data['intervals'].shape[0] == len(metric) == len(passed) return metric, passed @@ -946,7 +946,7 @@ def check_reward_volume_set(data, **_): :param data: dict of trial data with keys ('rewardVolume') """ - metric = data["rewardVolume"] + metric = data['rewardVolume'] passed = 0 < len(set(metric)) <= 2 and 0. in metric return metric, passed @@ -994,19 +994,19 @@ def check_stimulus_move_before_goCue(data, photodiode=None, **_): :param photodiode: the fronts from Bpod's BNC1 input or FPGA frame2ttl channel """ if photodiode is None: - _log.warning("No photodiode TTL input in function call, returning None") + _log.warning('No photodiode TTL input in function call, returning None') return None photodiode_clean = ephys_fpga._clean_frame2ttl(photodiode) - s = photodiode_clean["times"] + s = photodiode_clean['times'] s = s[~np.isnan(s)] # Remove NaNs metric = np.array([]) - for i, c in zip(data["intervals"][:, 0], data["goCue_times"]): + for i, c in zip(data['intervals'][:, 0], data['goCue_times']): metric = np.append(metric, np.count_nonzero(s[s > i] < (c - 0.02))) passed = (metric == 0).astype(float) # Remove no go trials - passed[data["choice"] == 0] = np.nan - assert data["intervals"].shape[0] == len(metric) == len(passed) + passed[data['choice'] == 0] = np.nan + assert data['intervals'].shape[0] == len(metric) == len(passed) return metric, passed @@ -1022,12 +1022,12 @@ def check_audio_pre_trial(data, audio=None, **_): :param audio: the fronts from Bpod's BNC2 input FPGA audio sync channel """ if audio is None: - _log.warning("No BNC2 input in function call, retuning None") + _log.warning('No BNC2 input in function call, retuning None') return None - s = audio["times"][~np.isnan(audio["times"])] # Audio TTLs with NaNs removed + s = audio['times'][~np.isnan(audio['times'])] # Audio TTLs with NaNs removed metric = np.array([], dtype=np.int8) - for i, c in zip(data["intervals"][:, 0], data["goCue_times"]): + for i, c in zip(data['intervals'][:, 0], data['goCue_times']): metric = np.append(metric, sum(s[s > i] < (c - 0.02))) passed = metric == 0 - assert data["intervals"].shape[0] == len(metric) == len(passed) + assert data['intervals'].shape[0] == len(metric) == len(passed) return metric, passed