Skip to content

Commit

Permalink
Merge pull request #681 from int-brain-lab/passive_fix
Browse files Browse the repository at this point in the history
Passive fix
  • Loading branch information
mayofaulkner authored Dec 11, 2023
2 parents e0cac72 + 30d1d78 commit 6075d57
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 68 deletions.
2 changes: 1 addition & 1 deletion brainbox/plot_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -568,7 +568,7 @@ def plot_probe(data, ax=None, show_cbar=True, make_pretty=True, fig_kwargs=dict(
im = NonUniformImage(ax, interpolation='nearest', cmap=data['cmap'])
im.set_clim(data['clim'][0], data['clim'][1])
im.set_data(x, y, dat.T)
ax.images.append(im)
ax.add_image(im)

ax.set_xlim(data['xlim'][0], data['xlim'][1])
ax.set_ylim(data['ylim'][0], data['ylim'][1])
Expand Down
10 changes: 6 additions & 4 deletions ibllib/io/extractors/ephys_passive.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,8 +227,8 @@ def _get_passive_spacers(session_path, sync_collection='raw_ephys_data',
f'trace ({int(np.size(spacer_times) / 2)})'
)

if tmax is None: # TODO THIS NEEDS CHANGING AS FOR DYNAMIC PIPELINE F2TTL slower than valve
tmax = fttl['times'][-1]
if tmax is None:
tmax = sync['times'][-1]

spacer_times = np.r_[spacer_times.flatten(), tmax]
return spacer_times[0], spacer_times[1::2], spacer_times[2::2]
Expand Down Expand Up @@ -418,8 +418,10 @@ def _extract_passiveAudio_intervals(audio: dict, rig_version: str) -> Tuple[np.a
soundOff_times = audio["times"][audio["polarities"] < 0]

# Check they are the correct number
assert len(soundOn_times) == NTONES + NNOISES, "Wrong number of sound ONSETS"
assert len(soundOff_times) == NTONES + NNOISES, "Wrong number of sound OFFSETS"
assert len(soundOn_times) == NTONES + NNOISES, f"Wrong number of sound ONSETS, " \
f"{len(soundOn_times)}/{NTONES + NNOISES}"
assert len(soundOff_times) == NTONES + NNOISES, f"Wrong number of sound OFFSETS, " \
f"{len(soundOn_times)}/{NTONES + NNOISES}"

diff = soundOff_times - soundOn_times
# Tone is ~100ms so check if diff < 0.3
Expand Down
105 changes: 42 additions & 63 deletions ibllib/pipes/training_status.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@

from ibllib.io.raw_data_loaders import load_bpod
from ibllib.oneibl.registration import _get_session_times
from ibllib.io.extractors.base import get_pipeline, get_session_extractor_type
from ibllib.io.extractors.base import get_session_extractor_type
from ibllib.io.session_params import read_params
import ibllib.pipes.dynamic_pipeline as dyn
from ibllib.io.extractors.bpod_trials import get_bpod_extractor

from iblutil.util import setup_logger
from ibllib.plots.snapshot import ReportSnapshot
Expand All @@ -22,6 +22,7 @@
import seaborn as sns
import boto3
from botocore.exceptions import ProfileNotFound, ClientError
from itertools import chain

logger = setup_logger(__name__)

Expand Down Expand Up @@ -87,43 +88,6 @@ def upload_training_table_to_aws(lab, subject):
return


def get_trials_task(session_path, one):
# If experiment description file then process this
experiment_description_file = read_params(session_path)
if experiment_description_file is not None:
tasks = []
pipeline = dyn.make_pipeline(session_path)
trials_tasks = [t for t in pipeline.tasks if 'Trials' in t]
for task in trials_tasks:
t = pipeline.tasks.get(task)
t.__init__(session_path, **t.kwargs)
tasks.append(t)
else:
# Otherwise default to old way of doing things
pipeline = get_pipeline(session_path)
if pipeline == 'training':
from ibllib.pipes.training_preprocessing import TrainingTrials
tasks = [TrainingTrials(session_path)]
elif pipeline == 'ephys':
from ibllib.pipes.ephys_preprocessing import EphysTrials
tasks = [EphysTrials(session_path)]
else:
try:
# try and look if there is a custom extractor in the personal projects extraction class
import projects.base
task_type = get_session_extractor_type(session_path)
PipelineClass = projects.base.get_pipeline(task_type)
pipeline = PipelineClass(session_path, one)
trials_task_name = next(task for task in pipeline.tasks if 'Trials' in task)
task = pipeline.tasks.get(trials_task_name)
task.__init__(session_path)
tasks = [task]
except Exception:
tasks = []

return tasks


def save_path(subj_path):
return Path(subj_path).joinpath('training.csv')

Expand Down Expand Up @@ -155,7 +119,7 @@ def load_existing_dataframe(subj_path):
def load_trials(sess_path, one, collections=None, force=True, mode='raise'):
"""
Load trials data for session. First attempts to load from local session path, if this fails will attempt to download via ONE,
if this also fails, will then attempt to re-extraxt locally
if this also fails, will then attempt to re-extract locally
:param sess_path: session path
:param one: ONE instance
:param force: when True and if the session trials can't be found, will attempt to re-extract from the disk
Expand Down Expand Up @@ -207,19 +171,24 @@ def load_trials(sess_path, one, collections=None, force=True, mode='raise'):
if 'probabilityLeft' not in trials.keys():
raise ALFObjectNotFound
except Exception:
# Finally try to rextract the trials data locally
# Finally try to re-extract the trials data locally
try:
# Get the tasks that need to be run
tasks = get_trials_task(sess_path, one)
if len(tasks) > 0:
for task in tasks:
status = task.run()
if status == 0:
return load_trials(sess_path, collections=collections, one=one, force=False)
else:
return
raw_collections, _ = get_data_collection(sess_path)

if len(raw_collections) == 0:
return None

trials_dict = {}
for i, collection in enumerate(raw_collections):
extractor = get_bpod_extractor(sess_path, task_collection=collection)
trials_data, _ = extractor.extract(task_collection=collection, save=False)
trials_dict[i] = alfio.AlfBunch.from_df(trials_data['table'])

if len(trials_dict) > 1:
trials = training.concatenate_trials(trials_dict)
else:
trials = None
trials = trials_dict[0]

except Exception as e:
if mode == 'raise':
raise Exception(f'Exhausted all possibilities for loading trials for {sess_path}') from e
Expand Down Expand Up @@ -468,20 +437,29 @@ def get_data_collection(session_path):
:param session_path: path of session
:return:
"""
experiment_description_file = read_params(session_path)
if experiment_description_file is not None:
pipeline = dyn.make_pipeline(session_path)
trials_tasks = [t for t in pipeline.tasks if 'Trials' in t]
collections = [pipeline.tasks.get(task).kwargs['collection'] for task in trials_tasks]
if len(collections) == 1 and collections[0] == 'raw_behavior_data':
alf_collections = ['alf']
elif all(['raw_task_data' in c for c in collections]):
alf_collections = [f'alf/task_{c[-2:]}' for c in collections]
else:
alf_collections = None
experiment_description = read_params(session_path)
collections = []
if experiment_description is not None:
task_protocols = experiment_description.get('tasks', [])
for i, (protocol, task_info) in enumerate(chain(*map(dict.items, task_protocols))):
if 'passiveChoiceWorld' in protocol:
continue
collection = task_info.get('collection', f'raw_task_data_{i:02}')
if collection == 'raw_passive_data':
continue
collections.append(collection)
else:
collections = ['raw_behavior_data']
settings = Path(session_path).rglob('_iblrig_taskSettings.raw.json')
for setting in settings:
if setting.parent.name != 'raw_passive_data':
collections.append(setting.parent.name)

if len(collections) == 1 and collections[0] == 'raw_behavior_data':
alf_collections = ['alf']
elif all(['raw_task_data' in c for c in collections]):
alf_collections = [f'alf/task_{c[-2:]}' for c in collections]
else:
alf_collections = None

return collections, alf_collections

Expand Down Expand Up @@ -561,6 +539,7 @@ def get_training_info_for_session(session_paths, one, force=True):

un_protocols = np.unique(protocols)
# Example, training, training, biased - training would be combined, biased not
sess_dict = None
if len(un_protocols) != 1:
print(f'Different protocols in same session {session_path} : {protocols}')
for prot in un_protocols:
Expand Down

0 comments on commit 6075d57

Please sign in to comment.