diff --git a/src/pymovements/gaze/gaze_dataframe.py b/src/pymovements/gaze/gaze_dataframe.py index 36f0f20f3..598369712 100644 --- a/src/pymovements/gaze/gaze_dataframe.py +++ b/src/pymovements/gaze/gaze_dataframe.py @@ -405,6 +405,52 @@ def pos2vel( """ self.transform('pos2vel', method=method, **kwargs) + def smooth( + self, + method: str = 'savitzky_golay', + window_length: int = 7, + degree: int = 2, + column: str = 'position', + padding: str | float | int | None = 'nearest', + **kwargs: int | float | str, + ) -> None: + """Smooth data in a column. + + Parameters + ---------- + method: + The method to use for smoothing. Choose from ``savitzky_golay``, ``moving_average``, + ``exponential_moving_average``. See :func:`~transforms.smooth()` for details. + window_length: + For ``moving_average`` this is the window size to calculate the mean of the subsequent + samples. For ``savitzky_golay`` this is the window size to use for the polynomial fit. + For ``exponential_moving_average`` this is the span parameter. + degree: + The degree of the polynomial to use. This has only an effect if using + ``savitzky_golay`` as smoothing method. `degree` must be less than `window_length`. + column: + The input column name to which the smoothing is applied. + padding: + Must be either ``None``, a scalar or one of the strings + ``mirror``, ``nearest`` or ``wrap``. + This determines the type of extension to use for the padded signal to + which the filter is applied. + When passing ``None``, no extension padding is used. + When passing a scalar value, data will be padded using the passed value. + See :func:`~transforms.smooth()` for details on the padding methods. + **kwargs: + Additional keyword arguments to be passed to the :func:`~transforms.smooth()` method. + """ + self.transform( + 'smooth', + column=column, + method=method, + degree=degree, + window_length=window_length, + padding=padding, + **kwargs, + ) + def detect( self, method: Callable[..., pm.EventDataFrame] | str, diff --git a/src/pymovements/gaze/transforms.py b/src/pymovements/gaze/transforms.py index 915b61080..3713367df 100644 --- a/src/pymovements/gaze/transforms.py +++ b/src/pymovements/gaze/transforms.py @@ -31,7 +31,6 @@ from pymovements.utils import checks - TransformMethod = TypeVar('TransformMethod', bound=Callable[..., pl.Expr]) @@ -558,6 +557,165 @@ def savitzky_golay( ).alias(output_column) +@register_transform +def smooth( + *, + method: str, + window_length: int, + n_components: int, + degree: int | None = None, + column: str = 'position', + padding: str | float | int | None = 'nearest', +) -> pl.Expr: + """ + Smooth data in a column. + + Parameters + ---------- + method: + The method to use for smoothing. See Notes for more details. + window_length + For ``moving_average`` this is the window size to calculate the mean of the subsequent + samples. For ``savitzky_golay`` this is the window size to use for the polynomial fit. + For ``exponential_moving_average`` this is the span parameter. + n_components: + Number of components in input column. + degree: + The degree of the polynomial to use. This has only an effect if using ``savitzky_golay`` as + smoothing method. `degree` must be less than `window_length`. + column: + The input column name to which the smoothing is applied. + padding: + Must be either ``None``, a scalar or one of the strings ``mirror``, ``nearest`` or ``wrap``. + This determines the type of extension to use for the padded signal to + which the filter is applied. + When passing ``None``, no extension padding is used. + When passing a scalar value, data will be padded using the passed value. + See the Notes for more details on the padding methods. + + Returns + ------- + polars.Expr + The respective polars expression. + + Notes + ----- + There following methods are available for smoothing: + + * ``savitzky_golay``: Smooth data by applying a Savitzky-Golay filter. + See :py:func:`~pymovements.gaze.transforms.savitzky_golay` for further details. + * ``moving_average``: Smooth data by calculating the mean of the subsequent samples. + Each smoothed sample is calculated by the mean of the samples in the window around the sample. + * ``exponential_moving_average``: Smooth data by exponentially weighted moving average. + + Details on the `padding` options: + + * ``None``: No padding extension is used. + * scalar value (int or float): The padding extension contains the specified scalar value. + * ``mirror``: Repeats the values at the edges in reverse order. The value closest to the edge is + not included. + * ``nearest``: The padding extension contains the nearest input value. + * ``wrap``: The padding extension contains the values from the other end of the array. + + Given the input is ``[1, 2, 3, 4, 5, 6, 7, 8]``, and + `window_length` is 7, the following table shows the padded data for + the various ``padding`` options: + + +-------------+-------------+----------------------------+-------------+ + | mode | padding | input | padding | + +=============+=============+============================+=============+ + | ``None`` | ``- - -`` | ``1 2 3 4 5 6 7 8`` | ``- - -`` | + +-------------+-------------+----------------------------+-------------+ + | ``0`` | ``0 0 0`` | ``1 2 3 4 5 6 7 8`` | ``0 0 0`` | + +-------------+-------------+----------------------------+-------------+ + | ``1`` | ``1 1 1`` | ``1 2 3 4 5 6 7 8`` | ``1 1 1`` | + +-------------+-------------+----------------------------+-------------+ + | ``nearest`` | ``1 1 1`` | ``1 2 3 4 5 6 7 8`` | ``8 8 8`` | + +-------------+-------------+----------------------------+-------------+ + | ``mirror`` | ``4 3 2`` | ``1 2 3 4 5 6 7 8`` | ``7 6 5`` | + +-------------+-------------+----------------------------+-------------+ + | ``wrap`` | ``6 7 8`` | ``1 2 3 4 5 6 7 8`` | ``1 2 3`` | + +-------------+-------------+----------------------------+-------------+ + + """ + _check_window_length(window_length=window_length) + _check_padding(padding=padding) + + if method in {'moving_average', 'exponential_moving_average'}: + pad_kwargs: dict[str, Any] = {'pad_width': 0} + pad_func = _identity + + if isinstance(padding, (int, float)): + pad_kwargs['constant_values'] = padding + padding = 'constant' + elif padding == 'nearest': + # option 'nearest' is called 'edge' for np.pad + padding = 'edge' + elif padding == 'mirror': + # option 'mirror' is called 'reflect' for np.pad + padding = 'reflect' + + if padding is not None: + pad_kwargs['mode'] = padding + pad_kwargs['pad_width'] = np.ceil(window_length / 2).astype(int) + + pad_func = partial( + np.pad, + **pad_kwargs, + ) + + if method == 'moving_average': + + return pl.concat_list( + [ + pl.col(column).list.get(component).map(pad_func).list.explode() + .rolling_mean(window_size=window_length, center=True) + .shift(periods=pad_kwargs['pad_width']) + .slice(pad_kwargs['pad_width'] * 2) + for component in range(n_components) + ], + ).alias(column) + + return pl.concat_list( + [ + pl.col(column).list.get(component).map(pad_func).list.explode() + .ewm_mean( + span=window_length, + adjust=False, + min_periods=window_length, + ).shift(periods=pad_kwargs['pad_width']) + .slice(pad_kwargs['pad_width'] * 2) + for component in range(n_components) + ], + ).alias(column) + + if method == 'savitzky_golay': + if degree is None: + raise TypeError("'degree' must not be none for method 'savitzky_golay'") + + return savitzky_golay( + window_length=window_length, + degree=degree, + sampling_rate=1, + padding=padding, + derivative=0, + n_components=n_components, + input_column=column, + output_column=None, + ) + + supported_methods = ['moving_average', 'exponential_moving_average', 'savitzky_golay'] + + raise ValueError( + f"Unknown method '{method}'. Supported methods are: {supported_methods}", + ) + + +def _identity(x: Any) -> Any: + """Identity function as placeholder for None as padding.""" + return x + + def _check_window_length(window_length: Any) -> None: """Check that window length is an integer and greater than zero.""" checks.check_is_not_none(window_length=window_length) diff --git a/tests/gaze/gaze_transform_test.py b/tests/gaze/gaze_transform_test.py index f92ee5c5b..8ec8b039b 100644 --- a/tests/gaze/gaze_transform_test.py +++ b/tests/gaze/gaze_transform_test.py @@ -439,6 +439,50 @@ def fixture_experiment(): ), id='pos2vel_preceding_trialize_single_column_str', ), + pytest.param( + { + 'data': pl.from_dict( + { + 'x_dva': [1.0, 1.0, 1.0, 1.0, 1.0, 1.0], + 'y_dva': [2.0, 2.0, 2.0, 2.0, 2.0, 2.0], + }, + ), + 'position_columns': ['x_dva', 'y_dva'], + }, + 'smooth', {'method': 'moving_average', 'window_length': 3}, + pm.GazeDataFrame( + data=pl.from_dict( + { + 'x_dva': [1.0, 1.0, 1.0, 1.0, 1.0, 1.0], + 'y_dva': [2.0, 2.0, 2.0, 2.0, 2.0, 2.0], + }, + ), + position_columns=['x_dva', 'y_dva'], + ), + id='smooth', + ), + pytest.param( + { + 'data': pl.from_dict( + { + 'x_dva': [1.0, 1.0, 1.0, 1.0, 1.0, 1.0], + 'y_dva': [2.0, 2.0, 2.0, 2.0, 2.0, 2.0], + }, + ), + 'position_columns': ['x_dva', 'y_dva'], + }, + pm.gaze.transforms.smooth, {'method': 'moving_average', 'window_length': 3}, + pm.GazeDataFrame( + data=pl.from_dict( + { + 'x_dva': [1.0, 1.0, 1.0, 1.0, 1.0, 1.0], + 'y_dva': [2.0, 2.0, 2.0, 2.0, 2.0, 2.0], + }, + ), + position_columns=['x_dva', 'y_dva'], + ), + id='smooth_method_pass', + ), ], ) @@ -750,3 +794,75 @@ def test_gaze_dataframe_pos2vel_exceptions(init_kwargs, exception, expected_msg) msg, = excinfo.value.args assert msg == expected_msg + + +@pytest.mark.parametrize( + ('gaze_init_kwargs', 'kwargs', 'expected'), + [ + pytest.param( + { + 'data': pl.from_dict( + { + 'x_pix': [0.0, 1.0, 0.0, 0.0, 1.0, 0.0], + 'y_pix': [0.0, 1.0, 0.0, 0.0, 1.0, 0.0], + 'x_dva': [0.0, 1.0, 0.0, 0.0, 1.0, 0.0], + 'y_dva': [0.0, 1.0, 0.0, 0.0, 1.0, 0.0], + }, + ), + 'pixel_columns': ['x_pix', 'y_pix'], + 'position_columns': ['x_dva', 'y_dva'], + }, + {'method': 'moving_average', 'column': 'pixel', 'window_length': 3}, + pm.GazeDataFrame( + data=pl.from_dict( + { + + 'x_pix': [1 / 3, 1 / 3, 1 / 3, 1 / 3, 1 / 3, 1 / 3], + 'y_pix': [1 / 3, 1 / 3, 1 / 3, 1 / 3, 1 / 3, 1 / 3], + 'x_dva': [0.0, 1.0, 0.0, 0.0, 1.0, 0.0], + 'y_dva': [0.0, 1.0, 0.0, 0.0, 1.0, 0.0], + }, + ), + pixel_columns=['x_pix', 'y_pix'], + position_columns=['x_dva', 'y_dva'], + ), + id='pixel', + ), + pytest.param( + { + 'data': pl.from_dict( + { + 'x_pix': [0.0, 1.0, 0.0, 0.0, 1.0, 0.0], + 'y_pix': [0.0, 1.0, 0.0, 0.0, 1.0, 0.0], + 'x_dva': [0.0, 1.0, 0.0, 0.0, 1.0, 0.0], + 'y_dva': [0.0, 1.0, 0.0, 0.0, 1.0, 0.0], + }, + ), + 'pixel_columns': ['x_pix', 'y_pix'], + 'position_columns': ['x_dva', 'y_dva'], + }, + {'method': 'moving_average', 'column': 'position', 'window_length': 3}, + pm.GazeDataFrame( + data=pl.from_dict( + { + + 'x_pix': [0.0, 1.0, 0.0, 0.0, 1.0, 0.0], + 'y_pix': [0.0, 1.0, 0.0, 0.0, 1.0, 0.0], + 'x_dva': [1 / 3, 1 / 3, 1 / 3, 1 / 3, 1 / 3, 1 / 3], + 'y_dva': [1 / 3, 1 / 3, 1 / 3, 1 / 3, 1 / 3, 1 / 3], + }, + ), + pixel_columns=['x_pix', 'y_pix'], + position_columns=['x_dva', 'y_dva'], + ), + id='position', + ), + ], +) +def test_gaze_dataframe_smooth_expected_column( + gaze_init_kwargs, kwargs, expected, +): + gaze = pm.GazeDataFrame(**gaze_init_kwargs) + gaze.smooth(**kwargs) + + assert_frame_equal(gaze.frame, expected.frame) diff --git a/tests/gaze/transforms/smooth_test.py b/tests/gaze/transforms/smooth_test.py new file mode 100644 index 000000000..4d1d34674 --- /dev/null +++ b/tests/gaze/transforms/smooth_test.py @@ -0,0 +1,413 @@ +# Copyright (c) 2022-2023 The pymovements Project Authors +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +"""Test pymovements.gaze.transforms.smooth.""" +import polars as pl +import pytest +from polars.testing import assert_frame_equal + +import pymovements as pm + + +@pytest.mark.parametrize( + 'kwargs, series, expected_df', + [ + # Method: moving_average + pytest.param( + { + 'method': 'moving_average', + 'n_components': 2, + 'window_length': 1, + 'padding': 'nearest', + }, + pl.Series('position', [[0., 0.], [1., 1.], [0., 0.]], pl.List(pl.Float64)), + pl.Series('position', [[0., 0.], [1., 1.], [0., 0.]], pl.List(pl.Float64)), + id='moving_average_window_length_1_returns_same_series', + ), + pytest.param( + { + 'method': 'moving_average', + 'n_components': 2, + 'window_length': 2, + 'padding': None, + }, + pl.Series('position', [[0., 0.], [1., 1.], [0., 0.]], pl.List(pl.Float64)), + pl.Series( + 'position', [ + [None, None], [1 / 2, 1 / 2], + [1 / 2, 1 / 2], + ], pl.List(pl.Float64), + ), + id='moving_average_window_length_2_no_padding', + ), + pytest.param( + { + 'method': 'moving_average', + 'n_components': 2, + 'window_length': 2, + 'padding': 0.0, + }, + pl.Series('position', [[0., 0.], [1., 1.], [0., 0.]], pl.List(pl.Float64)), + pl.Series('position', [[0., 0.], [1 / 2, 1 / 2], [1 / 2, 1 / 2]], pl.List(pl.Float64)), + id='moving_average_window_length_2_constant_padding_0', + ), + pytest.param( + { + 'method': 'moving_average', + 'n_components': 2, + 'window_length': 2, + 'padding': 1.0, + }, + pl.Series('position', [[0., 0.], [1., 1.], [0., 0.]], pl.List(pl.Float64)), + pl.Series( + 'position', [[1 / 2, 1 / 2], [1 / 2, 1 / 2], [1 / 2, 1 / 2]], pl.List(pl.Float64), + ), + id='moving_average_window_length_2_constant_padding_1', + ), + pytest.param( + { + 'method': 'moving_average', + 'n_components': 2, + 'window_length': 2, + 'padding': 'nearest', + }, + pl.Series('position', [[0., 0.], [1., 1.], [0., 0.]], pl.List(pl.Float64)), + pl.Series('position', [[0., 0.], [1 / 2, 1 / 2], [1 / 2, 1 / 2]], pl.List(pl.Float64)), + id='moving_average_window_length_2_nearest_padding', + ), + pytest.param( + { + 'method': 'moving_average', + 'n_components': 2, + 'window_length': 2, + 'padding': 'mirror', + }, + pl.Series('position', [[0., 0.], [1., 1.], [0., 0.]], pl.List(pl.Float64)), + pl.Series( + 'position', [ + [1 / 2, 1 / 2], [1 / 2, 1 / 2], + [1 / 2, 1 / 2], + ], pl.List(pl.Float64), + ), + id='moving_average_window_length_2_mirror_padding', + ), + pytest.param( + { + 'method': 'moving_average', + 'n_components': 2, + 'window_length': 3, + 'padding': None, + }, + pl.Series('position', [[0., 0.], [1., 1.], [0., 0.]], pl.List(pl.Float64)), + pl.Series( + 'position', [ + [None, None], [1 / 3, 1 / 3], + [None, None], + ], pl.List(pl.Float64), + ), + id='moving_average_window_length_3_no_padding', + ), + pytest.param( + { + 'method': 'moving_average', + 'n_components': 2, + 'window_length': 3, + 'padding': 0.0, + }, + pl.Series('position', [[0., 0.], [1., 1.], [0., 0.]], pl.List(pl.Float64)), + pl.Series( + 'position', [ + [1 / 3, 1 / 3], [1 / 3, 1 / 3], + [1 / 3, 1 / 3], + ], pl.List(pl.Float64), + ), + id='moving_average_window_length_3_constant_padding_0', + ), + pytest.param( + { + 'method': 'moving_average', + 'n_components': 2, + 'window_length': 3, + 'padding': 1.0, + }, + pl.Series('position', [[0., 0.], [1., 1.], [0., 0.]], pl.List(pl.Float64)), + pl.Series( + 'position', [[2 / 3, 2 / 3], [1 / 3, 1 / 3], [2 / 3, 2 / 3]], pl.List(pl.Float64), + ), + id='moving_average_window_length_3_constant_padding_1', + ), + pytest.param( + { + 'method': 'moving_average', + 'n_components': 2, + 'window_length': 3, + 'padding': 'nearest', + }, + pl.Series('position', [[0., 0.], [1., 1.], [0., 0.]], pl.List(pl.Float64)), + pl.Series( + 'position', [ + [1 / 3, 1 / 3], [1 / 3, 1 / 3], + [1 / 3, 1 / 3], + ], pl.List(pl.Float64), + ), + id='moving_average_window_length_3_nearest_padding', + ), + pytest.param( + { + 'method': 'moving_average', + 'n_components': 2, + 'window_length': 3, + 'padding': 'mirror', + }, + pl.Series('position', [[0., 0.], [1., 1.], [0., 0.]], pl.List(pl.Float64)), + pl.Series( + 'position', [ + [2 / 3, 2 / 3], [1 / 3, 1 / 3], + [2 / 3, 2 / 3], + ], pl.List(pl.Float64), + ), + id='moving_average_window_length_3_mirror_padding', + ), + # Method: exponential_moving_average + pytest.param( + { + 'method': 'exponential_moving_average', + 'n_components': 2, + 'window_length': 1, + 'padding': 'nearest', + }, + pl.Series('position', [[0., 0.], [1., 1.], [0., 0.]], pl.List(pl.Float64)), + pl.Series('position', [[0., 0.], [1., 1.], [0., 0.]], pl.List(pl.Float64)), + id='exponential_moving_average_window_length_1_returns_same_series', + ), + pytest.param( + { + 'method': 'exponential_moving_average', + 'n_components': 2, + 'window_length': 2, + 'padding': None, + }, + pl.Series('position', [[0., 0.], [1., 1.], [0., 0.]], pl.List(pl.Float64)), + pl.Series( + 'position', [ + [None, None], [2 / 3, 2 / 3], + [2 / 9, 2 / 9], + ], pl.List(pl.Float64), + ), + id='exponential_moving_average_window_length_2_no_padding', + ), + pytest.param( + { + 'method': 'exponential_moving_average', + 'n_components': 2, + 'window_length': 2, + 'padding': 0.0, + }, + pl.Series('position', [[0., 0.], [1., 1.], [0., 0.]], pl.List(pl.Float64)), + pl.Series('position', [[0., 0.], [2 / 3, 2 / 3], [2 / 9, 2 / 9]], pl.List(pl.Float64)), + id='exponential_moving_average_window_length_2_constant_padding_0', + ), + pytest.param( + { + 'method': 'exponential_moving_average', + 'n_components': 2, + 'window_length': 2, + 'padding': 1.0, + }, + pl.Series('position', [[0., 0.], [1., 1.], [0., 0.]], pl.List(pl.Float64)), + pl.Series( + 'position', [[1 / 3, 1 / 3], [7 / 9, 7 / 9], [7 / 27, 7 / 27]], pl.List(pl.Float64), + ), + id='exponential_moving_average_window_length_2_constant_padding_1', + ), + pytest.param( + { + 'method': 'exponential_moving_average', + 'n_components': 2, + 'window_length': 2, + 'padding': 'nearest', + }, + pl.Series('position', [[0., 0.], [1., 1.], [0., 0.]], pl.List(pl.Float64)), + pl.Series('position', [[0., 0.], [2 / 3, 2 / 3], [2 / 9, 2 / 9]], pl.List(pl.Float64)), + id='exponential_moving_average_window_length_2_nearest_padding', + ), + pytest.param( + { + 'method': 'exponential_moving_average', + 'n_components': 2, + 'window_length': 2, + 'padding': 'mirror', + }, + pl.Series('position', [[0., 0.], [1., 1.], [0., 0.]], pl.List(pl.Float64)), + pl.Series( + 'position', [[1 / 3, 1 / 3], [7 / 9, 7 / 9], [7 / 27, 7 / 27]], pl.List(pl.Float64), + ), + id='exponential_moving_average_window_length_2_mirror_padding', + ), + pytest.param( + { + 'method': 'exponential_moving_average', + 'n_components': 2, + 'window_length': 3, + 'padding': None, + }, + pl.Series('position', [[0., 0.], [1., 1.], [0., 0.]], pl.List(pl.Float64)), + pl.Series('position', [[None, None], [None, None], [0.25, 0.25]], pl.List(pl.Float64)), + id='exponential_moving_average_window_length_3_no_padding', + ), + pytest.param( + { + 'method': 'exponential_moving_average', + 'n_components': 2, + 'window_length': 3, + 'padding': 0.0, + }, + pl.Series('position', [[0., 0.], [1., 1.], [0., 0.]], pl.List(pl.Float64)), + pl.Series('position', [[0., 0.], [0.5, 0.5], [0.25, 0.25]], pl.List(pl.Float64)), + id='exponential_moving_average_window_length_3_constant_padding_0', + ), + pytest.param( + { + 'method': 'exponential_moving_average', + 'n_components': 2, + 'window_length': 3, + 'padding': 1.0, + }, + pl.Series('position', [[0., 0.], [1., 1.], [0., 0.]], pl.List(pl.Float64)), + pl.Series('position', [[0.5, 0.5], [0.75, 0.75], [0.375, 0.375]], pl.List(pl.Float64)), + id='exponential_moving_average_window_length_3_constant_padding_1', + ), + pytest.param( + { + 'method': 'exponential_moving_average', + 'n_components': 2, + 'window_length': 3, + 'padding': 'nearest', + }, + pl.Series('position', [[0., 0.], [1., 1.], [0., 0.]], pl.List(pl.Float64)), + pl.Series('position', [[0., 0.], [0.5, 0.5], [0.25, 0.25]], pl.List(pl.Float64)), + id='exponential_moving_average_window_length_3_nearest_padding', + ), + pytest.param( + { + 'method': 'exponential_moving_average', + 'n_components': 2, + 'window_length': 3, + 'padding': 'mirror', + }, + pl.Series('position', [[0., 0.], [1., 1.], [0., 0.]], pl.List(pl.Float64)), + pl.Series( + 'position', [[0.25, 0.25], [0.625, 0.625], [0.3125, 0.3125]], pl.List(pl.Float64), + ), + id='exponential_moving_average_window_length_3_mirror_padding', + ), + # Method: savitzky_golay + pytest.param( + { + 'method': 'savitzky_golay', + 'n_components': 2, + 'window_length': 2, + 'degree': 1, + }, + pl.Series('position', [[0., 0.], [1., 1.], [0., 0.]], pl.List(pl.Float64)), + pl.Series('position', [[0.5, 0.5], [0.5, 0.5], [0., 0.]], pl.List(pl.Float64)), + id='savitzky_golay_window_length_2_degree_1_returns_mean_of_window', + ), + pytest.param( + { + 'method': 'savitzky_golay', + 'n_components': 2, + 'window_length': 3, + 'degree': 1, + }, + pl.Series('position', [[0., 0.], [1., 1.], [0., 0.]], pl.List(pl.Float64)), + pl.Series( + 'position', [ + [1 / 3, 1 / 3], [1 / 3, 1 / 3], + [1 / 3, 1 / 3], + ], pl.List(pl.Float64), + ), + id='savitzky_golay_window_length_3_degree_1_returns_mean_of_window', + ), + pytest.param( + { + 'method': 'savitzky_golay', + 'n_components': 2, + 'window_length': 3, + 'degree': 2, + }, + pl.Series('position', [[0., 0.], [1., 1.], [0., 0.]], pl.List(pl.Float64)), + pl.Series('position', [[0., 0.], [1., 1.], [0., 0.]], pl.List(pl.Float64)), + id='savitzky_golay_window_length_3_degree_2_returns', + ), + ], +) +def test_smooth_returns(kwargs, series, expected_df): + """Test if smooth returns the expected dataframe.""" + df = series.to_frame() + + result_df = df.select( + pm.gaze.transforms.smooth(**kwargs), + ) + + assert_frame_equal(result_df, expected_df.to_frame()) + + +@pytest.mark.parametrize( + 'kwargs, exception, msg_substrings', + [ + pytest.param( + { + 'method': 'invalid_method', + 'n_components': 2, + 'window_length': 3, + }, + ValueError, + "Unkown method 'invalid_method'. Supported methods are: ", + id='invalid_method_raises_value_error', + ), + pytest.param( + { + 'method': 'savitzky_golay', + 'n_components': 2, + 'window_length': 3, + 'degree': None, + }, + TypeError, + "'degree' must not be none for method 'savitzky_golay'", + id='savitzky_golay_degree_none_raises_type_error', + ), + ], +) +def test_smooth_init_raises_error(kwargs, exception, msg_substrings): + """Test if smooth init raises the expected error.""" + with pytest.raises(exception) as excinfo: + pm.gaze.transforms.smooth(**kwargs) + + msg, = excinfo.value.args + for msg_substring in msg_substrings: + assert msg_substring.lower() in msg.lower() + + +def test_identity_returns_same_series(): + """Test if identity returns the same series.""" + series = pl.Series('position', [[1., 1.], [2., 2.], [3., 3.]], pl.List(pl.Float64)) + series_identity = pm.gaze.transforms._identity(series) + + assert_frame_equal(series.to_frame(), series_identity.to_frame())