Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: autodetect column names #719

Merged
merged 9 commits into from
Jul 9, 2024
21 changes: 21 additions & 0 deletions src/pymovements/gaze/gaze_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ class GazeDataFrame:
The experiment definition. (default: None)
events: pm.EventDataFrame | None
A dataframe of events in the gaze signal. (default: None)
auto_column_detect: bool
Flag indicating if the column names should be inferred automatically. (default: False)
trial_columns: str | list[str] | None
The name of the trial columns in the input data frame. If the list is empty or None,
the input data frame is assumed to contain only one trial. If the list is not empty,
Expand Down Expand Up @@ -180,6 +182,7 @@ def __init__(
experiment: Experiment | None = None,
events: pm.EventDataFrame | None = None,
*,
auto_column_detect: bool = False,
trial_columns: str | list[str] | None = None,
time_column: str | None = None,
time_unit: str | None = 'ms',
Expand Down Expand Up @@ -226,21 +229,39 @@ def __init__(
# The list will be used for inferring n_components.
column_specifiers: list[list[str]] = []

component_suffixes = ['x', 'y', 'xl', 'yl', 'xr', 'yr', 'xa', 'ya']

if auto_column_detect and pixel_columns is None:
column_canditates = ['pixel_' + suffix for suffix in component_suffixes]
pixel_columns = [c for c in column_canditates if c in self.frame.columns]

if pixel_columns:
self._check_component_columns(pixel_columns=pixel_columns)
self.nest(pixel_columns, output_column='pixel')
column_specifiers.append(pixel_columns)

if auto_column_detect and position_columns is None:
column_canditates = ['position_' + suffix for suffix in component_suffixes]
position_columns = [c for c in column_canditates if c in self.frame.columns]

if position_columns:
self._check_component_columns(position_columns=position_columns)
self.nest(position_columns, output_column='position')
column_specifiers.append(position_columns)

if auto_column_detect and velocity_columns is None:
column_canditates = ['velocity_' + suffix for suffix in component_suffixes]
velocity_columns = [c for c in column_canditates if c in self.frame.columns]

if velocity_columns:
self._check_component_columns(velocity_columns=velocity_columns)
self.nest(velocity_columns, output_column='velocity')
column_specifiers.append(velocity_columns)

if auto_column_detect and acceleration_columns is None:
column_canditates = ['acceleration_' + suffix for suffix in component_suffixes]
acceleration_columns = [c for c in column_canditates if c in self.frame.columns]

if acceleration_columns:
self._check_component_columns(acceleration_columns=acceleration_columns)
self.nest(acceleration_columns, output_column='acceleration')
Expand Down
96 changes: 96 additions & 0 deletions tests/unit/gaze/gaze_init_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -896,6 +896,102 @@
2,
id='df_three_rows_two_position_columns_no_time_1000_hz',
),

pytest.param(
{
'data': pl.from_dict(
{
'time': [1, 2, 3],
'pixel_x': [0., 1., 2.],
'pixel_y': [3., 4., 5.],
},
schema={'time': pl.Int64, 'pixel_x': pl.Float64, 'pixel_y': pl.Float64},
),
'auto_column_detect': True,
},
pl.from_dict(
{
'time': [1, 2, 3],
'pixel': [[0., 3.], [1., 4.], [2., 5.]],
},
schema={'time': pl.Int64, 'pixel': pl.List(pl.Float64)},
),
2,
id='df_auto_columns_pixel',
),

pytest.param(
{
'data': pl.from_dict(
{
'time': [1, 2, 3],
'position_x': [0., 1., 2.],
'position_y': [3., 4., 5.],
},
schema={'time': pl.Int64, 'position_x': pl.Float64, 'position_y': pl.Float64},
),
'auto_column_detect': True,
},
pl.from_dict(
{
'time': [1, 2, 3],
'position': [[0., 3.], [1., 4.], [2., 5.]],
},
schema={'time': pl.Int64, 'position': pl.List(pl.Float64)},
),
2,
id='df_auto_columns_position',
),

pytest.param(
{
'data': pl.from_dict(
{
'time': [1, 2, 3],
'velocity_x': [0., 1., 2.],
'velocity_y': [3., 4., 5.],
},
schema={'time': pl.Int64, 'velocity_x': pl.Float64, 'velocity_y': pl.Float64},
),
'auto_column_detect': True,
},
pl.from_dict(
{
'time': [1, 2, 3],
'velocity': [[0., 3.], [1., 4.], [2., 5.]],
},
schema={'time': pl.Int64, 'velocity': pl.List(pl.Float64)},
),
2,
id='df_auto_columns_velocity',
),

pytest.param(
{
'data': pl.from_dict(
{
'time': [1, 2, 3],
'acceleration_x': [0., 1., 2.],
'acceleration_y': [3., 4., 5.],
},
schema={
'time': pl.Int64,
'acceleration_x': pl.Float64,
'acceleration_y': pl.Float64,
},
),
'auto_column_detect': True,
},
pl.from_dict(
{
'time': [1, 2, 3],
'acceleration': [[0., 3.], [1., 4.], [2., 5.]],
},
schema={'time': pl.Int64, 'acceleration': pl.List(pl.Float64)},
),
2,
id='df_auto_columns_acceleration',
),
],
)
def test_init_gaze_dataframe_has_expected_attrs(init_kwargs, expected_frame, expected_n_components):
Expand Down
Loading