diff --git a/src/pymovements/gaze/gaze_dataframe.py b/src/pymovements/gaze/gaze_dataframe.py index 7579c4c55..46e02089f 100644 --- a/src/pymovements/gaze/gaze_dataframe.py +++ b/src/pymovements/gaze/gaze_dataframe.py @@ -49,6 +49,8 @@ class GazeDataFrame: The experiment definition. (default: None) events: pm.EventDataFrame | None A dataframe of events in the gaze signal. (default: None) + auto_column_detect: bool + Flag indicating if the column names should be inferred automatically. (default: False) trial_columns: str | list[str] | None The name of the trial columns in the input data frame. If the list is empty or None, the input data frame is assumed to contain only one trial. If the list is not empty, @@ -180,6 +182,7 @@ def __init__( experiment: Experiment | None = None, events: pm.EventDataFrame | None = None, *, + auto_column_detect: bool = False, trial_columns: str | list[str] | None = None, time_column: str | None = None, time_unit: str | None = 'ms', @@ -226,21 +229,39 @@ def __init__( # The list will be used for inferring n_components. column_specifiers: list[list[str]] = [] + component_suffixes = ['x', 'y', 'xl', 'yl', 'xr', 'yr', 'xa', 'ya'] + + if auto_column_detect and pixel_columns is None: + column_canditates = ['pixel_' + suffix for suffix in component_suffixes] + pixel_columns = [c for c in column_canditates if c in self.frame.columns] + if pixel_columns: self._check_component_columns(pixel_columns=pixel_columns) self.nest(pixel_columns, output_column='pixel') column_specifiers.append(pixel_columns) + if auto_column_detect and position_columns is None: + column_canditates = ['position_' + suffix for suffix in component_suffixes] + position_columns = [c for c in column_canditates if c in self.frame.columns] + if position_columns: self._check_component_columns(position_columns=position_columns) self.nest(position_columns, output_column='position') column_specifiers.append(position_columns) + if auto_column_detect and velocity_columns is None: + column_canditates = ['velocity_' + suffix for suffix in component_suffixes] + velocity_columns = [c for c in column_canditates if c in self.frame.columns] + if velocity_columns: self._check_component_columns(velocity_columns=velocity_columns) self.nest(velocity_columns, output_column='velocity') column_specifiers.append(velocity_columns) + if auto_column_detect and acceleration_columns is None: + column_canditates = ['acceleration_' + suffix for suffix in component_suffixes] + acceleration_columns = [c for c in column_canditates if c in self.frame.columns] + if acceleration_columns: self._check_component_columns(acceleration_columns=acceleration_columns) self.nest(acceleration_columns, output_column='acceleration') diff --git a/tests/unit/gaze/gaze_init_test.py b/tests/unit/gaze/gaze_init_test.py index a6cae88a3..3ed210fde 100644 --- a/tests/unit/gaze/gaze_init_test.py +++ b/tests/unit/gaze/gaze_init_test.py @@ -896,6 +896,102 @@ 2, id='df_three_rows_two_position_columns_no_time_1000_hz', ), + + pytest.param( + { + 'data': pl.from_dict( + { + 'time': [1, 2, 3], + 'pixel_x': [0., 1., 2.], + 'pixel_y': [3., 4., 5.], + }, + schema={'time': pl.Int64, 'pixel_x': pl.Float64, 'pixel_y': pl.Float64}, + ), + 'auto_column_detect': True, + }, + pl.from_dict( + { + 'time': [1, 2, 3], + 'pixel': [[0., 3.], [1., 4.], [2., 5.]], + }, + schema={'time': pl.Int64, 'pixel': pl.List(pl.Float64)}, + ), + 2, + id='df_auto_columns_pixel', + ), + + pytest.param( + { + 'data': pl.from_dict( + { + 'time': [1, 2, 3], + 'position_x': [0., 1., 2.], + 'position_y': [3., 4., 5.], + }, + schema={'time': pl.Int64, 'position_x': pl.Float64, 'position_y': pl.Float64}, + ), + 'auto_column_detect': True, + }, + pl.from_dict( + { + 'time': [1, 2, 3], + 'position': [[0., 3.], [1., 4.], [2., 5.]], + }, + schema={'time': pl.Int64, 'position': pl.List(pl.Float64)}, + ), + 2, + id='df_auto_columns_position', + ), + + pytest.param( + { + 'data': pl.from_dict( + { + 'time': [1, 2, 3], + 'velocity_x': [0., 1., 2.], + 'velocity_y': [3., 4., 5.], + }, + schema={'time': pl.Int64, 'velocity_x': pl.Float64, 'velocity_y': pl.Float64}, + ), + 'auto_column_detect': True, + }, + pl.from_dict( + { + 'time': [1, 2, 3], + 'velocity': [[0., 3.], [1., 4.], [2., 5.]], + }, + schema={'time': pl.Int64, 'velocity': pl.List(pl.Float64)}, + ), + 2, + id='df_auto_columns_velocity', + ), + + pytest.param( + { + 'data': pl.from_dict( + { + 'time': [1, 2, 3], + 'acceleration_x': [0., 1., 2.], + 'acceleration_y': [3., 4., 5.], + }, + schema={ + 'time': pl.Int64, + 'acceleration_x': pl.Float64, + 'acceleration_y': pl.Float64, + }, + ), + 'auto_column_detect': True, + }, + pl.from_dict( + { + 'time': [1, 2, 3], + 'acceleration': [[0., 3.], [1., 4.], [2., 5.]], + }, + schema={'time': pl.Int64, 'acceleration': pl.List(pl.Float64)}, + ), + 2, + id='df_auto_columns_acceleration', + ), ], ) def test_init_gaze_dataframe_has_expected_attrs(init_kwargs, expected_frame, expected_n_components):