Skip to content

Commit

Permalink
feat: add autodetect of column names (#719)
Browse files Browse the repository at this point in the history
  • Loading branch information
prassepaul authored Jul 9, 2024
1 parent 7c5180e commit 1e2eac1
Show file tree
Hide file tree
Showing 2 changed files with 117 additions and 0 deletions.
21 changes: 21 additions & 0 deletions src/pymovements/gaze/gaze_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ class GazeDataFrame:
The experiment definition. (default: None)
events: pm.EventDataFrame | None
A dataframe of events in the gaze signal. (default: None)
auto_column_detect: bool
Flag indicating if the column names should be inferred automatically. (default: False)
trial_columns: str | list[str] | None
The name of the trial columns in the input data frame. If the list is empty or None,
the input data frame is assumed to contain only one trial. If the list is not empty,
Expand Down Expand Up @@ -180,6 +182,7 @@ def __init__(
experiment: Experiment | None = None,
events: pm.EventDataFrame | None = None,
*,
auto_column_detect: bool = False,
trial_columns: str | list[str] | None = None,
time_column: str | None = None,
time_unit: str | None = 'ms',
Expand Down Expand Up @@ -226,21 +229,39 @@ def __init__(
# The list will be used for inferring n_components.
column_specifiers: list[list[str]] = []

component_suffixes = ['x', 'y', 'xl', 'yl', 'xr', 'yr', 'xa', 'ya']

if auto_column_detect and pixel_columns is None:
column_canditates = ['pixel_' + suffix for suffix in component_suffixes]
pixel_columns = [c for c in column_canditates if c in self.frame.columns]

if pixel_columns:
self._check_component_columns(pixel_columns=pixel_columns)
self.nest(pixel_columns, output_column='pixel')
column_specifiers.append(pixel_columns)

if auto_column_detect and position_columns is None:
column_canditates = ['position_' + suffix for suffix in component_suffixes]
position_columns = [c for c in column_canditates if c in self.frame.columns]

if position_columns:
self._check_component_columns(position_columns=position_columns)
self.nest(position_columns, output_column='position')
column_specifiers.append(position_columns)

if auto_column_detect and velocity_columns is None:
column_canditates = ['velocity_' + suffix for suffix in component_suffixes]
velocity_columns = [c for c in column_canditates if c in self.frame.columns]

if velocity_columns:
self._check_component_columns(velocity_columns=velocity_columns)
self.nest(velocity_columns, output_column='velocity')
column_specifiers.append(velocity_columns)

if auto_column_detect and acceleration_columns is None:
column_canditates = ['acceleration_' + suffix for suffix in component_suffixes]
acceleration_columns = [c for c in column_canditates if c in self.frame.columns]

if acceleration_columns:
self._check_component_columns(acceleration_columns=acceleration_columns)
self.nest(acceleration_columns, output_column='acceleration')
Expand Down
96 changes: 96 additions & 0 deletions tests/unit/gaze/gaze_init_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -896,6 +896,102 @@
2,
id='df_three_rows_two_position_columns_no_time_1000_hz',
),
pytest.param(
{
'data': pl.from_dict(
{
'time': [1, 2, 3],
'pixel_x': [0., 1., 2.],
'pixel_y': [3., 4., 5.],
},
schema={'time': pl.Int64, 'pixel_x': pl.Float64, 'pixel_y': pl.Float64},
),
'auto_column_detect': True,
},
pl.from_dict(
{
'time': [1, 2, 3],
'pixel': [[0., 3.], [1., 4.], [2., 5.]],
},
schema={'time': pl.Int64, 'pixel': pl.List(pl.Float64)},
),
2,
id='df_auto_columns_pixel',
),
pytest.param(
{
'data': pl.from_dict(
{
'time': [1, 2, 3],
'position_x': [0., 1., 2.],
'position_y': [3., 4., 5.],
},
schema={'time': pl.Int64, 'position_x': pl.Float64, 'position_y': pl.Float64},
),
'auto_column_detect': True,
},
pl.from_dict(
{
'time': [1, 2, 3],
'position': [[0., 3.], [1., 4.], [2., 5.]],
},
schema={'time': pl.Int64, 'position': pl.List(pl.Float64)},
),
2,
id='df_auto_columns_position',
),
pytest.param(
{
'data': pl.from_dict(
{
'time': [1, 2, 3],
'velocity_x': [0., 1., 2.],
'velocity_y': [3., 4., 5.],
},
schema={'time': pl.Int64, 'velocity_x': pl.Float64, 'velocity_y': pl.Float64},
),
'auto_column_detect': True,
},
pl.from_dict(
{
'time': [1, 2, 3],
'velocity': [[0., 3.], [1., 4.], [2., 5.]],
},
schema={'time': pl.Int64, 'velocity': pl.List(pl.Float64)},
),
2,
id='df_auto_columns_velocity',
),
pytest.param(
{
'data': pl.from_dict(
{
'time': [1, 2, 3],
'acceleration_x': [0., 1., 2.],
'acceleration_y': [3., 4., 5.],
},
schema={
'time': pl.Int64,
'acceleration_x': pl.Float64,
'acceleration_y': pl.Float64,
},
),
'auto_column_detect': True,
},
pl.from_dict(
{
'time': [1, 2, 3],
'acceleration': [[0., 3.], [1., 4.], [2., 5.]],
},
schema={'time': pl.Int64, 'acceleration': pl.List(pl.Float64)},
),
2,
id='df_auto_columns_acceleration',
),
],
)
def test_init_gaze_dataframe_has_expected_attrs(init_kwargs, expected_frame, expected_n_components):
Expand Down

0 comments on commit 1e2eac1

Please sign in to comment.