From 1e2eac18c5bd53f5a9ca2c3d679adc3eb4613964 Mon Sep 17 00:00:00 2001 From: prassepaul Date: Tue, 9 Jul 2024 08:18:16 +0200 Subject: [PATCH] feat: add autodetect of column names (#719) --- src/pymovements/gaze/gaze_dataframe.py | 21 ++++++ tests/unit/gaze/gaze_init_test.py | 96 ++++++++++++++++++++++++++ 2 files changed, 117 insertions(+) diff --git a/src/pymovements/gaze/gaze_dataframe.py b/src/pymovements/gaze/gaze_dataframe.py index 7579c4c55..46e02089f 100644 --- a/src/pymovements/gaze/gaze_dataframe.py +++ b/src/pymovements/gaze/gaze_dataframe.py @@ -49,6 +49,8 @@ class GazeDataFrame: The experiment definition. (default: None) events: pm.EventDataFrame | None A dataframe of events in the gaze signal. (default: None) + auto_column_detect: bool + Flag indicating if the column names should be inferred automatically. (default: False) trial_columns: str | list[str] | None The name of the trial columns in the input data frame. If the list is empty or None, the input data frame is assumed to contain only one trial. If the list is not empty, @@ -180,6 +182,7 @@ def __init__( experiment: Experiment | None = None, events: pm.EventDataFrame | None = None, *, + auto_column_detect: bool = False, trial_columns: str | list[str] | None = None, time_column: str | None = None, time_unit: str | None = 'ms', @@ -226,21 +229,39 @@ def __init__( # The list will be used for inferring n_components. column_specifiers: list[list[str]] = [] + component_suffixes = ['x', 'y', 'xl', 'yl', 'xr', 'yr', 'xa', 'ya'] + + if auto_column_detect and pixel_columns is None: + column_canditates = ['pixel_' + suffix for suffix in component_suffixes] + pixel_columns = [c for c in column_canditates if c in self.frame.columns] + if pixel_columns: self._check_component_columns(pixel_columns=pixel_columns) self.nest(pixel_columns, output_column='pixel') column_specifiers.append(pixel_columns) + if auto_column_detect and position_columns is None: + column_canditates = ['position_' + suffix for suffix in component_suffixes] + position_columns = [c for c in column_canditates if c in self.frame.columns] + if position_columns: self._check_component_columns(position_columns=position_columns) self.nest(position_columns, output_column='position') column_specifiers.append(position_columns) + if auto_column_detect and velocity_columns is None: + column_canditates = ['velocity_' + suffix for suffix in component_suffixes] + velocity_columns = [c for c in column_canditates if c in self.frame.columns] + if velocity_columns: self._check_component_columns(velocity_columns=velocity_columns) self.nest(velocity_columns, output_column='velocity') column_specifiers.append(velocity_columns) + if auto_column_detect and acceleration_columns is None: + column_canditates = ['acceleration_' + suffix for suffix in component_suffixes] + acceleration_columns = [c for c in column_canditates if c in self.frame.columns] + if acceleration_columns: self._check_component_columns(acceleration_columns=acceleration_columns) self.nest(acceleration_columns, output_column='acceleration') diff --git a/tests/unit/gaze/gaze_init_test.py b/tests/unit/gaze/gaze_init_test.py index a6cae88a3..3ed210fde 100644 --- a/tests/unit/gaze/gaze_init_test.py +++ b/tests/unit/gaze/gaze_init_test.py @@ -896,6 +896,102 @@ 2, id='df_three_rows_two_position_columns_no_time_1000_hz', ), + + pytest.param( + { + 'data': pl.from_dict( + { + 'time': [1, 2, 3], + 'pixel_x': [0., 1., 2.], + 'pixel_y': [3., 4., 5.], + }, + schema={'time': pl.Int64, 'pixel_x': pl.Float64, 'pixel_y': pl.Float64}, + ), + 'auto_column_detect': True, + }, + pl.from_dict( + { + 'time': [1, 2, 3], + 'pixel': [[0., 3.], [1., 4.], [2., 5.]], + }, + schema={'time': pl.Int64, 'pixel': pl.List(pl.Float64)}, + ), + 2, + id='df_auto_columns_pixel', + ), + + pytest.param( + { + 'data': pl.from_dict( + { + 'time': [1, 2, 3], + 'position_x': [0., 1., 2.], + 'position_y': [3., 4., 5.], + }, + schema={'time': pl.Int64, 'position_x': pl.Float64, 'position_y': pl.Float64}, + ), + 'auto_column_detect': True, + }, + pl.from_dict( + { + 'time': [1, 2, 3], + 'position': [[0., 3.], [1., 4.], [2., 5.]], + }, + schema={'time': pl.Int64, 'position': pl.List(pl.Float64)}, + ), + 2, + id='df_auto_columns_position', + ), + + pytest.param( + { + 'data': pl.from_dict( + { + 'time': [1, 2, 3], + 'velocity_x': [0., 1., 2.], + 'velocity_y': [3., 4., 5.], + }, + schema={'time': pl.Int64, 'velocity_x': pl.Float64, 'velocity_y': pl.Float64}, + ), + 'auto_column_detect': True, + }, + pl.from_dict( + { + 'time': [1, 2, 3], + 'velocity': [[0., 3.], [1., 4.], [2., 5.]], + }, + schema={'time': pl.Int64, 'velocity': pl.List(pl.Float64)}, + ), + 2, + id='df_auto_columns_velocity', + ), + + pytest.param( + { + 'data': pl.from_dict( + { + 'time': [1, 2, 3], + 'acceleration_x': [0., 1., 2.], + 'acceleration_y': [3., 4., 5.], + }, + schema={ + 'time': pl.Int64, + 'acceleration_x': pl.Float64, + 'acceleration_y': pl.Float64, + }, + ), + 'auto_column_detect': True, + }, + pl.from_dict( + { + 'time': [1, 2, 3], + 'acceleration': [[0., 3.], [1., 4.], [2., 5.]], + }, + schema={'time': pl.Int64, 'acceleration': pl.List(pl.Float64)}, + ), + 2, + id='df_auto_columns_acceleration', + ), ], ) def test_init_gaze_dataframe_has_expected_attrs(init_kwargs, expected_frame, expected_n_components):