From a0ce2e180257a3a93a5b0e2b3cfc5f3329afb111 Mon Sep 17 00:00:00 2001 From: Iza <38069449+izaskr@users.noreply.github.com> Date: Wed, 8 Jan 2025 13:23:48 +0100 Subject: [PATCH] feat: split `TextStimulus` by column values (#879) Co-authored-by: SiQube Co-authored-by: Daniel G. Krakowczyk --- src/pymovements/stimulus/text.py | 32 ++++++++++ tests/unit/stimulus/text_test.py | 100 ++++++++++++++++++++++++++++++- 2 files changed, 131 insertions(+), 1 deletion(-) diff --git a/src/pymovements/stimulus/text.py b/src/pymovements/stimulus/text.py index 91b75e9a..f98afb92 100644 --- a/src/pymovements/stimulus/text.py +++ b/src/pymovements/stimulus/text.py @@ -20,6 +20,7 @@ """Module for the TextDataFrame.""" from __future__ import annotations +from collections.abc import Sequence from pathlib import Path from typing import Any @@ -80,6 +81,37 @@ def __init__( self.end_y_column = end_y_column self.page_column = page_column + def split( + self, + by: str | Sequence[str], + ) -> list[TextStimulus]: + """Split the AOI df. + + Parameters + ---------- + by: str | Sequence[str] + Splitting criteria. + + Returns + ------- + list[TextStimulus] + A list of TextStimulus objects. + """ + return [ + TextStimulus( + aois=df, + aoi_column=self.aoi_column, + width_column=self.width_column, + height_column=self.height_column, + start_x_column=self.start_x_column, + start_y_column=self.start_y_column, + end_x_column=self.end_x_column, + end_y_column=self.end_y_column, + page_column=self.page_column, + ) + for df in self.aois.partition_by(by=by, as_dict=False) + ] + def from_file( aoi_path: str | Path, diff --git a/tests/unit/stimulus/text_test.py b/tests/unit/stimulus/text_test.py index 48ca93c7..0963f3f7 100644 --- a/tests/unit/stimulus/text_test.py +++ b/tests/unit/stimulus/text_test.py @@ -187,7 +187,7 @@ Path('tests/files/toy_text_1_1_aoi.csv'), {'separator': ','}, EXPECTED_DF, - id='toy_text_1_1_aoi', + id='toy_text_1_1_aoi_sep', ), ], ) @@ -226,3 +226,101 @@ def test_text_stimulus_unsupported_format(): expected = 'unsupported file format ".pickle".Supported formats are: '\ '[\'.csv\', \'.ias\', \'.tsv\', \'.txt\']' assert msg == expected + + +@pytest.mark.parametrize( + ('aoi_file', 'custom_read_kwargs'), + [ + pytest.param( + 'tests/files/toy_text_1_1_aoi.csv', + None, + id='toy_text_1_1_aoi', + ), + pytest.param( + Path('tests/files/toy_text_1_1_aoi.csv'), + {'separator': ','}, + id='toy_text_1_1_aoi_sep', + ), + ], +) +def test_text_stimulus_splitting(aoi_file, custom_read_kwargs): + aois_df = pm.stimulus.text.from_file( + aoi_file, + aoi_column='char', + start_x_column='top_left_x', + start_y_column='top_left_y', + width_column='width', + height_column='height', + page_column='page', + custom_read_kwargs=custom_read_kwargs, + ) + + aois_df = aois_df.split(by='line_idx') + assert len(aois_df) == 2 + + +@pytest.mark.parametrize( + ('aoi_file', 'custom_read_kwargs'), + [ + pytest.param( + 'tests/files/toy_text_1_1_aoi.csv', + None, + id='toy_text_1_1_aoi', + ), + pytest.param( + Path('tests/files/toy_text_1_1_aoi.csv'), + {'separator': ','}, + id='toy_text_1_1_aoi_sep', + ), + ], +) +def test_text_stimulus_splitting_unique_within(aoi_file, custom_read_kwargs): + aois_df = pm.stimulus.text.from_file( + aoi_file, + aoi_column='char', + start_x_column='top_left_x', + start_y_column='top_left_y', + width_column='width', + height_column='height', + page_column='page', + custom_read_kwargs=custom_read_kwargs, + ) + + aois_df = aois_df.split(by='line_idx') + assert all(df.aois.n_unique(subset=['line_idx']) == 1 for df in aois_df) + + +@pytest.mark.parametrize( + ('aoi_file', 'custom_read_kwargs'), + [ + pytest.param( + 'tests/files/toy_text_1_1_aoi.csv', + None, + id='toy_text_1_1_aoi', + ), + pytest.param( + Path('tests/files/toy_text_1_1_aoi.csv'), + {'separator': ','}, + id='toy_text_1_1_aoi_sep', + ), + ], +) +def test_text_stimulus_splitting_different_between(aoi_file, custom_read_kwargs): + aois_df = pm.stimulus.text.from_file( + aoi_file, + aoi_column='char', + start_x_column='top_left_x', + start_y_column='top_left_y', + width_column='width', + height_column='height', + page_column='page', + custom_read_kwargs=custom_read_kwargs, + ) + + aois_df = aois_df.split(by='line_idx') + unique_values = [] + for df in aois_df: + unique_value = df.aois.unique(subset=['line_idx'])['line_idx'].to_list() + unique_values.extend(unique_value) + + assert len(unique_values) == len(set(unique_values))