diff --git a/src/pymovements/gaze/gaze_dataframe.py b/src/pymovements/gaze/gaze_dataframe.py index f1e48f58b..a95181eef 100644 --- a/src/pymovements/gaze/gaze_dataframe.py +++ b/src/pymovements/gaze/gaze_dataframe.py @@ -187,55 +187,31 @@ def __init__( if time_column is not None: self.frame = self.frame.rename({time_column: 'time'}) - n_components = None + # List of passed not-None column specifier lists. + # The list will be used for inferring n_components. + column_specifiers: list[list[str]] = [] + if pixel_columns: - _check_component_columns( - frame=self.frame, - pixel_columns=pixel_columns, - ) - self.nest( - input_columns=pixel_columns, - output_column='pixel', - ) - n_components = len(pixel_columns) + _check_component_columns(self.frame, pixel_columns=pixel_columns) + self.nest(pixel_columns, output_column='pixel') + column_specifiers.append(pixel_columns) if position_columns: - _check_component_columns( - frame=self.frame, - position_columns=position_columns, - ) - - self.nest( - input_columns=position_columns, - output_column='position', - ) - n_components = len(position_columns) + _check_component_columns(self.frame, position_columns=position_columns) + self.nest(position_columns, output_column='position') + column_specifiers.append(position_columns) if velocity_columns: - _check_component_columns( - frame=self.frame, - velocity_columns=velocity_columns, - ) - - self.nest( - input_columns=velocity_columns, - output_column='velocity', - ) - n_components = len(velocity_columns) + _check_component_columns(self.frame, velocity_columns=velocity_columns) + self.nest(velocity_columns, output_column='velocity') + column_specifiers.append(velocity_columns) if acceleration_columns: - _check_component_columns( - frame=self.frame, - acceleration_columns=acceleration_columns, - ) + _check_component_columns(self.frame, acceleration_columns=acceleration_columns) + self.nest(acceleration_columns, output_column='acceleration') + column_specifiers.append(acceleration_columns) - self.nest( - input_columns=acceleration_columns, - output_column='acceleration', - ) - n_components = len(acceleration_columns) - - self.n_components = n_components + self.n_components = _infer_n_components(self.frame, column_specifiers) self.experiment = experiment def transform( @@ -397,6 +373,8 @@ def nest( output_column: Name of the resulting tuple column. """ + _check_component_columns(frame=self.frame, **{output_column: input_columns}) + self.frame = self.frame.with_columns( pl.concat_list([pl.col(component) for component in input_columns]) .alias(output_column), @@ -526,3 +504,48 @@ def _check_n_components(n_components: Any) -> None: """Check that n_components is either 2, 4 or 6.""" if n_components not in {2, 4, 6}: raise AttributeError(f'n_components must be either 2, 4 or 6 but is {n_components}') + + +def _infer_n_components(frame: pl.DataFrame, column_specifiers: list[list[str]]) -> int | None: + """Infer number of components from DataFrame. + + Method checks nested columns `pixel`, `position`, `velocity` and `acceleration` for number of + components by getting their list lenghts, which must be equal for all else a ValueError is + raised. Additionally, a list of list of column specifiers is checked for consistency. + + Parameters + ---------- + frame: pl.DataFrame + DataFrame to check. + column_specifiers: + List of list of column specifiers. + + Returns + ------- + int or None + Number of components + + Raises + ------ + ValueError + If number of components is not equal for all considered columns and rows. + """ + all_considered_columns = ['pixel', 'position', 'velocity', 'acceleration'] + considered_columns = [column for column in all_considered_columns if column in frame.columns] + + list_lengths = { + list_length + for column in considered_columns + for list_length in frame.get_column(column).list.lengths().unique().to_list() + } + + for column_specifier_list in column_specifiers: + list_lengths.add(len(column_specifier_list)) + + if len(list_lengths) > 1: + raise ValueError(f'inconsistent number of components inferred: {list_lengths}') + + if len(list_lengths) == 0: + return None + + return next(iter(list_lengths)) diff --git a/tests/gaze/gaze_dataframe_unnest_test.py b/tests/gaze/gaze_dataframe_unnest_test.py index 8060bf768..a21fd7056 100644 --- a/tests/gaze/gaze_dataframe_unnest_test.py +++ b/tests/gaze/gaze_dataframe_unnest_test.py @@ -26,40 +26,39 @@ @pytest.mark.parametrize( - ('init_data', 'unnest_kwargs', 'n_components', 'expected'), + ('init_data', 'unnest_kwargs', 'expected'), [ pytest.param( pl.DataFrame(schema={'pixel': pl.List(pl.Float64)}), {'column': 'pixel', 'output_columns': ['x', 'y']}, - 2, pl.DataFrame(schema={'x': pl.Float64, 'y': pl.Float64}), id='empty_df_with_schema_two_pixel_columns', + marks=pytest.mark.xfail(reason='#522'), ), pytest.param( pl.DataFrame(schema={'abc': pl.Int64, 'pixel': pl.List(pl.Float64)}), {'column': 'pixel', 'output_columns': ['x', 'y']}, - 2, pl.DataFrame(schema={'abc': pl.Int64, 'x': pl.Float64, 'y': pl.Float64}), id='empty_df_with_three_column_schema_two_pixel_columns', + marks=pytest.mark.xfail(reason='#522'), ), pytest.param( pl.DataFrame(schema={'pixel': pl.List(pl.Float64)}), {'column': 'pixel', 'output_columns': ['xl', 'yl', 'xr', 'yr']}, - 4, pl.DataFrame( schema={ 'xl': pl.Float64, 'yl': pl.Float64, 'xr': pl.Float64, 'yr': pl.Float64, }, ), id='empty_df_with_schema_four_pixel_columns', + marks=pytest.mark.xfail(reason='#522'), ), pytest.param( pl.DataFrame(schema={'pixel': pl.List(pl.Float64)}), {'column': 'pixel', 'output_columns': ['xl', 'yl', 'xr', 'yr', 'xa', 'ya']}, - 6, pl.DataFrame( schema={ 'xl': pl.Float64, 'yl': pl.Float64, @@ -68,12 +67,12 @@ }, ), id='empty_df_with_schema_six_pixel_columns', + marks=pytest.mark.xfail(reason='#522'), ), pytest.param( pl.DataFrame({'pixel': [[1.23, 4.56]]}), {'column': 'pixel', 'output_columns': ['x', 'y']}, - 2, pl.DataFrame({'x': [1.23], 'y': [4.56]}), id='df_single_row_two_pixel_columns', ), @@ -81,7 +80,6 @@ pytest.param( pl.DataFrame({'abc': [1], 'pixel': [[1.23, 4.56]]}), {'column': 'pixel', 'output_columns': ['x', 'y']}, - 2, pl.DataFrame({'abc': [1], 'x': [1.23], 'y': [4.56]}), id='df_single_row_three_columns_two_pixel_columns', ), @@ -89,7 +87,6 @@ pytest.param( pl.DataFrame({'pixel': [[1.2, 3.4, 5.6, 7.8]]}), {'column': 'pixel', 'output_columns': ['xl', 'yl', 'xr', 'yr']}, - 4, pl.DataFrame({'xl': [1.2], 'yl': [3.4], 'xr': [5.6], 'yr': [7.8]}), id='df_single_row_four_pixel_columns', ), @@ -97,36 +94,24 @@ pytest.param( pl.DataFrame({'pixel': [[0.1, 0.2, 0.3, 0.4, 0.5, 0.6]]}), {'column': 'pixel', 'output_columns': ['xl', 'yl', 'xr', 'yr', 'xa', 'ya']}, - 6, pl.DataFrame({'xl': [.1], 'yl': [.2], 'xr': [.3], 'yr': [.4], 'xa': [.5], 'ya': [.6]}), id='df_single_row_six_pixel_columns', ), - ], -) -def test_gaze_dataframe_unnest_has_expected_frame(init_data, unnest_kwargs, n_components, expected): - gaze = pm.GazeDataFrame(init_data) - gaze.n_components = n_components - gaze.unnest(**unnest_kwargs) - assert_frame_equal(gaze.frame, expected) - -@pytest.mark.parametrize( - ('init_data', 'unnest_kwargs', 'n_components', 'expected'), - [ pytest.param( pl.DataFrame(schema={'pixel': pl.List(pl.Float64)}), {'column': 'pixel', 'output_suffixes': ['_x', '_y'], 'output_columns': None}, - 2, pl.DataFrame(schema={'pixel_x': pl.Float64, 'pixel_y': pl.Float64}), id='empty_df_with_schema_two_pixel_suffixes_columns_none', + marks=pytest.mark.xfail(reason='#522'), ), pytest.param( pl.DataFrame(schema={'abc': pl.Int64, 'pixel': pl.List(pl.Float64)}), {'column': 'pixel', 'output_suffixes': ['_x', '_y'], 'output_columns': None}, - 2, pl.DataFrame(schema={'abc': pl.Int64, 'pixel_x': pl.Float64, 'pixel_y': pl.Float64}), id='empty_df_with_three_column_schema_two_pixel_suffixes_columns_none', + marks=pytest.mark.xfail(reason='#522'), ), pytest.param( @@ -136,7 +121,6 @@ def test_gaze_dataframe_unnest_has_expected_frame(init_data, unnest_kwargs, n_co '_xl', '_yl', '_xr', '_yr', ], 'output_columns': None, }, - 4, pl.DataFrame( schema={ 'pixel_xl': pl.Float64, 'pixel_yl': pl.Float64, @@ -144,12 +128,12 @@ def test_gaze_dataframe_unnest_has_expected_frame(init_data, unnest_kwargs, n_co }, ), id='empty_df_with_schema_four_pixel_suffixes_columns_none', + marks=pytest.mark.xfail(reason='#522'), ), pytest.param( pl.DataFrame(schema={'pixel': pl.List(pl.Float64)}), {'column': 'pixel', 'output_suffixes': ['_xl', '_yl', '_xr', '_yr', '_xa', '_ya']}, - 6, pl.DataFrame( schema={ 'pixel_xl': pl.Float64, 'pixel_yl': pl.Float64, @@ -158,12 +142,12 @@ def test_gaze_dataframe_unnest_has_expected_frame(init_data, unnest_kwargs, n_co }, ), id='empty_df_with_schema_six_pixel_columns', + marks=pytest.mark.xfail(reason='#522'), ), pytest.param( pl.DataFrame({'pixel': [[1.23, 4.56]]}), {'column': 'pixel', 'output_suffixes': ['_x', '_y']}, - 2, pl.DataFrame({'pixel_x': [1.23], 'pixel_y': [4.56]}), id='df_single_row_two_pixel_suffixes', ), @@ -171,7 +155,6 @@ def test_gaze_dataframe_unnest_has_expected_frame(init_data, unnest_kwargs, n_co pytest.param( pl.DataFrame({'abc': [1], 'pixel': [[1.23, 4.56]]}), {'column': 'pixel', 'output_suffixes': ['_x', '_y']}, - 2, pl.DataFrame({'abc': [1], 'pixel_x': [1.23], 'pixel_y': [4.56]}), id='df_single_row_three_columns_two_pixel_suffixes', ), @@ -179,7 +162,6 @@ def test_gaze_dataframe_unnest_has_expected_frame(init_data, unnest_kwargs, n_co pytest.param( pl.DataFrame({'pixel': [[1.2, 3.4, 5.6, 7.8]]}), {'column': 'pixel', 'output_suffixes': ['_xl', '_yl', '_xr', '_yr']}, - 4, pl.DataFrame({ 'pixel_xl': [1.2], 'pixel_yl': [3.4], 'pixel_xr': [5.6], 'pixel_yr': [7.8], @@ -190,7 +172,6 @@ def test_gaze_dataframe_unnest_has_expected_frame(init_data, unnest_kwargs, n_co pytest.param( pl.DataFrame({'pixel': [[0.1, 0.2, 0.3, 0.4, 0.5, 0.6]]}), {'column': 'pixel', 'output_suffixes': ['_xl', '_yl', '_xr', '_yr', '_xa', '_ya']}, - 6, pl.DataFrame({ 'pixel_xl': [.1], 'pixel_yl': [.2], 'pixel_xr': [.3], 'pixel_yr': [.4], @@ -202,7 +183,6 @@ def test_gaze_dataframe_unnest_has_expected_frame(init_data, unnest_kwargs, n_co pytest.param( pl.DataFrame({'pixel': [[1.23, 4.56]]}), {'column': 'pixel'}, - 2, pl.DataFrame({'pixel_x': [1.23], 'pixel_y': [4.56]}), id='df_single_row_two_pixel_suffixes_default_values', ), @@ -210,7 +190,6 @@ def test_gaze_dataframe_unnest_has_expected_frame(init_data, unnest_kwargs, n_co pytest.param( pl.DataFrame({'abc': [1], 'pixel': [[1.23, 4.56]]}), {'column': 'pixel'}, - 2, pl.DataFrame({'abc': [1], 'pixel_x': [1.23], 'pixel_y': [4.56]}), id='df_single_row_three_columns_two_pixel_suffixes_default_values', ), @@ -218,7 +197,6 @@ def test_gaze_dataframe_unnest_has_expected_frame(init_data, unnest_kwargs, n_co pytest.param( pl.DataFrame({'pixel': [[1.2, 3.4, 5.6, 7.8]]}), {'column': 'pixel'}, - 4, pl.DataFrame({ 'pixel_xl': [1.2], 'pixel_yl': [3.4], 'pixel_xr': [5.6], 'pixel_yr': [7.8], @@ -229,7 +207,6 @@ def test_gaze_dataframe_unnest_has_expected_frame(init_data, unnest_kwargs, n_co pytest.param( pl.DataFrame({'pixel': [[0.1, 0.2, 0.3, 0.4, 0.5, 0.6]]}), {'column': 'pixel'}, - 6, pl.DataFrame({ 'pixel_xl': [.1], 'pixel_yl': [.2], 'pixel_xr': [.3], 'pixel_yr': [.4], @@ -239,20 +216,18 @@ def test_gaze_dataframe_unnest_has_expected_frame(init_data, unnest_kwargs, n_co ), ], ) -def test_gaze_dataframe_unnest_suffixes(init_data, unnest_kwargs, n_components, expected): +def test_gaze_dataframe_unnest_has_expected_frame(init_data, unnest_kwargs, expected): gaze = pm.GazeDataFrame(init_data) - gaze.n_components = n_components gaze.unnest(**unnest_kwargs) assert_frame_equal(gaze.frame, expected) @pytest.mark.parametrize( - ('init_data', 'unnest_kwargs', 'n_components', 'exception', 'exception_msg'), + ('init_data', 'unnest_kwargs', 'exception', 'exception_msg'), [ pytest.param( pl.DataFrame({'pixel': [[1.23, 4.56]]}), {'column': 'pixel', 'output_suffixes': ['_x', '_y', '_z']}, - 2, ValueError, 'Number of output columns / suffixes (3) must match number of components (2)', id='df_single_row_two_pixel_components_three_output_suffixes', @@ -260,7 +235,6 @@ def test_gaze_dataframe_unnest_suffixes(init_data, unnest_kwargs, n_components, pytest.param( pl.DataFrame({'pixel': [[1.23, 4.56]]}), {'column': 'pixel', 'output_columns': ['x']}, - 2, ValueError, 'Number of output columns / suffixes (1) must match number of components (2)', id='df_single_row_two_pixel_components_one_output_column', @@ -268,7 +242,6 @@ def test_gaze_dataframe_unnest_suffixes(init_data, unnest_kwargs, n_components, pytest.param( pl.DataFrame({'pixel': [[1.23, 4.56]]}), {'column': 'pixel', 'output_suffixes': ['_x', '_x']}, - 2, ValueError, 'Output columns / suffixes must be unique', id='df_single_row_two_output_suffixes_non_unique', @@ -276,7 +249,6 @@ def test_gaze_dataframe_unnest_suffixes(init_data, unnest_kwargs, n_components, pytest.param( pl.DataFrame({'pixel': [[1.23, 4.56]]}), {'column': 'pixel', 'output_columns': ['x', 'x']}, - 2, ValueError, 'Output columns / suffixes must be unique', id='df_single_row_two_output_columns_non_unique', @@ -284,12 +256,25 @@ def test_gaze_dataframe_unnest_suffixes(init_data, unnest_kwargs, n_components, pytest.param( pl.DataFrame({'pixel': [[1.23, 4.56]]}), {'column': 'pixel', 'output_suffixes': ['_x', '_y'], 'output_columns': ['x', 'y']}, - 2, ValueError, 'The arguments "output_columns" and "output_suffixes" are mutually exclusive.', id='df_single_row_two_output_columns_and_suffixes', ), - # invalid number of components + ], + +) +def test_gaze_dataframe_unnest_errors(init_data, unnest_kwargs, exception, exception_msg): + with pytest.raises(exception) as exc_info: + gaze = pm.GazeDataFrame(init_data) + gaze.unnest(**unnest_kwargs) + + msg, = exc_info.value.args + assert msg == exception_msg + + +@pytest.mark.parametrize( + ('init_data', 'unnest_kwargs', 'n_components', 'exception', 'exception_msg'), + [ pytest.param( pl.DataFrame({'pixel': [[1.23, 4.56]]}), {'column': 'pixel', 'output_suffixes': ['_x', '_y']}, @@ -301,7 +286,7 @@ def test_gaze_dataframe_unnest_suffixes(init_data, unnest_kwargs, n_components, ], ) -def test_gaze_dataframe_unnest_errors( +def test_gaze_dataframe_unnest_invalid_number_of_components( init_data, unnest_kwargs, n_components, exception, exception_msg, ): with pytest.raises(exception) as exc_info: diff --git a/tests/gaze/gaze_init_test.py b/tests/gaze/gaze_init_test.py index efbb5939b..f74b87c89 100644 --- a/tests/gaze/gaze_init_test.py +++ b/tests/gaze/gaze_init_test.py @@ -18,6 +18,7 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. """Test GazeDataFrame initialization.""" +import numpy as np import polars as pl import pytest from polars.testing import assert_frame_equal @@ -26,13 +27,14 @@ @pytest.mark.parametrize( - ('init_kwargs', 'expected_frame'), + ('init_kwargs', 'expected_frame', 'expected_n_components'), [ pytest.param( { 'data': pl.DataFrame(), }, pl.DataFrame(), + None, id='empty_df_no_schema', ), @@ -41,6 +43,7 @@ 'data': pl.DataFrame(schema={'abc': pl.Int64}), }, pl.DataFrame(schema={'abc': pl.Int64}), + None, id='empty_df_with_schema_no_component_columns', ), @@ -53,6 +56,7 @@ 'acceleration_columns': [], }, pl.DataFrame(schema={'abc': pl.Int64}), + None, id='empty_df_with_schema_all_component_columns_empty_lists', ), @@ -62,6 +66,7 @@ 'pixel_columns': ['x', 'y'], }, pl.DataFrame(schema={'pixel': pl.List(pl.Float64)}), + 2, id='empty_df_with_schema_two_pixel_columns', ), @@ -71,6 +76,7 @@ 'pixel_columns': ['x', 'y'], }, pl.DataFrame(schema={'abc': pl.Int64, 'pixel': pl.List(pl.Float64)}), + 2, id='empty_df_with_three_column_schema_two_pixel_columns', ), @@ -84,6 +90,7 @@ 'pixel_columns': ['xr', 'yr', 'xl', 'yl'], }, pl.DataFrame(schema={'pixel': pl.List(pl.Float64)}), + 4, id='empty_df_with_schema_four_pixel_columns', ), @@ -101,6 +108,7 @@ ], }, pl.DataFrame(schema={'pixel': pl.List(pl.Float64)}), + 6, id='empty_df_with_schema_six_pixel_columns', ), @@ -115,6 +123,7 @@ {'pixel': [[1.23, 4.56]]}, schema={'pixel': pl.List(pl.Float64)}, ), + 2, id='df_single_row_two_pixel_columns', ), @@ -130,6 +139,7 @@ {'abc': [1], 'pixel': [[1.23, 4.56]]}, schema={'abc': pl.Int64, 'pixel': pl.List(pl.Float64)}, ), + 2, id='df_single_row_three_columns_two_pixel_columns', ), @@ -145,6 +155,7 @@ {'pixel': [[1.2, 3.4, 5.6, 7.8]]}, schema={'pixel': pl.List(pl.Float64)}, ), + 4, id='df_single_row_four_pixel_columns', ), @@ -170,6 +181,7 @@ {'pixel': [[0.1, 0.2, 0.3, 0.4, 0.5, 0.6]]}, schema={'pixel': pl.List(pl.Float64)}, ), + 6, id='df_single_row_six_pixel_columns', ), @@ -179,6 +191,7 @@ 'position_columns': ['x', 'y'], }, pl.DataFrame(schema={'position': pl.List(pl.Float64)}), + 2, id='empty_df_with_schema_two_position_columns', ), @@ -188,6 +201,7 @@ 'position_columns': ['x', 'y'], }, pl.DataFrame(schema={'abc': pl.Int64, 'position': pl.List(pl.Float64)}), + 2, id='empty_df_with_three_column_schema_two_position_columns', ), @@ -201,6 +215,7 @@ 'position_columns': ['xr', 'yr', 'xl', 'yl'], }, pl.DataFrame(schema={'position': pl.List(pl.Float64)}), + 4, id='empty_df_with_schema_four_position_columns', ), @@ -218,6 +233,7 @@ ], }, pl.DataFrame(schema={'position': pl.List(pl.Float64)}), + 6, id='empty_df_with_schema_six_position_columns', ), @@ -232,6 +248,7 @@ {'position': [[1.23, 4.56]]}, schema={'position': pl.List(pl.Float64)}, ), + 2, id='df_single_row_two_position_columns', ), @@ -247,6 +264,7 @@ {'abc': [1], 'position': [[1.23, 4.56]]}, schema={'abc': pl.Int64, 'position': pl.List(pl.Float64)}, ), + 2, id='df_single_row_three_columns_two_position_columns', ), @@ -262,6 +280,7 @@ {'position': [[1.2, 3.4, 5.6, 7.8]]}, schema={'position': pl.List(pl.Float64)}, ), + 4, id='df_single_row_four_position_columns', ), @@ -287,6 +306,7 @@ {'position': [[0.1, 0.2, 0.3, 0.4, 0.5, 0.6]]}, schema={'position': pl.List(pl.Float64)}, ), + 6, id='df_single_row_six_position_columns', ), @@ -296,6 +316,7 @@ 'velocity_columns': ['x_vel', 'y_vel'], }, pl.DataFrame(schema={'velocity': pl.List(pl.Float64)}), + 2, id='empty_df_with_schema_two_velocity_columns', ), @@ -309,6 +330,7 @@ 'velocity_columns': ['x_vel', 'y_vel'], }, pl.DataFrame(schema={'abc': pl.Int64, 'velocity': pl.List(pl.Float64)}), + 2, id='empty_df_with_three_column_schema_two_velocity_columns', ), @@ -323,6 +345,7 @@ 'velocity_columns': ['xr_vel', 'yr_vel', 'xl_vel', 'yl_vel'], }, pl.DataFrame(schema={'velocity': pl.List(pl.Float64)}), + 4, id='empty_df_with_schema_four_velocity_columns', ), @@ -342,6 +365,7 @@ ], }, pl.DataFrame(schema={'velocity': pl.List(pl.Float64)}), + 6, id='empty_df_with_schema_six_velocity_columns', ), @@ -357,6 +381,7 @@ {'velocity': [[1.23, 4.56]]}, schema={'velocity': pl.List(pl.Float64)}, ), + 2, id='df_single_row_two_velocity_columns', ), @@ -372,6 +397,7 @@ {'abc': [1], 'velocity': [[1.23, 4.56]]}, schema={'abc': pl.Int64, 'velocity': pl.List(pl.Float64)}, ), + 2, id='df_single_row_three_columns_two_velocity_columns', ), @@ -393,6 +419,7 @@ {'velocity': [[1.2, 3.4, 5.6, 7.8]]}, schema={'velocity': pl.List(pl.Float64)}, ), + 4, id='df_single_row_four_velocity_columns', ), @@ -420,6 +447,7 @@ {'velocity': [[0.1, 0.2, 0.3, 0.4, 0.5, 0.6]]}, schema={'velocity': pl.List(pl.Float64)}, ), + 6, id='df_single_row_six_velocity_columns', ), @@ -429,6 +457,7 @@ 'acceleration_columns': ['x_acc', 'y_acc'], }, pl.DataFrame(schema={'acceleration': pl.List(pl.Float64)}), + 2, id='empty_df_with_schema_two_acceleration_columns', ), @@ -442,6 +471,7 @@ 'acceleration_columns': ['x_acc', 'y_acc'], }, pl.DataFrame(schema={'abc': pl.Int64, 'acceleration': pl.List(pl.Float64)}), + 2, id='empty_df_with_three_column_schema_two_acceleration_columns', ), @@ -456,6 +486,7 @@ 'acceleration_columns': ['xr_acc', 'yr_acc', 'xl_acc', 'yl_acc'], }, pl.DataFrame(schema={'acceleration': pl.List(pl.Float64)}), + 4, id='empty_df_with_schema_four_acceleration_columns', ), @@ -475,6 +506,7 @@ ], }, pl.DataFrame(schema={'acceleration': pl.List(pl.Float64)}), + 6, id='empty_df_with_schema_six_acceleration_columns', ), @@ -490,6 +522,7 @@ {'acceleration': [[1.23, 4.56]]}, schema={'acceleration': pl.List(pl.Float64)}, ), + 2, id='df_single_row_two_acceleration_columns', ), @@ -505,6 +538,7 @@ {'abc': [1], 'acceleration': [[1.23, 4.56]]}, schema={'abc': pl.Int64, 'acceleration': pl.List(pl.Float64)}, ), + 2, id='df_single_row_three_columns_two_acceleration_columns', ), @@ -526,6 +560,7 @@ {'acceleration': [[1.2, 3.4, 5.6, 7.8]]}, schema={'acceleration': pl.List(pl.Float64)}, ), + 4, id='df_single_row_four_acceleration_columns', ), @@ -553,6 +588,7 @@ {'acceleration': [[0.1, 0.2, 0.3, 0.4, 0.5, 0.6]]}, schema={'acceleration': pl.List(pl.Float64)}, ), + 6, id='df_single_row_six_acceleration_columns', ), @@ -591,6 +627,7 @@ 'acceleration': pl.List(pl.Float64), }, ), + 2, id='df_single_row_all_types_two_columns', ), @@ -661,13 +698,15 @@ 'acceleration': pl.List(pl.Float64), }, ), + 6, id='df_single_row_all_types_six_columns', ), ], ) -def test_init_gaze_dataframe_has_expected_frame(init_kwargs, expected_frame): +def test_init_gaze_dataframe_has_expected_attrs(init_kwargs, expected_frame, expected_n_components): gaze = GazeDataFrame(**init_kwargs) assert_frame_equal(gaze.frame, expected_frame) + assert gaze.n_components == expected_n_components @pytest.mark.parametrize( @@ -1114,11 +1153,43 @@ def test_init_gaze_dataframe_has_expected_frame(init_kwargs, expected_frame): 'column y_acc from acceleration_columns is not available in dataframe', id='acceleration_columns_missing_column', ), + + pytest.param( + { + 'data': pl.DataFrame( + schema={ + 'x': pl.Float64, 'y': pl.Float64, + 'xr': pl.Float64, 'yr': pl.Float64, + 'xl': pl.Float64, 'yl': pl.Float64, + }, + ), + 'pixel_columns': ['x', 'y'], + 'position_columns': ['xl', 'yl', 'xr', 'yr'], + }, + ValueError, + 'inconsistent number of components inferred: {2, 4}', + id='inconsistent_number_of_components', + ), + ], ) -def test_event_dataframe_init_exceptions(init_kwargs, exception, exception_msg): +def test_gaze_dataframe_init_exceptions(init_kwargs, exception, exception_msg): with pytest.raises(exception) as excinfo: GazeDataFrame(**init_kwargs) msg, = excinfo.value.args assert msg == exception_msg + + +def test_gaze_copy_init_has_same_n_components(): + """Tests if gaze initialization with frame with nested columns has correct n_components. + + Refers to issue #514. + """ + df_orig = pl.from_numpy(np.zeros((2, 1000)), orient='col', schema=['x', 'y']) + gaze = GazeDataFrame(df_orig, position_columns=['x', 'y']) + + df_copy = gaze.frame.clone() + gaze_copy = GazeDataFrame(df_copy) + + assert gaze.n_components == gaze_copy.n_components diff --git a/tests/gaze/gaze_transform_test.py b/tests/gaze/gaze_transform_test.py index 66c7d24d8..24fea7b24 100644 --- a/tests/gaze/gaze_transform_test.py +++ b/tests/gaze/gaze_transform_test.py @@ -522,7 +522,7 @@ def test_gaze_dataframe_pix2deg_creates_position_column(data, experiment, pixel_ ), pytest.param( { - 'data': pl.DataFrame(schema={'x': pl.Float64, 'y': pl.Float64}), + 'data': pl.from_dict({'x': [0.1], 'y': [0.2]}), 'experiment': pm.Experiment(1024, 768, 38, 30, 60, 'center', 1000), 'acceleration_columns': ['x', 'y'], }, @@ -603,7 +603,7 @@ def test_gaze_dataframe_pos2acc_creates_acceleration_column(data, experiment, po ), pytest.param( { - 'data': pl.DataFrame(schema={'x': pl.Float64, 'y': pl.Float64}), + 'data': pl.from_dict({'x': [0.1], 'y': [0.2]}), 'experiment': pm.Experiment(1024, 768, 38, 30, 60, 'center', 1000), 'pixel_columns': ['x', 'y'], }, @@ -684,7 +684,7 @@ def test_gaze_dataframe_pos2vel_creates_velocity_column(data, experiment, positi ), pytest.param( { - 'data': pl.DataFrame(schema={'x': pl.Float64, 'y': pl.Float64}), + 'data': pl.from_dict({'x': [0.1], 'y': [0.2]}), 'experiment': pm.Experiment(1024, 768, 38, 30, 60, 'center', 1000), 'pixel_columns': ['x', 'y'], },