Skip to content

Commit

Permalink
feat: Add GazeDataFrame.detect() (#511)
Browse files Browse the repository at this point in the history
  • Loading branch information
dkrako authored Sep 15, 2023
1 parent 88af8d0 commit 9a10c04
Show file tree
Hide file tree
Showing 8 changed files with 1,082 additions and 213 deletions.
158 changes: 26 additions & 132 deletions src/pymovements/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
"""This module provides the base dataset class."""
from __future__ import annotations

import inspect
from collections.abc import Callable
from copy import deepcopy
from pathlib import Path
Expand All @@ -34,7 +33,6 @@
from pymovements.dataset.dataset_definition import DatasetDefinition
from pymovements.dataset.dataset_library import DatasetLibrary
from pymovements.dataset.dataset_paths import DatasetPaths
from pymovements.events.detection import EventDetectionLibrary
from pymovements.events.frame import EventDataFrame
from pymovements.events.processing import EventGazeProcessor
from pymovements.gaze import GazeDataFrame
Expand Down Expand Up @@ -360,7 +358,7 @@ def detect_events(
self,
method: Callable[..., EventDataFrame] | str,
*,
eye: str | None = 'auto',
eye: str = 'auto',
clear: bool = False,
verbose: bool = True,
**kwargs: Any,
Expand Down Expand Up @@ -393,134 +391,19 @@ def detect_events(
Dataset
Returns self, useful for method cascading.
"""
self._check_gaze_dataframe()

if isinstance(method, str):
method = EventDetectionLibrary.get(method)

# this is just a work-around until merged columns are standard behavior
# https://github.com/aeye-lab/pymovements/pull/443
unnested_columns = {}
if 'position' in self.gaze[0].frame.columns:
unnested_columns_pos = [
'x_left_pos', 'y_left_pos',
'x_right_pos', 'y_right_pos',
'x_avg_pos', 'y_avg_pos',
][:self.gaze[0].n_components]
unnested_columns['position'] = unnested_columns_pos
else:
raise pl.exceptions.ColumnNotFoundError(
f'Column \'position\' not found.'
f' Available columns are: {self.gaze[0].frame.columns}',
)

if 'velocity' in self.gaze[0].frame.columns:
unnested_columns_vel = [
'x_left_vel', 'y_left_vel',
'x_right_vel', 'y_right_vel',
'x_avg_vel', 'y_avg_vel',
][:self.gaze[0].n_components]
unnested_columns['velocity'] = unnested_columns_vel
else:
raise pl.exceptions.ColumnNotFoundError(
f'Column \'velocity\' not found.'
f' Available columns are: {self.gaze[0].frame.columns}',
)

self.gaze[0].unnest('position', output_columns=unnested_columns['position'])
self.gaze[0].unnest('velocity', output_columns=unnested_columns['velocity'])

if (
isinstance(self.gaze[0].n_components, int)
and self.gaze[0].n_components < 4
and eye not in [None, 'auto']
):
raise AttributeError()

# Automatically infer eye to use for event detection.
if eye == 'auto':
if 'x_avg_pos' in self.gaze[0].columns:
eye = 'avg'
elif 'x_right_pos' in self.gaze[0].columns:
eye = 'right'
else:
eye = 'left'

position_columns = [f'x_{eye}_pos', f'y_{eye}_pos']
velocity_columns = [f'x_{eye}_vel', f'y_{eye}_vel']

# this is just a work-around until merged columns are standard behavior
# https://github.com/aeye-lab/pymovements/pull/443
self.gaze[0].nest(
input_columns=unnested_columns['position'],
output_column='position',
)
self.gaze[0].nest(
input_columns=unnested_columns['velocity'],
output_column='velocity',
return self.detect(
method=method,
eye=eye,
clear=clear,
verbose=verbose,
**kwargs,
)

disable_progressbar = not verbose

if not self.events or clear:
self.events = [EventDataFrame() for _ in self.fileinfo.iter_rows()]

for file_id, (gaze_df, fileinfo_row) in tqdm(
enumerate(zip(self.gaze, self.fileinfo.to_dicts())), disable=disable_progressbar,
):
# this is just a work-around until merged columns are standard behavior
# https://github.com/aeye-lab/pymovements/pull/443
gaze_df.unnest('position', output_columns=unnested_columns['position'])
gaze_df.unnest('velocity', output_columns=unnested_columns['velocity'])

positions = gaze_df.frame.select(position_columns).to_numpy()
velocities = gaze_df.frame.select(velocity_columns).to_numpy()
timesteps = gaze_df.frame.get_column('time').to_numpy()

method_args = inspect.getfullargspec(method).args

if 'positions' in method_args:
kwargs['positions'] = positions

if 'velocities' in method_args:
kwargs['velocities'] = velocities

if 'events' in method_args:
kwargs['events'] = self.events[file_id]

kwargs['timesteps'] = timesteps

new_event_df = method(**kwargs)

new_event_df.frame = dataset_files.add_fileinfo(
definition=self.definition,
df=new_event_df.frame,
fileinfo=fileinfo_row,
)

self.events[file_id].frame = pl.concat(
[self.events[file_id].frame, new_event_df.frame],
how='diagonal',
)

# this is just a work-around until merged columns are standard behavior
# https://github.com/aeye-lab/pymovements/pull/443
gaze_df.nest(
input_columns=unnested_columns['position'],
output_column='position',
)
gaze_df.nest(
input_columns=unnested_columns['velocity'],
output_column='velocity',
)

return self

def detect(
self,
method: Callable[..., EventDataFrame] | str,
*,
eye: str | None = 'auto',
eye: str = 'auto',
clear: bool = False,
verbose: bool = True,
**kwargs: Any,
Expand Down Expand Up @@ -555,13 +438,24 @@ def detect(
Dataset
Returns self, useful for method cascading.
"""
return self.detect_events(
method=method,
eye=eye,
clear=clear,
verbose=verbose,
**kwargs,
)
self._check_gaze_dataframe()

if not self.events:
self.events = [gaze.events for gaze in self.gaze]

disable_progressbar = not verbose
for file_id, (gaze, fileinfo_row) in tqdm(
enumerate(zip(self.gaze, self.fileinfo.to_dicts())), disable=disable_progressbar,
):
gaze.detect(method, eye=eye, clear=clear, **kwargs)
# workaround until events are fully part of the GazeDataFrame
gaze.events.frame = dataset_files.add_fileinfo(
definition=self.definition,
df=gaze.events.frame,
fileinfo=fileinfo_row,
)
self.events[file_id] = gaze.events
return self

def compute_event_properties(
self,
Expand Down
2 changes: 1 addition & 1 deletion src/pymovements/dataset/dataset_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,7 @@ def add_fileinfo(
[
pl.lit(value).alias(column)
for column, value in fileinfo.items()
if column != 'filepath'
if column != 'filepath' and column not in df.columns
] + [pl.all()],
)

Expand Down
3 changes: 3 additions & 0 deletions src/pymovements/events/detection/_ivt.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,9 @@ def ivt(
if include_nan:
candidates = filter_candidates_remove_nans(candidates=candidates, values=velocities)

# Remove empty candidates.
candidates = [candidate for candidate in candidates if len(candidate) > 0]

# Filter all candidates by minimum duration.
candidates = [
candidate for candidate in candidates
Expand Down
Loading

0 comments on commit 9a10c04

Please sign in to comment.