Skip to content

Commit

Permalink
Add get_trials_tasks function
Browse files Browse the repository at this point in the history
  • Loading branch information
k1o0 committed Dec 15, 2023
1 parent 934f9da commit b4883f5
Show file tree
Hide file tree
Showing 6 changed files with 297 additions and 44 deletions.
156 changes: 124 additions & 32 deletions ibllib/io/extractors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,8 @@ def extract(self, bpod_trials=None, settings=None, **kwargs):

def run_extractor_classes(classes, session_path=None, **kwargs):
"""
Run a set of extractors with the same inputs
Run a set of extractors with the same inputs.
:param classes: list of Extractor class
:param save: True/False
:param path_out: (defaults to alf path)
Expand Down Expand Up @@ -195,6 +196,23 @@ def run_extractor_classes(classes, session_path=None, **kwargs):


def _get_task_types_json_config():
"""
Return the extractor types map.
This function is only used for legacy sessions, i.e. those without an experiment description
file and will be removed in favor of :func:`_get_task_extractor_map`, which directly returns
the Bpod extractor class name. The experiment description file cuts out the need for pipeline
name identifiers.
Returns
-------
Dict[str, str]
A map of task protocol to task extractor identifier, e.g. 'ephys', 'habituation', etc.
See Also
--------
_get_task_extractor_map - returns a map of task protocol to Bpod trials extractor class name.
"""
with open(Path(__file__).parent.joinpath('extractor_types.json')) as fp:
task_types = json.load(fp)
try:
Expand All @@ -210,6 +228,26 @@ def _get_task_types_json_config():


def get_task_protocol(session_path, task_collection='raw_behavior_data'):
"""
Return the task protocol name from task settings.
If the session path and/or task collection do not exist, the settings file is missing or
otherwise can not be parsed, or if the 'PYBPOD_PROTOCOL' key is absent, None is returned.
A warning is logged if the session path or settings file doesn't exist. An error is logged if
the settings file can not be parsed.
Parameters
----------
session_path : str, pathlib.Path
The absolute session path.
task_collection : str
The session path directory containing the task settings file.
Returns
-------
str or None
The Pybpod task protocol name or None if not found.
"""
try:
settings = load_settings(get_session_path(session_path), task_collection=task_collection)
except json.decoder.JSONDecodeError:
Expand All @@ -223,11 +261,26 @@ def get_task_protocol(session_path, task_collection='raw_behavior_data'):

def get_task_extractor_type(task_name):
"""
Returns the task type string from the full pybpod task name:
_iblrig_tasks_biasedChoiceWorld3.7.0 returns "biased"
_iblrig_tasks_trainingChoiceWorld3.6.0 returns "training'
:param task_name:
:return: one of ['biased', 'habituation', 'training', 'ephys', 'mock_ephys', 'sync_ephys']
Returns the task type string from the full pybpod task name.
Parameters
----------
task_name : str
The complete task protocol name from the PYBPOD_PROTOCOL field of the task settings.
Returns
-------
str
The extractor type identifier. Examples include 'biased', 'habituation', 'training',
'ephys', 'mock_ephys' and 'sync_ephys'.
Examples
--------
>>> get_task_extractor_type('_iblrig_tasks_biasedChoiceWorld3.7.0')
'biased'
>>> get_task_extractor_type('_iblrig_tasks_trainingChoiceWorld3.6.0')
'training'
"""
if isinstance(task_name, Path):
task_name = get_task_protocol(task_name)
Expand All @@ -245,16 +298,30 @@ def get_task_extractor_type(task_name):

def get_session_extractor_type(session_path, task_collection='raw_behavior_data'):
"""
From a session path, loads the settings file, finds the task and checks if extractors exist
task names examples:
:param session_path:
:return: bool
Infer trials extractor type from task settings.
From a session path, loads the settings file, finds the task and checks if extractors exist.
Examples include 'biased', 'habituation', 'training', 'ephys', 'mock_ephys', and 'sync_ephys'.
Note this should only be used for legacy sessions, i.e. those without an experiment description
file.
Parameters
----------
session_path : str, pathlib.Path
The session path for which to determine the pipeline.
task_collection : str
The session path directory containing the raw task data.
Returns
-------
str or False
The task extractor type, e.g. 'biased', 'habituation', 'ephys', or False if unknown.
"""
settings = load_settings(session_path, task_collection=task_collection)
if settings is None:
_logger.error(f'ABORT: No data found in "{task_collection}" folder {session_path}')
task_protocol = get_task_protocol(session_path, task_collection=task_collection)
if task_protocol is None:
_logger.error(f'ABORT: No task protocol found in "{task_collection}" folder {session_path}')
return False
extractor_type = get_task_extractor_type(settings['PYBPOD_PROTOCOL'])
extractor_type = get_task_extractor_type(task_protocol)
if extractor_type:
return extractor_type
else:
Expand All @@ -263,28 +330,52 @@ def get_session_extractor_type(session_path, task_collection='raw_behavior_data'

def get_pipeline(session_path, task_collection='raw_behavior_data'):
"""
Get the pre-processing pipeline name from a session path
:param session_path:
:return:
Get the pre-processing pipeline name from a session path.
Note this is only suitable for legacy sessions, i.e. those without an experiment description
file. This function will be removed in the future.
Parameters
----------
session_path : str, pathlib.Path
The session path for which to determine the pipeline.
task_collection : str
The session path directory containing the raw task data.
Returns
-------
str
The pipeline name inferred from the extractor type, e.g. 'ephys', 'training', 'widefield'.
"""
stype = get_session_extractor_type(session_path, task_collection=task_collection)
return _get_pipeline_from_task_type(stype)


def _get_pipeline_from_task_type(stype):
"""
Returns the pipeline from the task type. Some tasks types directly define the pipeline
:param stype: session_type or task extractor type
:return:
Return the pipeline from the task type.
Some task types directly define the pipeline. Note this is only suitable for legacy sessions,
i.e. those without an experiment description file. This function will be removed in the future.
Parameters
----------
stype : str
The session type or task extractor type, e.g. 'habituation', 'ephys', etc.
Returns
-------
str
A task pipeline identifier.
"""
if stype in ['ephys_biased_opto', 'ephys', 'ephys_training', 'mock_ephys', 'sync_ephys']:
return 'ephys'
elif stype in ['habituation', 'training', 'biased', 'biased_opto']:
return 'training'
elif 'widefield' in stype:
elif isinstance(stype, str) and 'widefield' in stype:
return 'widefield'
else:
return stype
return stype or ''


def _get_task_extractor_map():
Expand All @@ -293,7 +384,7 @@ def _get_task_extractor_map():
Returns
-------
dict(str, str)
Dict[str, str]
A map of task protocol to Bpod trials extractor class.
"""
FILENAME = 'task_extractor_map.json'
Expand All @@ -315,34 +406,35 @@ def get_bpod_extractor_class(session_path, task_collection='raw_behavior_data'):
"""
Get the Bpod trials extractor class associated with a given Bpod session.
Note that unlike :func:`get_session_extractor_type`, this function maps directly to the Bpod
trials extractor class name. This is hardware invariant and is purly to determine the Bpod only
trials extractor.
Parameters
----------
session_path : str, pathlib.Path
The session path containing Bpod behaviour data.
task_collection : str
The session_path subfolder containing the Bpod settings file.
The session_path sub-folder containing the Bpod settings file.
Returns
-------
str
The extractor class name.
"""
# Attempt to load settings files
settings = load_settings(session_path, task_collection=task_collection)
if settings is None:
raise ValueError(f'No data found in "{task_collection}" folder {session_path}')
# Attempt to get task protocol
protocol = settings.get('PYBPOD_PROTOCOL')
# Attempt to get protocol name from settings file
protocol = get_task_protocol(session_path, task_collection=task_collection)
if not protocol:
raise ValueError(f'No task protocol found in {session_path/task_collection}')
raise ValueError(f'No task protocol found in {Path(session_path) / task_collection}')
return protocol2extractor(protocol)


def protocol2extractor(protocol):
"""
Get the Bpod trials extractor class associated with a given Bpod task protocol.
The Bpod task protocol can be found in the 'PYBPOD_PROTOCOL' field of _iblrig_taskSettings.raw.json.
The Bpod task protocol can be found in the 'PYBPOD_PROTOCOL' field of the
_iblrig_taskSettings.raw.json file.
Parameters
----------
Expand Down
67 changes: 63 additions & 4 deletions ibllib/pipes/dynamic_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import spikeglx

import ibllib.io.session_params as sess_params
import ibllib.io.extractors.base
from ibllib.io.extractors.base import get_pipeline, get_session_extractor_type
import ibllib.pipes.tasks as mtasks
import ibllib.pipes.base_tasks as bstasks
import ibllib.pipes.widefield_tasks as wtasks
Expand Down Expand Up @@ -45,7 +45,7 @@ def acquisition_description_legacy_session(session_path, save=False):
dict
The legacy acquisition description.
"""
extractor_type = ibllib.io.extractors.base.get_session_extractor_type(session_path=session_path)
extractor_type = get_session_extractor_type(session_path)
etype2protocol = dict(biased='choice_world_biased', habituation='choice_world_habituation',
training='choice_world_training', ephys='choice_world_recording')
dict_ad = get_acquisition_description(etype2protocol[extractor_type])
Expand Down Expand Up @@ -130,7 +130,7 @@ def make_pipeline(session_path, **pkwargs):
----------
session_path : str, Path
The absolute session path, i.e. '/path/to/subject/yyyy-mm-dd/nnn'.
**pkwargs
pkwargs
Optional arguments passed to the ibllib.pipes.tasks.Pipeline constructor.
Returns
Expand All @@ -147,7 +147,7 @@ def make_pipeline(session_path, **pkwargs):
if not acquisition_description:
raise ValueError('Experiment description file not found or is empty')
devices = acquisition_description.get('devices', {})
kwargs = {'session_path': session_path}
kwargs = {'session_path': session_path, 'one': pkwargs.get('one')}

# Registers the experiment description file
tasks['ExperimentDescriptionRegisterRaw'] = type('ExperimentDescriptionRegisterRaw',
Expand Down Expand Up @@ -430,3 +430,62 @@ def load_pipeline_dict(path):
task_list = yaml.full_load(file)

return task_list


def get_trials_tasks(session_path, one=None):
"""
Return a list of pipeline trials extractor task objects for a given session.
This function supports both legacy and dynamic pipeline sessions.
Parameters
----------
session_path : str, pathlib.Path
An absolute path to a session.
one : one.api.One
An ONE instance.
Returns
-------
list of pipes.tasks.Task
A list of task objects for the provided session.
"""
# Check for an experiment.description file; ensure downloaded if possible
if one and one.to_eid(session_path): # to_eid returns None if session not registered
one.load_datasets(session_path, ['_ibl_experiment.description'], download_only=True, assert_present=False)
experiment_description = sess_params.read_params(session_path)

# If experiment description file then use this to make the pipeline
if experiment_description is not None:
tasks = []
pipeline = make_pipeline(session_path, one=one)
trials_tasks = [t for t in pipeline.tasks if 'Trials' in t]
for task in trials_tasks:
t = pipeline.tasks.get(task)
t.__init__(session_path, **t.kwargs)
tasks.append(t)
else:
# Otherwise default to old way of doing things
pipeline = get_pipeline(session_path)
if pipeline == 'training':
from ibllib.pipes.training_preprocessing import TrainingTrials
tasks = [TrainingTrials(session_path, one=one)]
elif pipeline == 'ephys':
from ibllib.pipes.ephys_preprocessing import EphysTrials
tasks = [EphysTrials(session_path, one=one)]
else:
try:
# try to find a custom extractor in the personal projects extraction class
import projects.base
task_type = get_session_extractor_type(session_path)
assert (PipelineClass := projects.base.get_pipeline(task_type))
pipeline = PipelineClass(session_path, one=one)
trials_task_name = next(task for task in pipeline.tasks if 'Trials' in task)
task = pipeline.tasks.get(trials_task_name)
task(session_path)
tasks = [task]
except (ModuleNotFoundError, AssertionError, StopIteration):
tasks = []

return tasks
1 change: 1 addition & 0 deletions ibllib/pipes/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@

def subjects_data_folder(folder: Path, rglob: bool = False) -> Path:
"""Given a root_data_folder will try to find a 'Subjects' data folder.
If Subjects folder is passed will return it directly."""
if not isinstance(folder, Path):
folder = Path(folder)
Expand Down
2 changes: 1 addition & 1 deletion ibllib/pipes/training_preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from ibllib.qc.task_extractors import TaskQCExtractor

_logger = logging.getLogger(__name__)
warnings.warn('`pipes.training_preprocessing` to be removed in favour of dynamic pipeline')
warnings.warn('`pipes.training_preprocessing` to be removed in favour of dynamic pipeline', FutureWarning)


# level 0
Expand Down
22 changes: 20 additions & 2 deletions ibllib/tests/fixtures/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ def create_fake_raw_behavior_data_folder(
):
"""Create the folder structure for a raw behaviour session.
Creates a raw_behavior_data folder and optionally, touches some files and writes a experiment
Creates a raw_behavior_data folder and optionally, touches some files and writes an experiment
description stub to a `_devices` folder.
Parameters
Expand Down Expand Up @@ -304,8 +304,26 @@ def create_fake_raw_behavior_data_folder(


def populate_task_settings(fpath: Path, patch: dict):
with fpath.open("w") as f:
"""
Populate a task settings JSON file.
Parameters
----------
fpath : pathlib.Path
A path to a raw task settings folder or the full settings file path.
patch : dict
The settings dict to write to file.
Returns
-------
pathlib.Path
The full settings file path.
"""
if fpath.is_dir():
fpath /= '_iblrig_taskSettings.raw.json'
with fpath.open('w') as f:
json.dump(patch, f, indent=1)
return fpath


def create_fake_complete_ephys_session(
Expand Down
Loading

0 comments on commit b4883f5

Please sign in to comment.