diff --git a/iblrig_custom_tasks/max_optoStaticTrainingChoiceWorld/PulsePal.py b/iblrig_custom_tasks/max_optoStaticTrainingChoiceWorld/PulsePal.py new file mode 100644 index 0000000..020f574 --- /dev/null +++ b/iblrig_custom_tasks/max_optoStaticTrainingChoiceWorld/PulsePal.py @@ -0,0 +1,137 @@ +import logging +import sys +from typing import Literal +from abc import ABC, abstractmethod +import numpy as np + +from iblrig.base_choice_world import SOFTCODE +from pybpodapi.protocol import StateMachine, Bpod +from pypulsepal import PulsePalObject +from iblrig.base_tasks import BaseSession + +log = logging.getLogger('iblrig.task') + +SOFTCODE_FIRE_PULSEPAL = max(SOFTCODE).value + 1 +SOFTCODE_STOP_PULSEPAL = max(SOFTCODE).value + 2 +V_MAX = 5 + + +class PulsePalStateMachine(StateMachine): + """ + This class adds: + 1. Hardware or sofware triggering of optogenetic stimulation via a PulsePal (or BPod Analog Output Module) + EITHER + - adds soft-codes for starting and stopping the opto stim + OR + - sets up a TTL to hardware trigger the PulsePal + 2. (not yet implemented!!!) sets up a TTL channel for recording opto stim times from the PulsePal + """ + # TODO: define the TTL channel for recording opto stim times? + def __init__( + self, + bpod, + trigger_type: Literal['soft', 'hardware'] = 'soft', + is_opto_stimulation=False, + states_opto_ttls=None, + states_opto_stop=None, + opto_t_max_seconds=None, + ): + super().__init__(bpod) + self.trigger_type = trigger_type + self.is_opto_stimulation = is_opto_stimulation + self.states_opto_ttls = states_opto_ttls or [] + self.states_opto_stop = states_opto_stop or [] + + # Set global timer 1 for T_MAX + self.set_global_timer(timer_id=1, timer_duration=opto_t_max_seconds) + + def add_state(self, **kwargs): + if self.is_opto_stimulation: + if kwargs['state_name'] in self.states_opto_ttls: + if self.trigger_type == 'soft': + kwargs['output_actions'] += [('SoftCode', SOFTCODE_FIRE_PULSEPAL),] + elif self.trigger_type == 'hardware': + kwargs['output_actions'] += [('BNC2', 255),] + kwargs['output_actions'] += [(Bpod.OutputChannels.GlobalTimerTrig, 1)] # start the global timer when the opto stim comes on + elif kwargs['state_name'] in self.states_opto_stop: + if self.trigger_type == 'soft': + kwargs['output_actions'] += [('SoftCode', SOFTCODE_STOP_PULSEPAL),] + elif self.trigger_type == 'hardware': + kwargs['output_actions'] += [('BNC2', 0),] + + super().add_state(**kwargs) + +class PulsePalMixin(ABC): + """ + A mixin class that adds optogenetic stimulation capabilities to a task via the + PulsePal module (or a Analog Output module running PulsePal firmware). It is used + in conjunction with the PulsePalStateMachine class rather than the StateMachine class. + + The user must define the arm_opto_stim method to define the parameters for optogenetic stimulation. + PulsePalMixin supports soft-code triggering via the start_opto_stim and stop_opto_stim methods. + Hardware triggering is also supported by defining trigger channels in the arm_opto_stim method. + + The opto stim is currently hard-coded on output channel 1. + A TTL pulse is hard-coded on output channel 2 for accurately recording trigger times. This TTL + will rise when the opto stim starts and fall when it stops, thus accurately recording software trigger times. + """ + + def start_opto_hardware(self): + self.pulsepal_connection = PulsePalObject('COM13') # TODO: get port from hardware params + log.warning('Connected to PulsePal') + # TODO: get the calibration value for this specific cannula + #super().start_hardware() # TODO: move this out + + # add the softcodes for the PulsePal + soft_code_dict = self.bpod.softcodes + soft_code_dict.update({SOFTCODE_STOP_PULSEPAL: self.stop_opto_stim}) + soft_code_dict.update({SOFTCODE_FIRE_PULSEPAL: self.start_opto_stim}) + self.bpod.register_softcodes(soft_code_dict) + + @abstractmethod + def arm_opto_stim(self, ttl_output_channel): + raise NotImplementedError("User must define the stimulus and trigger type to deliver with pulsepal") + # Define the pulse sequence and load it to the desired output channel here + # This method should not fire the pulse train, that is handled by start_opto_stim() (soft-trigger) or a hardware trigger + # See https://github.com/sanworks/PulsePal/blob/master/Python/Python3/PulsePalExample.py for examples + # you should also define the max_stim_seconds property here to set the maximum duration of the pulse train + + ############################## + # Example code to define a sine wave lasting 5 seconds + voltages = list(range(0, 1000)) + for i in voltages: + voltages[i] = math.sin(voltages[i]/float(10))*10 # Set 1,000 voltages to create a 20V peak-to-peak sine waveform + times = np.linspace(0, 5, len(voltages)) # Create a time vector for the waveform + self.stim_length_seconds = times[-1] # it is essential to get this property right so that the TTL for recording stim pulses is correcty defined + self.pulsepal_connection.sendCustomPulseTrain(1, times, voltages) + self.pulsepal_connection.programOutputChannelParam('customTrainID', 1, 1) + ############################## + + @property + @abstractmethod + def stim_length_seconds(): + # this should be set within the arm_opto_stim method + pass + + def arm_ttl_stim(self): + # a TTL pulse from channel 2 that rises when the opto stim starts and falls when it stops + log.warning('Arming TTL signal') + self.pulsepal_connection.programOutputChannelParam('phase1Duration', 2, self.stim_length_seconds) + self.pulsepal_connection.sendCustomPulseTrain(2, [0,], [V_MAX,]) + self.pulsepal_connection.programOutputChannelParam('customTrainID', 2, 2) + + def start_opto_stim(self): + self.pulsepal_connection.triggerOutputChannels(1, 1, 0, 0) + log.warning('Started opto stim') + + def stop_opto_stim(self): + # this will stop the pulse train instantly (and the corresponding TTL pulse) + # To avoid rebound spiking in the case of GtACR, a ramp down is recommended + self.pulsepal_connection.abortPulseTrains() + + def compute_vmax_from_calibration(self, calibration_value): + # TODO: implement this method to convert the calibration value to a voltage for the opto stim + pass + + def __del__(self): + del self.pulsepal_connection diff --git a/iblrig_custom_tasks/max_optoStaticTrainingChoiceWorld/__init__.py b/iblrig_custom_tasks/max_optoStaticTrainingChoiceWorld/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/iblrig_custom_tasks/max_optoStaticTrainingChoiceWorld/task.py b/iblrig_custom_tasks/max_optoStaticTrainingChoiceWorld/task.py new file mode 100644 index 0000000..e055383 --- /dev/null +++ b/iblrig_custom_tasks/max_optoStaticTrainingChoiceWorld/task.py @@ -0,0 +1,212 @@ +""" +This task is a replica of max_staticTrainingChoiceWorld with the addition of optogenetic stimulation +An `opto_stimulation` column is added to the trials_table, which is a boolean array of length NTRIALS_INIT +The PROBABILITY_OPTO_STIMULATION parameter is used to determine the probability of optogenetic stimulation +for each trial + +Additionally the state machine is modified to add output TTLs for optogenetic stimulation +""" + +import logging +import random +import sys +from importlib.util import find_spec +from pathlib import Path +from typing import Literal +import pandas as pd + +import numpy as np +import yaml +import time + +import iblrig +from iblrig.base_choice_world import SOFTCODE +from pybpodapi.protocol import StateMachine +from iblrig_custom_tasks.max_staticTrainingChoiceWorld.task import Session as StaticTrainingChoiceSession +from iblrig_custom_tasks.max_optoStaticTrainingChoiceWorld.PulsePal import PulsePalMixin, PulsePalStateMachine + +stim_location_history = [] + +log = logging.getLogger('iblrig.task') + +NTRIALS_INIT = 2000 +SOFTCODE_FIRE_LED = max(SOFTCODE).value + 1 +SOFTCODE_RAMP_DOWN_LED = max(SOFTCODE).value + 2 +RAMP_SECONDS = .25 # time to ramp down the opto stim # TODO: make this a parameter +LED_V_MAX = 5 # maximum voltage for LED control # TODO: make this a parameter + +# read defaults from task_parameters.yaml +with open(Path(__file__).parent.joinpath('task_parameters.yaml')) as f: + DEFAULTS = yaml.safe_load(f) + +class Session(StaticTrainingChoiceSession, PulsePalMixin): + protocol_name = 'max_optoStaticTrainingChoiceWorld' + extractor_tasks = ['PulsePalTrials'] + + def __init__( + self, + *args, + probability_opto_stim: float = DEFAULTS['PROBABILITY_OPTO_STIM'], + opto_ttl_states: list[str] = DEFAULTS['OPTO_TTL_STATES'], + opto_stop_states: list[str] = DEFAULTS['OPTO_STOP_STATES'], + max_laser_time: float = DEFAULTS['MAX_LASER_TIME'], + estimated_led_power_mW: float = DEFAULTS['ESTIMATED_LED_POWER_MW'], + **kwargs, + ): + super().__init__(*args, **kwargs) + self.task_params['OPTO_TTL_STATES'] = opto_ttl_states + self.task_params['OPTO_STOP_STATES'] = opto_stop_states + self.task_params['PROBABILITY_OPTO_STIM'] = probability_opto_stim + self.task_params['MAX_LASER_TIME'] = max_laser_time + self.task_params['LED_POWER'] = estimated_led_power_mW + # generates the opto stimulation for each trial + opto = np.random.choice( + [0, 1], + p=[1 - probability_opto_stim, probability_opto_stim], + size=NTRIALS_INIT, + ).astype(bool) + + opto[0] = False + self.trials_table['opto_stimulation'] = opto + + # get the calibration values for the LED + # TODO: do a calibration curve instead + dat = pd.read_csv(r'Y:/opto_fiber_calibration_values.csv') + l_cannula = f'{kwargs["subject"]}L' #TODO: where is SUBJECT defined? + r_cannula = f'{kwargs["subject"]}R' + l_cable = 0 + r_cable = 1 + l_cal_power = dat[(dat['Cannula'] == l_cannula) & (dat['cable_ID'] == l_cable)].cable_power.values[0] + r_cal_power = dat[(dat['Cannula'] == r_cannula) & (dat['cable_ID'] == r_cable)].cable_power.values[0] + + mean_cal_power = np.mean([l_cal_power, r_cal_power]) + vmax = LED_V_MAX * self.task_params['LED_POWER'] / mean_cal_power + log.warning(f'Using VMAX: {vmax}V for target LED power {self.task_params["LED_POWER"]}mW') + self.task_params['VMAX_LED'] = vmax + + def _instantiate_state_machine(self, trial_number=None): + """ + We override this using the custom class PulsePalStateMachine that appends TTLs for optogenetic stimulation where needed + :param trial_number: + :return: + """ + # PWM1 is the LED OUTPUT for port interface board + # Input is PortIn1 + # TODO: enable input port? + log.warning('Instantiating state machine') + is_opto_stimulation = self.trials_table.at[trial_number, 'opto_stimulation'] + if is_opto_stimulation: + self.arm_opto_stim() + self.arm_ttl_stim() + return PulsePalStateMachine( + self.bpod, + trigger_type='soft', # software trigger + is_opto_stimulation=is_opto_stimulation, + states_opto_ttls=self.task_params['OPTO_TTL_STATES'], + states_opto_stop=self.task_params['OPTO_STOP_STATES'], + opto_t_max_seconds=self.task_params['MAX_LASER_TIME'], + ) + + def arm_opto_stim(self): + # define a contant offset voltage with a ramp down at the end to avoid rebound excitation + log.warning('Arming opto stim') + ramp = np.linspace(self.task_params['VMAX_LED'], 0, 1000) # SET POWER + t = np.linspace(0, RAMP_SECONDS, 1000) + v = np.concatenate((np.array([self.task_params['VMAX_LED']]), ramp)) # SET POWER + t = np.concatenate((np.array([0]), t + self.task_params['MAX_LASER_TIME'])) + + self.pulsepal_connection.programOutputChannelParam('phase1Duration', 1, self.task_params['MAX_LASER_TIME']) + self.pulsepal_connection.sendCustomPulseTrain(1, t, v) + self.pulsepal_connection.programOutputChannelParam('customTrainID', 1, 1) + + def start_opto_stim(self): + super().start_opto_stim() + self.opto_start_time = time.time() + + @property + def stim_length_seconds(self): + return self.task_params['MAX_LASER_TIME'] + + def stop_opto_stim(self): + if time.time() - self.opto_start_time >= self.task_params['MAX_LASER_TIME']: + # the LED should have turned off by now, we don't need to force the ramp down + log.warning('Stopped opto stim - hit opto timeout') + return + + # we will modify this function to ramp down the opto stim rather than abruptly stopping it + # send instructions to set the TTL back to 0 + self.pulsepal_connection.programOutputChannelParam('phase1Duration', 2, self.task_params['MAX_LASER_TIME']) + self.pulsepal_connection.sendCustomPulseTrain(2, [0,], [0,]) + self.pulsepal_connection.programOutputChannelParam('customTrainID', 2, 2) + + # send instructions to ramp the opto stim down to 0 + v = np.linspace(self.task_params['VMAX_LED'], 0, 1000) + t = np.linspace(0, RAMP_SECONDS, 1000) + self.pulsepal_connection.programOutputChannelParam('phase1Duration', 1, self.task_params['MAX_LASER_TIME']) + self.pulsepal_connection.sendCustomPulseTrain(1, t, v) + self.pulsepal_connection.programOutputChannelParam('customTrainID', 1, 1) + + # trigger these instructions + self.pulsepal_connection.triggerOutputChannels(1, 1, 0, 0) + log.warning('Stopped opto stim - hit a stop opto state') + + def start_hardware(self): + super().start_hardware() + super().start_opto_hardware() + + + @staticmethod + def extra_parser(): + """:return: argparse.parser()""" + parser = super(Session, Session).extra_parser() + parser.add_argument( + '--probability_opto_stim', + option_strings=['--probability_opto_stim'], + dest='probability_opto_stim', + default=DEFAULTS['PROBABILITY_OPTO_STIM'], + type=float, + help=f'probability of opto-genetic stimulation (default: {DEFAULTS["PROBABILITY_OPTO_STIM"]})', + ) + + parser.add_argument( + '--opto_ttl_states', + option_strings=['--opto_ttl_states'], + dest='opto_ttl_states', + default=DEFAULTS['OPTO_TTL_STATES'], + nargs='+', + type=str, + help='list of the state machine states where opto stim should be delivered', + ) + parser.add_argument( + '--opto_stop_states', + option_strings=['--opto_stop_states'], + dest='opto_stop_states', + default=DEFAULTS['OPTO_STOP_STATES'], + nargs='+', + type=str, + help='list of the state machine states where opto stim should be stopped', + ) + parser.add_argument( + '--max_laser_time', + option_strings=['--max_laser_time'], + dest='max_laser_time', + default=DEFAULTS['MAX_LASER_TIME'], + type=float, + help='Maximum laser duration in seconds', + ) + parser.add_argument( + '--estimated_led_power_mW', + option_strings=['--estimated_led_power_mW'], + dest='estimated_led_power_mW', + default=DEFAULTS['ESTIMATED_LED_POWER_MW'], + type=float, + help='The estimated LED power in mW. Computed from a calibration curve' + ) + + return parser + + +if __name__ == '__main__': # pragma: no cover + kwargs = iblrig.misc.get_task_arguments(parents=[Session.extra_parser()]) + sess = Session(**kwargs) + sess.run() diff --git a/iblrig_custom_tasks/max_optoStaticTrainingChoiceWorld/task_parameters.yaml b/iblrig_custom_tasks/max_optoStaticTrainingChoiceWorld/task_parameters.yaml new file mode 100644 index 0000000..a2e86ee --- /dev/null +++ b/iblrig_custom_tasks/max_optoStaticTrainingChoiceWorld/task_parameters.yaml @@ -0,0 +1,33 @@ +'CONTRAST_SET': [1.0, 0.25, 0.125, 0.0625, 0.0, 0.0, 0.0625, 0.125, 0.25, 1.0] # signed contrast set +'PROBABILITY_SET': [2, 2, 2, 2, 1, 1, 2, 2, 2, 2] # scalar or list of n signed contrasts values, if scalar all contingencies are equiprobable +'REWARD_SET_UL': [1.5] # scalar or list of Ncontrast values +'POSITION_SET': [-35, -35, -35, -35, -35, 35, 35, 35, 35, 35] # position set +'STIM_GAIN': 4.0 # wheel to stimulus relationship +'STIM_REVERSE': False +#'DEBIAS': True # Whether to use debiasing rule or not by repeating error trials # todo + +# Opto parameters +'OPTO_TTL_STATES': # list of the state machine states where opto stim should be delivered + - trial_start +'OPTO_STOP_STATES': + - no_go + - error + - reward +'PROBABILITY_OPTO_STIM': 0.2 # probability of optogenetic stimulation +'MAX_LASER_TIME': 6.0 +'ESTIMATED_LED_POWER_MW': 2.5 +#'MASK_TTL_STATES': # list of the state machine states where mask stim should be delivered +# - trial_start +# - delay_initiation +# - reset_rotary_encoder +# - quiescent_period +# - stim_on +# - interactive_delay +# - play_tone +# - reset2_rotary_encoder +# - closed_loop +# - no_go +# - freeze_error +# - error +# - freeze_reward +# - reward diff --git a/iblrig_custom_tasks/max_staticTrainingChoiceWorld/__init__.py b/iblrig_custom_tasks/max_staticTrainingChoiceWorld/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/iblrig_custom_tasks/max_staticTrainingChoiceWorld/task.py b/iblrig_custom_tasks/max_staticTrainingChoiceWorld/task.py new file mode 100644 index 0000000..0bb9382 --- /dev/null +++ b/iblrig_custom_tasks/max_staticTrainingChoiceWorld/task.py @@ -0,0 +1,163 @@ +from pathlib import Path + +import numpy as np +import pandas as pd +import yaml +#import logging + +import iblrig.misc +from iblrig.base_choice_world import NTRIALS_INIT, ActiveChoiceWorldSession + +#log = logging.getLogger('iblrig.task') + +# read defaults from task_parameters.yaml +with open(Path(__file__).parent.joinpath('task_parameters.yaml')) as f: + DEFAULTS = yaml.safe_load(f) + + +class Session(ActiveChoiceWorldSession): + """ + This is a static version of trainingChoiceWorld, where debiasing and adaptive contrasts are disabled. + It does not have any blocks like the biased task. + Zero contrast probability is halved by default. + Highly similar to advancedChoiceWorld, but with slightly different contrast sets and probabilities. + + TODO: + - Maybe add a longer timeout for incorrect sessions? + - Add antibias back in if the mice struggle too much to learn + """ + + protocol_name = 'max_staticTrainingChoiceWorld' + extractor_tasks = ['ChoiceWorldTrials'] + + def __init__( + self, + *args, + contrast_set: list[float] = DEFAULTS['CONTRAST_SET'], + probability_set: list[float] = DEFAULTS['PROBABILITY_SET'], + reward_set_ul: list[float] = DEFAULTS['REWARD_SET_UL'], + position_set: list[float] = DEFAULTS['POSITION_SET'], + stim_gain: float = DEFAULTS['STIM_GAIN'], + stim_reverse: float = DEFAULTS['STIM_REVERSE'], + feedback_error_delay_secs: float = DEFAULTS['FEEDBACK_ERROR_DELAY_SECS'], + **kwargs, + ): + super().__init__(*args, **kwargs) + nc = len(contrast_set) + assert len(probability_set) in [nc, 1], 'probability_set must be a scalar or have the same length as contrast_set' + assert len(reward_set_ul) in [nc, 1], 'reward_set_ul must be a scalar or have the same length as contrast_set' + assert len(position_set) == nc, 'position_set must have the same length as contrast_set' + self.task_params['CONTRAST_SET'] = contrast_set + self.task_params['PROBABILITY_SET'] = probability_set + self.task_params['REWARD_SET_UL'] = reward_set_ul + self.task_params['POSITION_SET'] = position_set + self.task_params['STIM_GAIN'] = stim_gain + self.task_params['STIM_REVERSE'] = stim_reverse + self.task_params['FEEDBACK_ERROR_DELAY_SECS'] = feedback_error_delay_secs # make the punishment timeout a parameter + # it is easier to work with parameters as a dataframe + self.df_contingencies = pd.DataFrame(columns=['contrast', 'probability', 'reward_amount_ul', 'position']) + self.df_contingencies['contrast'] = contrast_set + self.df_contingencies['probability'] = np.float64(probability_set if len(probability_set) == nc else probability_set[0]) + self.df_contingencies['reward_amount_ul'] = reward_set_ul if len(reward_set_ul) == nc else reward_set_ul[0] + self.df_contingencies['position'] = position_set + # normalize the probabilities + self.df_contingencies.loc[:, 'probability'] = self.df_contingencies.loc[:, 'probability'] / np.sum( + self.df_contingencies.loc[:, 'probability'] + ) + # update the PROBABILITY LEFT field to reflect the probabilities in the parameters above + self.task_params['PROBABILITY_LEFT'] = np.sum( + self.df_contingencies['probability'] * (self.df_contingencies['position'] < 0) + ) + self.trials_table['debias_trial'] = np.zeros(NTRIALS_INIT, dtype=bool) + + def draw_next_trial_info(self, **kwargs): + nc = self.df_contingencies.shape[0] + ic = np.random.choice(np.arange(nc), p=self.df_contingencies['probability']) + # now calling the super class with the proper parameters + super().draw_next_trial_info( + pleft=self.task_params.PROBABILITY_LEFT, + contrast=self.df_contingencies.at[ic, 'contrast'], + position=self.df_contingencies.at[ic, 'position'], + reward_amount=self.df_contingencies.at[ic, 'reward_amount_ul'], + ) + + @property + def reward_amount(self): + return self.task_params.REWARD_AMOUNTS_UL[0] + + @staticmethod + def extra_parser(): + """:return: argparse.parser()""" + parser = super(Session, Session).extra_parser() + parser.add_argument( + '--contrast_set', + option_strings=['--contrast_set'], + dest='contrast_set', + default=DEFAULTS['CONTRAST_SET'], + nargs='+', + type=float, + help='Set of contrasts to present', + ) + parser.add_argument( + '--probability_set', + option_strings=['--probability_set'], + dest='probability_set', + default=DEFAULTS['PROBABILITY_SET'], + nargs='+', + type=float, + help='Probabilities of each contrast in contrast_set. If scalar all contrasts are equiprobable', + ) + parser.add_argument( + '--reward_set_ul', + option_strings=['--reward_set_ul'], + dest='reward_set_ul', + default=DEFAULTS['REWARD_SET_UL'], + nargs='+', + type=float, + help='Reward for contrast in contrast set.', + ) + parser.add_argument( + '--feedback_error_delay_secs', + option_strings=['--feedback_error_delay_secs'], + dest='feedback_error_delay_secs', + default=DEFAULTS['FEEDBACK_ERROR_DELAY_SECS'], + type=float, + help='The punishment timeout duration (s) for incorrect choice trials', + ) + parser.add_argument( + '--position_set', + option_strings=['--position_set'], + dest='position_set', + default=DEFAULTS['POSITION_SET'], + nargs='+', + type=float, + help='Position for each contrast in contrast set.', + ) + parser.add_argument( + '--stim_gain', + option_strings=['--stim_gain'], + dest='stim_gain', + default=DEFAULTS['STIM_GAIN'], + type=float, + help=f'Visual angle/wheel displacement ' f'(deg/mm, default: {DEFAULTS["STIM_GAIN"]})', + ) + parser.add_argument( + '--stim_reverse', + option_strings=['--stim_reverse'], + action='store_true', + dest='stim_reverse', + help='Inverse relationship of wheel to stimulus movement', + ) + return parser + + def next_trial(self): + # update counters + self.trial_num += 1 + # save and send trial info to bonsai + self.draw_next_trial_info(pleft=self.task_params.PROBABILITY_LEFT) + + +if __name__ == '__main__': # pragma: no cover + kwargs = iblrig.misc.get_task_arguments(parents=[Session.extra_parser()]) + sess = Session(**kwargs) + sess.run() \ No newline at end of file diff --git a/iblrig_custom_tasks/max_staticTrainingChoiceWorld/task_parameters.yaml b/iblrig_custom_tasks/max_staticTrainingChoiceWorld/task_parameters.yaml new file mode 100644 index 0000000..bce0c0b --- /dev/null +++ b/iblrig_custom_tasks/max_staticTrainingChoiceWorld/task_parameters.yaml @@ -0,0 +1,8 @@ +'CONTRAST_SET': [1.0, 0.5, 0.25, 0.125, 0.0625, 0.0, 0.0, 0.0625, 0.125, 0.25, 0.5, 1.0] # signed contrast set +'PROBABILITY_SET': [2, 2, 2, 2, 2, 1, 1, 2, 2, 2, 2, 2] # scalar or list of n signed contrasts values, if scalar all contingencies are equiprobable +'REWARD_SET_UL': [1.5] # scalar or list of Ncontrast values +'POSITION_SET': [-35, -35, -35, -35, -35, -35, 35, 35, 35, 35, 35, 35] # position set +'STIM_GAIN': 4.0 # wheel to stimulus relationship +'STIM_REVERSE': False +'FEEDBACK_ERROR_DELAY_SECS': 2.0 +#'DEBIAS': True # Whether to use debiasing rule or not by repeating error trials # todo diff --git a/iblrig_custom_tasks/max_staticTrainingChoiceWorld/test_max_staticTrainingChoiceWorld.py b/iblrig_custom_tasks/max_staticTrainingChoiceWorld/test_max_staticTrainingChoiceWorld.py new file mode 100644 index 0000000..e69de29 diff --git a/projects/max_optoStaticTrainingChoiceWorld.py b/projects/max_optoStaticTrainingChoiceWorld.py new file mode 100644 index 0000000..cae0977 --- /dev/null +++ b/projects/max_optoStaticTrainingChoiceWorld.py @@ -0,0 +1,103 @@ +"""Bpod extractor for max_optoStaticTrainingChoiceWorld task. + +This is the same as advancedChoiceWorld with the addition of one dataset, `optoStimulation.intervals`; The times the +led was on. +""" + +import numpy as np +import ibllib.io.raw_data_loaders as raw +from ibllib.io.extractors.base import BaseBpodTrialsExtractor, run_extractor_classes +from ibllib.io.extractors.bpod_trials import TrainingTrials, BiasedTrials # was BiasedTrials +from ibllib.pipes.behavior_tasks import ChoiceWorldTrialsNidq, ChoiceWorldTrialsBpod +from ibllib.qc.task_metrics import TaskQC as BaseTaskQC +from inspect import getmembers, ismethod + +class TaskQC(BaseTaskQC): + def _get_checks(self): + def is_metric(x): + return ismethod(x) and x.__name__.startswith('check_') + + checks = super()._get_checks() + checks.update(dict(getmembers(self, is_metric))) + return checks + + def check_opto_percentage(self, data, **_): + p_opto = self.extractor.settings['PROBABILITY_OPTO_STIM'] + is_opto_trial = ~np.isnan(data['opto_intervals'][:, 0]) + n_trials = len(is_opto_trial) + actual_p_opto = np.sum(is_opto_trial) / n_trials + passed = np.isclose(p_opto, actual_p_opto, rtol=0, atol=.2) + return actual_p_opto, passed + + def check_opto_stim_intervals(self, data, **_): + """ + 1. Verify that the laser stimulation intervals are within the trial intervals of an opto_on trial. + 2. Verify that the laser stimulation intervals are greater than 0 and less than t_max. + + + Parameters + ---------- + data : dict + Map of trial data with keys ('opto_intervals', 'opto_stimulation'). + + Returns + ------- + numpy.array + An array the length of trials of metric M. + numpy.array + An boolean array the length of trials where True indicates the metric passed the + criterion. + """ + t_max = self.extractor.settings['MAX_LASER_TIME'] + is_opto_trial = ~np.isnan(data['opto_intervals'][:, 0]) + + opto_on_length = data['opto_intervals'][:,1] - data['opto_intervals'][:,0] + tol = .01 # seconds + passed = (opto_on_length < t_max + tol) | ~is_opto_trial # less than t_max + passed = passed & ((opto_on_length > 0) | ~is_opto_trial) # greater than zero + return opto_on_length, passed + +class TrialsOpto(BaseBpodTrialsExtractor): + var_names = BiasedTrials.var_names + ('opto_intervals',) + save_names = BiasedTrials.save_names + ('_ibl_optoStimulation.intervals.npy',) + + def _extract(self, extractor_classes=None, **kwargs) -> dict: + settings = self.settings.copy() + assert {'OPTO_STOP_STATES', 'OPTO_TTL_STATES', 'PROBABILITY_OPTO_STIM'} <= set(settings) + # Get all detected TTLs. These are stored for QC purposes + self.frame2ttl, self.audio = raw.load_bpod_fronts(self.session_path, data=self.bpod_trials) + # Extract common biased choice world datasets + out, _ = run_extractor_classes( + [BiasedTrials], session_path=self.session_path, bpod_trials=self.bpod_trials, + settings=settings, save=False, task_collection=self.task_collection) + + # Extract opto dataset + laser_intervals = [] + #for trial in filter(lambda t: t['opto_stimulation'], self.bpod_trials): + for trial in self.bpod_trials: + # the PulsePal TTL is wired into Bpod port 2. Hi for led on, lo for led off + events = trial['behavior_data']['Events timestamps'] + if 'Port2In' in events and 'Port2Out' in events: + start = events['Port2In'][0] + stop = events['Port2Out'][0] # TODO: make this handle multiple opto events per trial + else: + start = np.nan + stop = np.nan + laser_intervals.append((start, stop)) + out['opto_intervals'] = np.array(laser_intervals, dtype=np.float64) + + return {k: out[k] for k in self.var_names} # Ensures all datasets present and ordered + +class PulsePalTrialsBpod(ChoiceWorldTrialsBpod): + """Extract bpod only trials and pulsepal stimulation data.""" + @property + def signature(self): + signature = super().signature + signature['output_files'].append(('*optoStimulation.intervals.npy', self.output_collection, True)) + return signature + + def run_qc(self, trials_data=None, update=True, QC=TaskQC,**kwargs): + return super().run_qc(trials_data=trials_data, update=update, QC=QC, **kwargs) + + +# TODO: will eventually need to write the nidaq extractor \ No newline at end of file diff --git a/projects/task_extractor_map.json b/projects/task_extractor_map.json index 14444bb..0603614 100644 --- a/projects/task_extractor_map.json +++ b/projects/task_extractor_map.json @@ -2,6 +2,8 @@ "samuel_cuedBiasedChoiceWorld": "BiasedTrials", "_iblrig_tasks_neuromodulatorChoiceWorld": "projects.neuromodulators.TrialsTableNeuromodulator", "nate_optoBiasedChoiceWorld": "projects.nate_optoBiasedChoiceWorld.TrialsOpto", + "max_optoStaticTrainingChoiceWorld": "projects.max_optoStaticTrainingChoiceWorld.TrialsOpto", + "max_staticTrainingChoiceWorld": "BiasedTrials" "FPChoiceWorld": "BiasedTrials", "FPROptoChoiceWorld": "projects.alejandro_FPLROptoChoiceWorld.TrialsFPLROpto", "FPLOptoChoiceWorld": "projects.alejandro_FPLROptoChoiceWorld.TrialsFPLROpto", diff --git a/pyproject.toml b/pyproject.toml index d55b4a5..d2842bc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "project_extraction" -version = "0.6.5" +version = "0.7.0" description = "Custom extractors for satellite tasks" dynamic = [ "readme" ] keywords = [ "IBL", "neuro-science" ]