diff --git a/src/pymovements/gaze/transforms.py b/src/pymovements/gaze/transforms.py index 69daa1f12..e395e2a80 100644 --- a/src/pymovements/gaze/transforms.py +++ b/src/pymovements/gaze/transforms.py @@ -568,7 +568,7 @@ def smooth( Number of components in input column. degree: The degree of the polynomial to use. This has only an effect if using ``savitzky_golay`` as - smoothing method. + smoothing method. `degree` must be less than `window_length`. padding: Must be either ``None``, a scalar or one of the strings ``mirror``, ``nearest`` or ``wrap``. This determines the type of extension to use for the padded signal to @@ -626,7 +626,7 @@ def smooth( _check_padding(padding=padding) if method in {'moving_average', 'exponential_moving_average'}: - pad_kwargs: dict[str, Any] = {'pad_width': 0.0} + pad_kwargs: dict[str, Any] = {'pad_width': 0} pad_func = _identity if isinstance(padding, (int, float)): @@ -641,7 +641,7 @@ def smooth( if padding is not None: pad_kwargs['mode'] = padding - pad_kwargs['pad_width'] = window_length // 2 + pad_kwargs['pad_width'] = np.ceil(window_length / 2).astype(int) pad_func = partial( np.pad, diff --git a/tests/gaze/transforms/smooth_test.py b/tests/gaze/transforms/smooth_test.py index 11d1aeeb4..fe1fcd6ef 100644 --- a/tests/gaze/transforms/smooth_test.py +++ b/tests/gaze/transforms/smooth_test.py @@ -28,16 +28,75 @@ @pytest.mark.parametrize( 'kwargs, series, expected_df', [ + # Method: moving_average pytest.param( { 'method': 'moving_average', 'n_components': 2, 'window_length': 1, + 'padding': 'nearest', }, - pl.Series('position', [[1., 1.], [2., 2.], [3., 3.]], pl.List(pl.Float64)), - pl.Series('position', [[1., 1.], [2., 2.], [3., 3.]], pl.List(pl.Float64)), + pl.Series('position', [[0., 0.], [1., 1.], [0., 0.]], pl.List(pl.Float64)), + pl.Series('position', [[0., 0.], [1., 1.], [0., 0.]], pl.List(pl.Float64)), id='moving_average_window_length_1_returns_same_series', ), + pytest.param( + { + 'method': 'moving_average', + 'n_components': 2, + 'window_length': 2, + 'padding': None, + }, + pl.Series('position', [[0., 0.], [1., 1.], [0., 0.]], pl.List(pl.Float64)), + pl.Series('position', [[None, None], [1/2, 1/2], [1/2, 1/2]], pl.List(pl.Float64)), + id='moving_average_window_length_2_no_padding', + ), + pytest.param( + { + 'method': 'moving_average', + 'n_components': 2, + 'window_length': 2, + 'padding': 0.0, + }, + pl.Series('position', [[0., 0.], [1., 1.], [0., 0.]], pl.List(pl.Float64)), + pl.Series('position', [[0., 0.], [1/2, 1/2], [1/2, 1/2]], pl.List(pl.Float64)), + id='moving_average_window_length_2_constant_padding_0', + ), + pytest.param( + { + 'method': 'moving_average', + 'n_components': 2, + 'window_length': 2, + 'padding': 1.0, + }, + pl.Series('position', [[0., 0.], [1., 1.], [0., 0.]], pl.List(pl.Float64)), + pl.Series( + 'position', [[1 / 2, 1 / 2], [1 / 2, 1 / 2], [1 / 2, 1 / 2]], pl.List(pl.Float64) + ), + id='moving_average_window_length_2_constant_padding_1', + ), + pytest.param( + { + 'method': 'moving_average', + 'n_components': 2, + 'window_length': 2, + 'padding': 'nearest', + }, + pl.Series('position', [[0., 0.], [1., 1.], [0., 0.]], pl.List(pl.Float64)), + pl.Series('position', [[0., 0.], [1/2, 1/2], [1/2, 1/2]], pl.List(pl.Float64)), + id='moving_average_window_length_2_nearest_padding', + ), + pytest.param( + { + 'method': 'moving_average', + 'n_components': 2, + 'window_length': 2, + 'padding': 'mirror', + }, + pl.Series('position', [[0., 0.], [1., 1.], [0., 0.]], pl.List(pl.Float64)), + pl.Series('position', [[1 / 2, 1/2], [1/2, 1/2], [1/2, 1/2]], pl.List(pl.Float64)), + id='moving_average_window_length_2_mirror_padding', + ), pytest.param( { 'method': 'moving_average', @@ -45,8 +104,8 @@ 'window_length': 3, 'padding': None, }, - pl.Series('position', [[1., 1.], [2., 2.], [3., 3.]], pl.List(pl.Float64)), - pl.Series('position', [[None, None], [2., 2.], [None, None]], pl.List(pl.Float64)), + pl.Series('position', [[0., 0.], [1., 1.], [0., 0.]], pl.List(pl.Float64)), + pl.Series('position', [[None, None], [1/3, 1/3], [None, None]], pl.List(pl.Float64)), id='moving_average_window_length_3_no_padding', ), pytest.param( @@ -56,9 +115,22 @@ 'window_length': 3, 'padding': 0.0, }, - pl.Series('position', [[1., 1.], [2., 2.], [3., 3.]], pl.List(pl.Float64)), - pl.Series('position', [[1., 1.], [2., 2.], [5 / 3, 5 / 3]], pl.List(pl.Float64)), - id='moving_average_window_length_3_constant_padding', + pl.Series('position', [[0., 0.], [1., 1.], [0., 0.]], pl.List(pl.Float64)), + pl.Series('position', [[1/3, 1/3], [1/3, 1/3], [1/3, 1/3]], pl.List(pl.Float64)), + id='moving_average_window_length_3_constant_padding_0', + ), + pytest.param( + { + 'method': 'moving_average', + 'n_components': 2, + 'window_length': 3, + 'padding': 1.0, + }, + pl.Series('position', [[0., 0.], [1., 1.], [0., 0.]], pl.List(pl.Float64)), + pl.Series( + 'position', [[2 / 3, 2 / 3], [1 / 3, 1 / 3], [2 / 3, 2 / 3]], pl.List(pl.Float64) + ), + id='moving_average_window_length_3_constant_padding_1', ), pytest.param( { @@ -67,8 +139,8 @@ 'window_length': 3, 'padding': 'nearest', }, - pl.Series('position', [[1., 1.], [2., 2.], [3., 3.]], pl.List(pl.Float64)), - pl.Series('position', [[4 / 3, 4 / 3], [2., 2.], [8 / 3, 8 / 3]], pl.List(pl.Float64)), + pl.Series('position', [[0., 0.], [1., 1.], [0., 0.]], pl.List(pl.Float64)), + pl.Series('position', [[1/3, 1/3], [1/3, 1/3], [1/3, 1/3]], pl.List(pl.Float64)), id='moving_average_window_length_3_nearest_padding', ), pytest.param( @@ -78,20 +150,139 @@ 'window_length': 3, 'padding': 'mirror', }, - pl.Series('position', [[1., 1.], [2., 2.], [3., 3.]], pl.List(pl.Float64)), - pl.Series('position', [[5 / 3, 5 / 3], [2., 2.], [7 / 3, 7 / 3]], pl.List(pl.Float64)), + pl.Series('position', [[0., 0.], [1., 1.], [0., 0.]], pl.List(pl.Float64)), + pl.Series('position', [[2/3, 2/3], [1/3, 1/3], [2/3, 2/3]], pl.List(pl.Float64)), id='moving_average_window_length_3_mirror_padding', ), + # Method: exponential_moving_average pytest.param( { 'method': 'exponential_moving_average', 'n_components': 2, 'window_length': 1, + 'padding': 'nearest', }, - pl.Series('position', [[1., 1.], [2., 2.], [3., 3.]], pl.List(pl.Float64)), - pl.Series('position', [[1., 1.], [2., 2.], [3., 3.]], pl.List(pl.Float64)), + pl.Series('position', [[0., 0.], [1., 1.], [0., 0.]], pl.List(pl.Float64)), + pl.Series('position', [[0., 0.], [1., 1.], [0., 0.]], pl.List(pl.Float64)), id='exponential_moving_average_window_length_1_returns_same_series', ), + pytest.param( + { + 'method': 'exponential_moving_average', + 'n_components': 2, + 'window_length': 2, + 'padding': None, + }, + pl.Series('position', [[0., 0.], [1., 1.], [0., 0.]], pl.List(pl.Float64)), + pl.Series('position', [[None, None], [2/3, 2/3], [2/9, 2/9]], pl.List(pl.Float64)), + id='exponential_moving_average_window_length_2_no_padding', + ), + pytest.param( + { + 'method': 'exponential_moving_average', + 'n_components': 2, + 'window_length': 2, + 'padding': 0.0, + }, + pl.Series('position', [[0., 0.], [1., 1.], [0., 0.]], pl.List(pl.Float64)), + pl.Series('position', [[0., 0.], [2/3, 2/3], [2/9, 2/9]], pl.List(pl.Float64)), + id='exponential_moving_average_window_length_2_constant_padding_0', + ), + pytest.param( + { + 'method': 'exponential_moving_average', + 'n_components': 2, + 'window_length': 2, + 'padding': 1.0, + }, + pl.Series('position', [[0., 0.], [1., 1.], [0., 0.]], pl.List(pl.Float64)), + pl.Series( + 'position', [[1 / 3, 1 / 3], [7 / 9, 7 / 9], [7 / 27, 7 / 27]], pl.List(pl.Float64) + ), + id='exponential_moving_average_window_length_2_constant_padding_1', + ), + pytest.param( + { + 'method': 'exponential_moving_average', + 'n_components': 2, + 'window_length': 2, + 'padding': 'nearest', + }, + pl.Series('position', [[0., 0.], [1., 1.], [0., 0.]], pl.List(pl.Float64)), + pl.Series('position', [[0., 0.], [2/3, 2/3], [2/9, 2/9]], pl.List(pl.Float64)), + id='exponential_moving_average_window_length_2_nearest_padding', + ), + pytest.param( + { + 'method': 'exponential_moving_average', + 'n_components': 2, + 'window_length': 2, + 'padding': 'mirror', + }, + pl.Series('position', [[0., 0.], [1., 1.], [0., 0.]], pl.List(pl.Float64)), + pl.Series( + 'position', [[1 / 3, 1 / 3], [7 / 9, 7 / 9], [7 / 27, 7 / 27]], pl.List(pl.Float64) + ), + id='exponential_moving_average_window_length_2_mirror_padding', + ), + pytest.param( + { + 'method': 'exponential_moving_average', + 'n_components': 2, + 'window_length': 3, + 'padding': None, + }, + pl.Series('position', [[0., 0.], [1., 1.], [0., 0.]], pl.List(pl.Float64)), + pl.Series('position', [[None, None], [None, None], [0.25, 0.25]], pl.List(pl.Float64)), + id='exponential_moving_average_window_length_3_no_padding', + ), + pytest.param( + { + 'method': 'exponential_moving_average', + 'n_components': 2, + 'window_length': 3, + 'padding': 0.0, + }, + pl.Series('position', [[0., 0.], [1., 1.], [0., 0.]], pl.List(pl.Float64)), + pl.Series('position', [[0., 0.], [0.5, 0.5], [0.25, 0.25]], pl.List(pl.Float64)), + id='exponential_moving_average_window_length_3_constant_padding_0', + ), + pytest.param( + { + 'method': 'exponential_moving_average', + 'n_components': 2, + 'window_length': 3, + 'padding': 1.0, + }, + pl.Series('position', [[0., 0.], [1., 1.], [0., 0.]], pl.List(pl.Float64)), + pl.Series('position', [[0.5, 0.5], [0.75, 0.75], [0.375, 0.375]], pl.List(pl.Float64)), + id='exponential_moving_average_window_length_3_constant_padding_1', + ), + pytest.param( + { + 'method': 'exponential_moving_average', + 'n_components': 2, + 'window_length': 3, + 'padding': 'nearest', + }, + pl.Series('position', [[0., 0.], [1., 1.], [0., 0.]], pl.List(pl.Float64)), + pl.Series('position', [[0., 0.], [0.5, 0.5], [0.25, 0.25]], pl.List(pl.Float64)), + id='exponential_moving_average_window_length_3_nearest_padding', + ), + pytest.param( + { + 'method': 'exponential_moving_average', + 'n_components': 2, + 'window_length': 3, + 'padding': 'mirror', + }, + pl.Series('position', [[0., 0.], [1., 1.], [0., 0.]], pl.List(pl.Float64)), + pl.Series( + 'position', [[0.25, 0.25], [0.625, 0.625], [0.3125, 0.3125]], pl.List(pl.Float64) + ), + id='exponential_moving_average_window_length_3_mirror_padding', + ), + # Method: savitzky_golay pytest.param( { 'method': 'savitzky_golay', @@ -99,9 +290,31 @@ 'window_length': 2, 'degree': 1, }, - pl.Series('position', [[1., 1.], [2., 2.], [3., 3.]], pl.List(pl.Float64)), - pl.Series('position', [[1.5, 1.5], [2.5, 2.5], [3., 3.]], pl.List(pl.Float64)), - id='savitzky_golay_window_length_2_degree_1_returns', + pl.Series('position', [[0., 0.], [1., 1.], [0., 0.]], pl.List(pl.Float64)), + pl.Series('position', [[0.5, 0.5], [0.5, 0.5], [0., 0.]], pl.List(pl.Float64)), + id='savitzky_golay_window_length_2_degree_1_returns_mean_of_window', + ), + pytest.param( + { + 'method': 'savitzky_golay', + 'n_components': 2, + 'window_length': 3, + 'degree': 1, + }, + pl.Series('position', [[0., 0.], [1., 1.], [0., 0.]], pl.List(pl.Float64)), + pl.Series('position', [[1/3, 1/3], [1/3, 1/3], [1/3, 1/3]], pl.List(pl.Float64)), + id='savitzky_golay_window_length_3_degree_1_returns_mean_of_window', + ), + pytest.param( + { + 'method': 'savitzky_golay', + 'n_components': 2, + 'window_length': 3, + 'degree': 2, + }, + pl.Series('position', [[0., 0.], [1., 1.], [0., 0.]], pl.List(pl.Float64)), + pl.Series('position', [[0.,0.], [1., 1.], [0., 0.]], pl.List(pl.Float64)), + id='savitzky_golay_window_length_3_degree_2_returns', ), ], )