Skip to content

Commit

Permalink
feat: Add trial_columns and trialize transform() method (#473)
Browse files Browse the repository at this point in the history
  • Loading branch information
dkrako authored Aug 18, 2023
1 parent 8b02c28 commit 7c38f5c
Show file tree
Hide file tree
Showing 10 changed files with 125 additions and 21 deletions.
39 changes: 36 additions & 3 deletions docs/source/tutorials/local-dataset.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -169,15 +169,48 @@
"}"
]
},
{
"cell_type": "markdown",
"id": "9e63f355",
"metadata": {},
"source": [
"## Column Definitions"
]
},
{
"cell_type": "markdown",
"id": "18b5b563",
"metadata": {},
"source": [
"The `trial_columns` argument can be used to specify which columns define a single trial.\n",
"\n",
"This is important for correctly applying all preprocessing methods.\n",
"\n",
"For this very small single user dataset a trial ist just defined by `text_id` and `page_id`."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5f5e4789",
"metadata": {},
"outputs": [],
"source": [
"trial_columns = ['text_id', 'page_id']"
]
},
{
"cell_type": "markdown",
"id": "e039de9b",
"metadata": {},
"source": [
"The `time_column` and `pixel_columns` arguments can be used to specify your columns according to their content.\n",
"This is needed for correctly applying all preprocessing methods.\n",
"The `time_column` and `pixel_columns` arguments can be used to correctly map the columns in your dataframes.\n",
"\n",
"Depending on the content of your dataset, you can alternatively also provide `position_columns`, `velocity_columns` and `acceleration_columns`.\n",
"\n",
"Specifying these columns is needed for correctly applying preprocessing methods. For example, if you want to apply the `pix2deg` method, you will need to specify `pixel_columns` accordingly.\n",
"\n",
"Depending on the content of your dataset, you can alternatively also provide `position_columns`, `velocity_columns` and `acceleration_columns`."
"If your dataset has gaze positions available only in degrees of visual angle, you have to specify the `position_columns` instead."
]
},
{
Expand Down
1 change: 1 addition & 0 deletions src/pymovements/dataset/dataset_definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ class DatasetDefinition:

column_map: dict[str, str] = field(default_factory=dict)

trial_columns: list[str] | None = None
time_column: str | None = None
pixel_columns: list[str] | None = None
position_columns: list[str] | None = None
Expand Down
3 changes: 3 additions & 0 deletions src/pymovements/dataset/dataset_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,7 @@ def load_gaze_files(
gaze_df = GazeDataFrame(
gaze_data,
experiment=definition.experiment,
trial_columns=definition.trial_columns,
)

elif preprocessed and extension == 'csv':
Expand Down Expand Up @@ -261,6 +262,7 @@ def load_gaze_files(
gaze_df = GazeDataFrame(
gaze_data,
experiment=definition.experiment,
trial_columns=definition.trial_columns,
time_column=time_column,
pixel_columns=pixel_columns,
position_columns=position_columns,
Expand All @@ -272,6 +274,7 @@ def load_gaze_files(
gaze_df = GazeDataFrame(
gaze_data,
experiment=definition.experiment,
trial_columns=definition.trial_columns,
time_column=definition.time_column,
pixel_columns=definition.pixel_columns,
position_columns=definition.position_columns,
Expand Down
4 changes: 4 additions & 0 deletions src/pymovements/datasets/gazebase.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,10 @@ class GazeBase(DatasetDefinition):
},
)

trial_columns: list[str] = field(
default_factory=lambda: ['round_id', 'subject_id', 'session_id', 'task_name'],
)

time_column: str = 'n'

position_columns: list[str] = field(default_factory=lambda: ['x', 'y'])
Expand Down
4 changes: 4 additions & 0 deletions src/pymovements/datasets/gazebasevr.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,10 @@ class GazeBaseVR(DatasetDefinition):
},
)

trial_columns: list[str] = field(
default_factory=lambda: ['round_id', 'subject_id', 'session_id', 'task_name'],
)

time_column: str = 'n'

position_columns: list[str] = field(default_factory=lambda: ['lx', 'ly', 'rx', 'ry', 'x', 'y'])
Expand Down
4 changes: 4 additions & 0 deletions src/pymovements/datasets/judo1000.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,10 @@ class JuDo1000(DatasetDefinition):
},
)

trial_columns: list[str] = field(
default_factory=lambda: ['subject_id', 'session_id', 'trial_id'],
)

time_column: str = 'time'
pixel_columns: list[str] = field(
default_factory=lambda: [
Expand Down
2 changes: 2 additions & 0 deletions src/pymovements/datasets/toy_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,8 @@ class ToyDataset(DatasetDefinition):
},
)

trial_columns: list[str] = field(default_factory=lambda: ['text_id', 'page_id'])

time_column: str = 'timestamp'

pixel_columns: list[str] = field(default_factory=lambda: ['x', 'y'])
Expand Down
2 changes: 2 additions & 0 deletions src/pymovements/datasets/toy_dataset_eyelink.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,8 @@ class ToyDatasetEyeLink(DatasetDefinition):

column_map: dict[str, str] = field(default_factory=lambda: {})

trial_columns: list[str] = field(default_factory=lambda: ['subject_id', 'session_id'])

time_column: str = 'time'

pixel_columns: list[str] = field(default_factory=lambda: ['x_pix', 'y_pix'])
Expand Down
13 changes: 12 additions & 1 deletion src/pymovements/gaze/gaze_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def __init__(
data: pl.DataFrame | None = None,
experiment: Experiment | None = None,
*,
trial_columns: str | list[str] | None = None,
time_column: str | None = None,
pixel_columns: list[str] | None = None,
position_columns: list[str] | None = None,
Expand Down Expand Up @@ -166,6 +167,8 @@ def __init__(
data = data.clone()
self.frame = data

self.trial_columns = trial_columns

if time_column is not None:
self.frame = self.frame.rename({time_column: 'time'})

Expand Down Expand Up @@ -272,7 +275,15 @@ def transform(
_check_n_components(self.n_components)
kwargs['n_components'] = self.n_components

self.frame = self.frame.with_columns(transform_method(**kwargs))
if self.trial_columns is None:
self.frame = self.frame.with_columns(transform_method(**kwargs))
else:
self.frame = pl.concat(
[
df.with_columns(transform_method(**kwargs))
for group, df in self.frame.groupby(self.trial_columns, maintain_order=True)
],
)

def pix2deg(self) -> None:
"""Compute gaze positions in degrees of visual angle from pixel position coordinates.
Expand Down
74 changes: 57 additions & 17 deletions tests/gaze/gaze_transform_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,10 +101,7 @@ def fixture_experiment():
{
'time': np.arange(1000, 1010, 2),
'x_pix': np.arange(0, 1, 0.2),
'y_pix': [
20.0, 20.200000000000003, 20.400000000000006,
20.60000000000001, 20.80000000000001,
],
'y_pix': [20.0, 20.2, 20.4, 20.6, 20.8],
},
),
pixel_columns=['x_pix', 'y_pix'],
Expand Down Expand Up @@ -206,7 +203,7 @@ def fixture_experiment():
'time': [1000, 1000],
'x_pix': [49.5, 49.5],
'y_pix': [0.0, 0.0],
'x_dva': [26.335410003881348, 26.335410003881348],
'x_dva': [26.3354, 26.3354],
'y_dva': [0.0, 0.0],
},
),
Expand Down Expand Up @@ -243,7 +240,7 @@ def fixture_experiment():
'time': [1000, 1000],
'x_pix': [49.5, 49.5],
'y_pix': [0.0, 0.0],
'x_dva': [26.335410003881348, 26.335410003881348],
'x_dva': [26.3354, 26.3354],
'y_dva': [0.0, 0.0],
},
),
Expand Down Expand Up @@ -281,7 +278,7 @@ def fixture_experiment():
'x_pix': [49.5],
'y_pix': [0.0],
'x_dva': [0.0],
'y_dva': [-26.335410003881348],
'y_dva': [-26.3354],
},
),
pixel_columns=['x_pix', 'y_pix'],
Expand All @@ -294,9 +291,10 @@ def fixture_experiment():
{
'data': pl.from_dict(
{
'time': [1000, 1001, 1002],
'x_dva': [1.0, 1.0, 1.0],
'y_dva': [1.0, 1.1, 1.2],
'trial_id': [1, 1, 1, 2, 2, 2],
'time': [1000, 1001, 1002, 1003, 1004, 1005],
'x_dva': [1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
'y_dva': [1.0, 1.1, 1.2, 1.0, 1.1, 1.2],
},
),
'experiment': pm.Experiment(
Expand All @@ -314,11 +312,12 @@ def fixture_experiment():
pm.GazeDataFrame(
data=pl.from_dict(
{
'time': [1000, 1001, 1002],
'x_dva': [1.0, 1.0, 1.0],
'y_dva': [1.0, 1.1, 1.2],
'x_vel': [None, 0.0, 0.0],
'y_vel': [None, 100.00000000000009, 99.99999999999987],
'trial_id': [1, 1, 1, 2, 2, 2],
'time': [1000, 1001, 1002, 1003, 1004, 1005],
'x_dva': [1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
'y_dva': [1.0, 1.1, 1.2, 1.0, 1.1, 1.2],
'x_vel': [None, 0.0, 0.0, 0.0, 0.0, 0.0],
'y_vel': [None, 100.0, 100.0, -200.0, 100.0, 100.0],
},
),
position_columns=['x_dva', 'y_dva'],
Expand Down Expand Up @@ -355,7 +354,7 @@ def fixture_experiment():
'x_dva': [1.0, 1.0, 1.0],
'y_dva': [1.0, 1.1, 1.2],
'x_vel': [None, 0.0, None],
'y_vel': [None, 99.99999999999997, None],
'y_vel': [None, 100.0, None],
},
),
position_columns=['x_dva', 'y_dva'],
Expand Down Expand Up @@ -392,14 +391,55 @@ def fixture_experiment():
'x_dva': [1.0, 1.0, 1.0, 1.0, 1.0],
'y_dva': [1.0, 1.1, 1.2, 1.3, 1.4],
'x_vel': [None, None, 0.0, None, None],
'y_vel': [None, None, 100.00000000000001, None, None],
'y_vel': [None, None, 100.0, None, None],
},
),
position_columns=['x_dva', 'y_dva'],
velocity_columns=['x_vel', 'y_vel'],
),
id='pos2vel_five_point',
),
pytest.param(
{
'data': pl.from_dict(
{
'trial_id': [1, 1, 1, 2, 2, 2],
'time': [1000, 1001, 1002, 1003, 1004, 1005],
'x_dva': [1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
'y_dva': [1.0, 1.1, 1.2, 1.0, 1.1, 1.2],
},
),
'experiment': pm.Experiment(
sampling_rate=1000,
screen_width_px=100,
screen_height_px=100,
screen_width_cm=100,
screen_height_cm=100,
distance_cm=100,
origin='lower left',
),
'position_columns': ['x_dva', 'y_dva'],
'trial_columns': 'trial_id',
},
'pos2vel', {'method': 'preceding'},
pm.GazeDataFrame(
data=pl.from_dict(
{
'trial_id': [1, 1, 1, 2, 2, 2],
'time': [1000, 1001, 1002, 1003, 1004, 1005],
'x_dva': [1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
'y_dva': [1.0, 1.1, 1.2, 1.0, 1.1, 1.2],
'x_vel': [None, 0.0, 0.0, None, 0.0, 0.0],
'y_vel': [None, 100.0, 100.0, None, 100.0, 100.0],
},
),
position_columns=['x_dva', 'y_dva'],
velocity_columns=['x_vel', 'y_vel'],
),
id='pos2vel_preceding_trialize_single_column_str',
),
],
)
def test_gaze_transform_expected_frame(
Expand Down

0 comments on commit 7c38f5c

Please sign in to comment.