Skip to content

Commit

Permalink
feat: split TextStimulus by column values (#879)
Browse files Browse the repository at this point in the history
Co-authored-by: SiQube <[email protected]>
Co-authored-by: Daniel G. Krakowczyk <[email protected]>
  • Loading branch information
3 people authored Jan 8, 2025
1 parent 29b806d commit a0ce2e1
Show file tree
Hide file tree
Showing 2 changed files with 131 additions and 1 deletion.
32 changes: 32 additions & 0 deletions src/pymovements/stimulus/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
"""Module for the TextDataFrame."""
from __future__ import annotations

from collections.abc import Sequence
from pathlib import Path
from typing import Any

Expand Down Expand Up @@ -80,6 +81,37 @@ def __init__(
self.end_y_column = end_y_column
self.page_column = page_column

def split(
self,
by: str | Sequence[str],
) -> list[TextStimulus]:
"""Split the AOI df.
Parameters
----------
by: str | Sequence[str]
Splitting criteria.
Returns
-------
list[TextStimulus]
A list of TextStimulus objects.
"""
return [
TextStimulus(
aois=df,
aoi_column=self.aoi_column,
width_column=self.width_column,
height_column=self.height_column,
start_x_column=self.start_x_column,
start_y_column=self.start_y_column,
end_x_column=self.end_x_column,
end_y_column=self.end_y_column,
page_column=self.page_column,
)
for df in self.aois.partition_by(by=by, as_dict=False)
]


def from_file(
aoi_path: str | Path,
Expand Down
100 changes: 99 additions & 1 deletion tests/unit/stimulus/text_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@
Path('tests/files/toy_text_1_1_aoi.csv'),
{'separator': ','},
EXPECTED_DF,
id='toy_text_1_1_aoi',
id='toy_text_1_1_aoi_sep',
),
],
)
Expand Down Expand Up @@ -226,3 +226,101 @@ def test_text_stimulus_unsupported_format():
expected = 'unsupported file format ".pickle".Supported formats are: '\
'[\'.csv\', \'.ias\', \'.tsv\', \'.txt\']'
assert msg == expected


@pytest.mark.parametrize(
('aoi_file', 'custom_read_kwargs'),
[
pytest.param(
'tests/files/toy_text_1_1_aoi.csv',
None,
id='toy_text_1_1_aoi',
),
pytest.param(
Path('tests/files/toy_text_1_1_aoi.csv'),
{'separator': ','},
id='toy_text_1_1_aoi_sep',
),
],
)
def test_text_stimulus_splitting(aoi_file, custom_read_kwargs):
aois_df = pm.stimulus.text.from_file(
aoi_file,
aoi_column='char',
start_x_column='top_left_x',
start_y_column='top_left_y',
width_column='width',
height_column='height',
page_column='page',
custom_read_kwargs=custom_read_kwargs,
)

aois_df = aois_df.split(by='line_idx')
assert len(aois_df) == 2


@pytest.mark.parametrize(
('aoi_file', 'custom_read_kwargs'),
[
pytest.param(
'tests/files/toy_text_1_1_aoi.csv',
None,
id='toy_text_1_1_aoi',
),
pytest.param(
Path('tests/files/toy_text_1_1_aoi.csv'),
{'separator': ','},
id='toy_text_1_1_aoi_sep',
),
],
)
def test_text_stimulus_splitting_unique_within(aoi_file, custom_read_kwargs):
aois_df = pm.stimulus.text.from_file(
aoi_file,
aoi_column='char',
start_x_column='top_left_x',
start_y_column='top_left_y',
width_column='width',
height_column='height',
page_column='page',
custom_read_kwargs=custom_read_kwargs,
)

aois_df = aois_df.split(by='line_idx')
assert all(df.aois.n_unique(subset=['line_idx']) == 1 for df in aois_df)


@pytest.mark.parametrize(
('aoi_file', 'custom_read_kwargs'),
[
pytest.param(
'tests/files/toy_text_1_1_aoi.csv',
None,
id='toy_text_1_1_aoi',
),
pytest.param(
Path('tests/files/toy_text_1_1_aoi.csv'),
{'separator': ','},
id='toy_text_1_1_aoi_sep',
),
],
)
def test_text_stimulus_splitting_different_between(aoi_file, custom_read_kwargs):
aois_df = pm.stimulus.text.from_file(
aoi_file,
aoi_column='char',
start_x_column='top_left_x',
start_y_column='top_left_y',
width_column='width',
height_column='height',
page_column='page',
custom_read_kwargs=custom_read_kwargs,
)

aois_df = aois_df.split(by='line_idx')
unique_values = []
for df in aois_df:
unique_value = df.aois.unique(subset=['line_idx'])['line_idx'].to_list()
unique_values.extend(unique_value)

assert len(unique_values) == len(set(unique_values))

0 comments on commit a0ce2e1

Please sign in to comment.