Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: split_gaze_data into trial #859

Open
wants to merge 20 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
1c6c769
feat: split_gaze_data into trial
SiQube Oct 23, 2024
976695b
docs: Add missing modules to documentation (#866)
dkrako Oct 23, 2024
953ade3
hotfix: check whether public dataset has gaze files (#872)
SiQube Oct 24, 2024
b842bdb
docs: correctly add EyeTracker class to gaze module (#876)
dkrako Oct 24, 2024
cc1bae1
feat: add support for .ias files in stimulus.text.from_file() (#858)
SiQube Oct 24, 2024
5417804
dataset: beijing sentence corpus (#857)
SiQube Oct 24, 2024
92b49a7
dataset: add InteRead dataset (#862)
SiQube Oct 24, 2024
f0b69a9
fix: copy event resource files instead of moving them to events direc…
SiQube Oct 24, 2024
e6a9ced
hotfix: CopCo dataset precomputed eventsloading (#873)
SiQube Oct 24, 2024
1b8c4bd
ci: ignore too-many-public-methods (#882)
dkrako Oct 25, 2024
69ef837
ci: pre-commit autoupdate (#889)
pre-commit-ci[bot] Oct 29, 2024
cfbce95
ci: pre-commit autoupdate (#890)
pre-commit-ci[bot] Nov 5, 2024
47e734d
build: add support for python 3.13 (#845)
SiQube Nov 7, 2024
166b076
build: update nbsphinx requirement from <0.9.5,>=0.8.8 to >=0.8.8,<0.…
dependabot[bot] Nov 7, 2024
495e5d9
ci: pre-commit autoupdate (#896)
pre-commit-ci[bot] Nov 12, 2024
b691e6d
build: update setuptools-git-versioning requirement from <2 to <3 (#895)
dependabot[bot] Nov 12, 2024
88113c8
hotfix: download link fakenewsperception dataset (#897)
SiQube Nov 13, 2024
21fd0d2
feat: Store metadata from ASC in experiment metadata (#884)
saeub Nov 14, 2024
0856658
move split method to gaze dataframe
SiQube Nov 17, 2024
4751e41
Merge branch 'main' into split-gaze-files-into-trial-dataframes
SiQube Nov 17, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions src/pymovements/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,32 @@ def load_precomputed_reading_measures(self) -> None:
self.paths,
)

def _split_gaze_data(
SiQube marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can probably call remove the leading _

self,
by: list[str] | str,
) -> None:
"""Split gaze data into separated GazeDataFrame's.

Parameters
----------
by: list[str] | str
Column(s) to split dataframe by.
"""
by = [by] if isinstance(by, str) else by

fileinfo_dicts = self.fileinfo['gaze'].to_dicts()

all_gaze_frames = []
all_fileinfo_rows = []

for frame, fileinfo_row in zip(self.gaze, fileinfo_dicts):
split_frames = frame.split(by=by)
all_gaze_frames.extend(split_frames)
all_fileinfo_rows.extend([fileinfo_row] * len(split_frames))

self.gaze = all_gaze_frames
self.fileinfo['gaze'] = pl.concat([pl.from_dict(row) for row in all_fileinfo_rows])

def split_precomputed_events(
self,
by: list[str] | str,
Expand Down
41 changes: 41 additions & 0 deletions src/pymovements/gaze/gaze_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,14 @@ def __init__(

# Remove this attribute once #893 is fixed
self._metadata: dict[str, Any] | None = None
self.auto_column_detect = auto_column_detect
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is only needed as a flag for autodetecting pixel, velocity, etc. columns. you don't need to store this.

self.time_column = time_column
self.time_unit = time_unit
self.pixel_columns = pixel_columns
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why should we storing these? the pixel_columns are merged into a single column named pixel. after terminating __init__() the pixel_columns won't exist anymore in the dataframe.

self.position_columns = position_columns
self.velocity_columns = velocity_columns
self.acceleration_columns = acceleration_columns
self.distance_column = distance_column
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the distance columns is called distance after initialization. if it was named different before, it is renamed.


def apply(
self,
Expand All @@ -307,6 +315,39 @@ def apply(
else:
raise ValueError(f"unsupported method '{function}'")

def split(self, by: list[str] | str) -> list[GazeDataFrame]:
"""Split the GazeDataFrame into multiple frames based on specified column(s).

Parameters
----------
by: list[str] | str
Column name(s) to split the DataFrame by. If a single string is provided,
it will be used as a single column name. If a list is provided, the DataFrame
will be split by unique combinations of values in all specified columns.

Returns
-------
list[GazeDataFrame]
A list of new GazeDataFrame instances, each containing a partition of the
original data with all metadata and configurations preserved.
"""
by = [by] if isinstance(by, str) else by
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the by argument to partion_by() can be of type str or list[str], so this conversion shouldn't be needed

return [
GazeDataFrame(
new_frame,
experiment=self.experiment,
auto_column_detect=self.auto_column_detect,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the columns are already detected in self.frame

trial_columns=self.trial_columns,
time_column=self.time_column,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the time columns is called time and will be autodetected

time_unit=self.time_unit,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the time_unit is already converted to a datetime type in the time column of self.frame, so you don't need to pass the value from the original init

position_columns=self.position_columns,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are you sure this works? the position_columns won't exist anymore in self.frame

velocity_columns=self.velocity_columns,
acceleration_columns=self.acceleration_columns,
distance_column=self.distance_column,
)
for new_frame in self.frame.partition_by(by=by)
]

def transform(
self,
transform_method: str | Callable[..., pl.Expr],
Expand Down
53 changes: 51 additions & 2 deletions tests/unit/dataset/dataset_test.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add tests to gaze_dataframe_test.py?

please check the resulting splits in a similar way like it is done in #879 , e.g. check equality of the by-column within a split and check for difference to all other splits.

Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,8 @@ def mock_toy(
'y_left_pix': np.zeros(1000),
'x_right_pix': np.zeros(1000),
'y_right_pix': np.zeros(1000),
'trial_id_1': np.concatenate([np.zeros(500), np.ones(500)]),
'trial_id_2': ['a'] * 200 + ['b'] * 200 + ['c'] * 600,
},
schema={
'subject_id': pl.Int64,
Expand All @@ -154,6 +156,8 @@ def mock_toy(
'y_left_pix': pl.Float64,
'x_right_pix': pl.Float64,
'y_right_pix': pl.Float64,
'trial_id_1': pl.Float64,
'trial_id_2': pl.Utf8,
},
)
pixel_columns = ['x_left_pix', 'y_left_pix', 'x_right_pix', 'y_right_pix']
Expand All @@ -169,6 +173,8 @@ def mock_toy(
'y_right_pix': np.zeros(1000),
'x_avg_pix': np.zeros(1000),
'y_avg_pix': np.zeros(1000),
'trial_id_1': np.concatenate([np.zeros(500), np.ones(500)]),
'trial_id_2': ['a'] * 200 + ['b'] * 200 + ['c'] * 600,
},
schema={
'subject_id': pl.Int64,
Expand All @@ -179,6 +185,8 @@ def mock_toy(
'y_right_pix': pl.Float64,
'x_avg_pix': pl.Float64,
'y_avg_pix': pl.Float64,
'trial_id_1': pl.Float64,
'trial_id_2': pl.Utf8,
},
)
pixel_columns = [
Expand All @@ -192,12 +200,16 @@ def mock_toy(
'time': np.arange(1000),
'x_left_pix': np.zeros(1000),
'y_left_pix': np.zeros(1000),
'trial_id_1': np.concatenate([np.zeros(500), np.ones(500)]),
'trial_id_2': ['a'] * 200 + ['b'] * 200 + ['c'] * 600,
},
schema={
'subject_id': pl.Int64,
'time': pl.Int64,
'x_left_pix': pl.Float64,
'y_left_pix': pl.Float64,
'trial_id_1': pl.Float64,
'trial_id_2': pl.Utf8,
},
)
pixel_columns = ['x_left_pix', 'y_left_pix']
Expand All @@ -208,12 +220,16 @@ def mock_toy(
'time': np.arange(1000),
'x_right_pix': np.zeros(1000),
'y_right_pix': np.zeros(1000),
'trial_id_1': np.concatenate([np.zeros(500), np.ones(500)]),
'trial_id_2': ['a'] * 200 + ['b'] * 200 + ['c'] * 600,
},
schema={
'subject_id': pl.Int64,
'time': pl.Int64,
'x_right_pix': pl.Float64,
'y_right_pix': pl.Float64,
'trial_id_1': pl.Float64,
'trial_id_2': pl.Utf8,
},
)
pixel_columns = ['x_right_pix', 'y_right_pix']
Expand All @@ -224,12 +240,16 @@ def mock_toy(
'time': np.arange(1000),
'x_pix': np.zeros(1000),
'y_pix': np.zeros(1000),
'trial_id_1': np.concatenate([np.zeros(500), np.ones(500)]),
'trial_id_2': ['a'] * 200 + ['b'] * 200 + ['c'] * 600,
},
schema={
'subject_id': pl.Int64,
'time': pl.Int64,
'x_pix': pl.Float64,
'y_pix': pl.Float64,
'trial_id_1': pl.Float64,
'trial_id_2': pl.Utf8,
},
)
pixel_columns = ['x_pix', 'y_pix']
Expand Down Expand Up @@ -1000,7 +1020,8 @@ def test_detect_events_attribute_error(gaze_dataset_configuration):
},
(
"Column 'position' not found. Available columns are: "
"['time', 'subject_id', 'pixel', 'custom_position', 'velocity']"
"['time', 'trial_id_1', 'trial_id_2', 'subject_id', "
"'pixel', 'custom_position', 'velocity']"
),
id='no_position',
),
Expand All @@ -1012,7 +1033,8 @@ def test_detect_events_attribute_error(gaze_dataset_configuration):
},
(
"Column 'velocity' not found. Available columns are: "
"['time', 'subject_id', 'pixel', 'position', 'custom_velocity']"
"['time', 'trial_id_1', 'trial_id_2', 'subject_id', "
"'pixel', 'position', 'custom_velocity']"
),
id='no_velocity',
),
Expand Down Expand Up @@ -1930,3 +1952,30 @@ def test_load_split_precomputed_events(precomputed_dataset_configuration, by, ex
dataset.load()
dataset.split_precomputed_events(by)
assert len(dataset.precomputed_events) == expected_len


@pytest.mark.parametrize(
('by', 'expected_len'),
[
pytest.param(
'trial_id_1',
40,
id='subset_int',
),
pytest.param(
'trial_id_2',
60,
id='subset_int',
),
pytest.param(
['trial_id_1', 'trial_id_2'],
80,
id='subset_int',
),
],
)
def test_load_split_gaze(gaze_dataset_configuration, by, expected_len):
dataset = pm.Dataset(**gaze_dataset_configuration['init_kwargs'])
dataset.load()
dataset._split_gaze_data(by)
assert len(dataset.gaze) == expected_len
Loading