Skip to content

Commit

Permalink
feat: Add GazeDataFrame.apply() (#558)
Browse files Browse the repository at this point in the history
  • Loading branch information
dkrako authored and josephine-funken committed Sep 25, 2023
1 parent 0af0bf5 commit fbd6d37
Show file tree
Hide file tree
Showing 4 changed files with 289 additions and 0 deletions.
16 changes: 16 additions & 0 deletions src/pymovements/events/detection/_library.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,22 @@ def get(cls, name: str) -> Callable[..., EventDataFrame]:
"""
return cls.methods[name]

@classmethod
def __contains__(cls, name: str) -> bool:
"""Check if class contains method of given name.
Parameters
----------
name: str
Name of the method to check.
Returns
-------
bool
True if EventDetectionLibrary contains method with given name, else False.
"""
return name in cls.methods


def register_event_detection(
method: Callable[..., EventDataFrame],
Expand Down
21 changes: 21 additions & 0 deletions src/pymovements/gaze/gaze_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,27 @@ def __init__(
else:
self.events = events.copy()

def apply(
self,
function: str,
**kwargs: Any,
) -> None:
"""Apply preprocessing method to GazeDataFrame.
Parameters
----------
function: str
Name of the preprocessing method to apply.
kwargs:
kwargs that will be forwarded when calling the preprocessing method.
"""
if transforms.TransformLibrary.__contains__(function):
self.transform(function, **kwargs)
elif pm.events.EventDetectionLibrary.__contains__(function):
self.detect(function, **kwargs)
else:
raise ValueError(f"unsupported method '{function}'")

def transform(
self,
transform_method: str | Callable[..., pl.Expr],
Expand Down
16 changes: 16 additions & 0 deletions src/pymovements/gaze/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,22 @@ def get(cls, name: str) -> Callable[..., pl.Expr]:
"""
return cls.methods[name]

@classmethod
def __contains__(cls, name: str) -> bool:
"""Check if class contains method of given name.
Parameters
----------
name: str
Name of the method to check.
Returns
-------
bool
True if TransformsLibrary contains method with given name, else False.
"""
return name in cls.methods


def register_transform(method: TransformMethod) -> TransformMethod:
"""Register a transform method."""
Expand Down
236 changes: 236 additions & 0 deletions tests/gaze/apply_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,236 @@
# Copyright (c) 2023 The pymovements Project Authors
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
"""Test GazeDataFrame detect method."""
import numpy as np
import polars as pl
import pytest
from polars.testing import assert_frame_equal

import pymovements as pm
from pymovements.synthetic import step_function


@pytest.mark.parametrize(
('method', 'kwargs', 'gaze', 'expected'),
[
pytest.param(
'ivt',
{
'velocity_threshold': 1,
'minimum_duration': 2,
'eye': 'cyclops',
},
pm.gaze.from_numpy(
velocity=step_function(
length=100, steps=[0, 10], values=[(1, 1, 1, 1, 0, 0), (0, 0, 0, 0, 0, 0)],
),
orient='row',
experiment=pm.Experiment(1024, 768, 38, 30, 60, 'center', 10),
),
pm.gaze.from_numpy(
velocity=step_function(
length=100, steps=[0, 10], values=[(1, 1, 1, 1, 0, 0), (0, 0, 0, 0, 0, 0)],
),
orient='row',
experiment=pm.Experiment(1024, 768, 38, 30, 60, 'center', 10),
events=pm.events.EventDataFrame(
name='fixation',
onsets=[0],
offsets=[99],
),
),
id='ivt_constant_position_monocular_fixation_six_components_eye_cyclops',
),
pytest.param(
'microsaccades',
{
'threshold': 1e-5,
},
pm.gaze.from_numpy(
velocity=step_function(
length=100,
steps=[20, 30, 70, 80],
values=[(9, 9), (0, 0), (9, 9), (0, 0)],
start_value=(0, 0),
),
orient='row',
experiment=pm.Experiment(1024, 768, 38, 30, 60, 'center', 10),
),
pm.gaze.from_numpy(
velocity=step_function(
length=100,
steps=[20, 30, 70, 80],
values=[(9, 9), (0, 0), (9, 9), (0, 0)],
start_value=(0, 0),
),
orient='row',
experiment=pm.Experiment(1024, 768, 38, 30, 60, 'center', 10),
events=pm.EventDataFrame(
name='saccade',
onsets=[20, 70],
offsets=[29, 79],
),
),
id='microsaccades_four_steps_two_saccades',
),
pytest.param(
'fill',
{},
pm.gaze.from_numpy(
time=np.arange(0, 100),
events=pm.EventDataFrame(
name=['fixation', 'saccade'], onsets=[0, 50], offsets=[40, 100],
),
),
pm.gaze.from_numpy(
time=np.arange(0, 100),
events=pm.EventDataFrame(
name=['fixation', 'saccade', 'unclassified'],
onsets=[0, 50, 40],
offsets=[40, 100, 49],
),
),
id='fill_fixation_10_ms_break_then_saccade_until_end_single_fill',
),
pytest.param(
'downsample',
{'factor': 2},
pm.GazeDataFrame(
data=pl.from_dict(
{
'time': np.arange(1000, 1010, 1),
'x_pix': np.arange(0, 1, 0.1),
'y_pix': np.arange(20, 21, 0.1),
},
),
pixel_columns=['x_pix', 'y_pix'],
),
pm.GazeDataFrame(
data=pl.from_dict(
{
'time': np.arange(1000, 1010, 2),
'x_pix': np.arange(0, 1, 0.2),
'y_pix': [20.0, 20.2, 20.4, 20.6, 20.8],
},
),
pixel_columns=['x_pix', 'y_pix'],
),
id='downsample_factor_2',
),
pytest.param(
'pix2deg',
{},
pm.GazeDataFrame(
data=pl.from_dict(
{
'time': [1000, 1000],
'x_pix': [(100 - 1) / 2, (100 - 1) / 2],
'y_pix': [0.0, 0.0],
},
),
experiment=pm.Experiment(100, 100, 100, 100, 100, 'center', 1000),
pixel_columns=['x_pix', 'y_pix'],
),
pm.GazeDataFrame(
data=pl.from_dict(
{
'time': [1000, 1000],
'x_pix': [49.5, 49.5],
'y_pix': [0.0, 0.0],
'x_dva': [26.3354, 26.3354],
'y_dva': [0.0, 0.0],
},
),
pixel_columns=['x_pix', 'y_pix'],
position_columns=['x_dva', 'y_dva'],
),
id='pix2deg_origin_center',
),
pytest.param(
'pos2vel',
{'method': 'preceding'},
pm.GazeDataFrame(
data=pl.from_dict(
{
'trial_id': [1, 1, 1, 2, 2, 2],
'time': [1000, 1001, 1002, 1003, 1004, 1005],
'x': [1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
'y': [1.0, 1.1, 1.2, 1.0, 1.1, 1.2],
},
),
experiment=pm.Experiment(100, 100, 100, 100, 100, 'center', 1000),
trial_columns='trial_id',
position_columns=['x', 'y'],
),
pm.GazeDataFrame(
data=pl.from_dict(
{
'trial_id': [1, 1, 1, 2, 2, 2],
'time': [1000, 1001, 1002, 1003, 1004, 1005],
'x_dva': [1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
'y_dva': [1.0, 1.1, 1.2, 1.0, 1.1, 1.2],
'x_vel': [None, 0.0, 0.0, None, 0.0, 0.0],
'y_vel': [None, 100.0, 100.0, None, 100.0, 100.0],
},
),
position_columns=['x_dva', 'y_dva'],
velocity_columns=['x_vel', 'y_vel'],
),
id='pos2vel_preceding_trialize_single_column_str',
),
],
)
def test_gaze_apply(method, kwargs, gaze, expected):
gaze.apply(method, **kwargs)
assert_frame_equal(gaze.frame, expected.frame)
assert_frame_equal(gaze.events.frame, expected.events.frame)


@pytest.mark.parametrize(
('method', 'kwargs', 'gaze', 'exception', 'exception_msg'),
[
pytest.param(
'foobar',
{},
pm.gaze.from_numpy(
velocity=step_function(
length=100, steps=[0, 10], values=[(1, 1, 1, 1, 0, 0), (0, 0, 0, 0, 0, 0)],
),
orient='row',
experiment=pm.Experiment(1024, 768, 38, 30, 60, 'center', 10),
),
ValueError,
"unsupported method 'foobar'",
id='unknown_method',
),
],
)
def test_gaze_apply_raises_exception(method, kwargs, gaze, exception, exception_msg):
with pytest.raises(exception) as exc_info:
gaze.apply(method, **kwargs)

msg, = exc_info.value.args
assert msg == exception_msg

0 comments on commit fbd6d37

Please sign in to comment.