Skip to content

Commit

Permalink
* minor fix pad_width type
Browse files Browse the repository at this point in the history
* use rounded up half of window_length as padding length
* add tests for all smoothing methods
  • Loading branch information
jakobchwastek committed Sep 21, 2023
1 parent a92e61b commit f74a75f
Show file tree
Hide file tree
Showing 2 changed files with 232 additions and 19 deletions.
6 changes: 3 additions & 3 deletions src/pymovements/gaze/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -568,7 +568,7 @@ def smooth(
Number of components in input column.
degree:
The degree of the polynomial to use. This has only an effect if using ``savitzky_golay`` as
smoothing method.
smoothing method. `degree` must be less than `window_length`.
padding:
Must be either ``None``, a scalar or one of the strings ``mirror``, ``nearest`` or ``wrap``.
This determines the type of extension to use for the padded signal to
Expand Down Expand Up @@ -626,7 +626,7 @@ def smooth(
_check_padding(padding=padding)

if method in {'moving_average', 'exponential_moving_average'}:
pad_kwargs: dict[str, Any] = {'pad_width': 0.0}
pad_kwargs: dict[str, Any] = {'pad_width': 0}
pad_func = _identity

if isinstance(padding, (int, float)):
Expand All @@ -641,7 +641,7 @@ def smooth(

if padding is not None:
pad_kwargs['mode'] = padding
pad_kwargs['pad_width'] = window_length // 2
pad_kwargs['pad_width'] = np.ceil(window_length / 2).astype(int)

pad_func = partial(
np.pad,
Expand Down
245 changes: 229 additions & 16 deletions tests/gaze/transforms/smooth_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,25 +28,84 @@
@pytest.mark.parametrize(
'kwargs, series, expected_df',
[
# Method: moving_average
pytest.param(
{
'method': 'moving_average',
'n_components': 2,
'window_length': 1,
'padding': 'nearest',
},
pl.Series('position', [[1., 1.], [2., 2.], [3., 3.]], pl.List(pl.Float64)),
pl.Series('position', [[1., 1.], [2., 2.], [3., 3.]], pl.List(pl.Float64)),
pl.Series('position', [[0., 0.], [1., 1.], [0., 0.]], pl.List(pl.Float64)),
pl.Series('position', [[0., 0.], [1., 1.], [0., 0.]], pl.List(pl.Float64)),
id='moving_average_window_length_1_returns_same_series',
),
pytest.param(
{
'method': 'moving_average',
'n_components': 2,
'window_length': 2,
'padding': None,
},
pl.Series('position', [[0., 0.], [1., 1.], [0., 0.]], pl.List(pl.Float64)),
pl.Series('position', [[None, None], [1/2, 1/2], [1/2, 1/2]], pl.List(pl.Float64)),
id='moving_average_window_length_2_no_padding',
),
pytest.param(
{
'method': 'moving_average',
'n_components': 2,
'window_length': 2,
'padding': 0.0,
},
pl.Series('position', [[0., 0.], [1., 1.], [0., 0.]], pl.List(pl.Float64)),
pl.Series('position', [[0., 0.], [1/2, 1/2], [1/2, 1/2]], pl.List(pl.Float64)),
id='moving_average_window_length_2_constant_padding_0',
),
pytest.param(
{
'method': 'moving_average',
'n_components': 2,
'window_length': 2,
'padding': 1.0,
},
pl.Series('position', [[0., 0.], [1., 1.], [0., 0.]], pl.List(pl.Float64)),
pl.Series(
'position', [[1 / 2, 1 / 2], [1 / 2, 1 / 2], [1 / 2, 1 / 2]], pl.List(pl.Float64)
),
id='moving_average_window_length_2_constant_padding_1',
),
pytest.param(
{
'method': 'moving_average',
'n_components': 2,
'window_length': 2,
'padding': 'nearest',
},
pl.Series('position', [[0., 0.], [1., 1.], [0., 0.]], pl.List(pl.Float64)),
pl.Series('position', [[0., 0.], [1/2, 1/2], [1/2, 1/2]], pl.List(pl.Float64)),
id='moving_average_window_length_2_nearest_padding',
),
pytest.param(
{
'method': 'moving_average',
'n_components': 2,
'window_length': 2,
'padding': 'mirror',
},
pl.Series('position', [[0., 0.], [1., 1.], [0., 0.]], pl.List(pl.Float64)),
pl.Series('position', [[1 / 2, 1/2], [1/2, 1/2], [1/2, 1/2]], pl.List(pl.Float64)),
id='moving_average_window_length_2_mirror_padding',
),
pytest.param(
{
'method': 'moving_average',
'n_components': 2,
'window_length': 3,
'padding': None,
},
pl.Series('position', [[1., 1.], [2., 2.], [3., 3.]], pl.List(pl.Float64)),
pl.Series('position', [[None, None], [2., 2.], [None, None]], pl.List(pl.Float64)),
pl.Series('position', [[0., 0.], [1., 1.], [0., 0.]], pl.List(pl.Float64)),
pl.Series('position', [[None, None], [1/3, 1/3], [None, None]], pl.List(pl.Float64)),
id='moving_average_window_length_3_no_padding',
),
pytest.param(
Expand All @@ -56,9 +115,22 @@
'window_length': 3,
'padding': 0.0,
},
pl.Series('position', [[1., 1.], [2., 2.], [3., 3.]], pl.List(pl.Float64)),
pl.Series('position', [[1., 1.], [2., 2.], [5 / 3, 5 / 3]], pl.List(pl.Float64)),
id='moving_average_window_length_3_constant_padding',
pl.Series('position', [[0., 0.], [1., 1.], [0., 0.]], pl.List(pl.Float64)),
pl.Series('position', [[1/3, 1/3], [1/3, 1/3], [1/3, 1/3]], pl.List(pl.Float64)),
id='moving_average_window_length_3_constant_padding_0',
),
pytest.param(
{
'method': 'moving_average',
'n_components': 2,
'window_length': 3,
'padding': 1.0,
},
pl.Series('position', [[0., 0.], [1., 1.], [0., 0.]], pl.List(pl.Float64)),
pl.Series(
'position', [[2 / 3, 2 / 3], [1 / 3, 1 / 3], [2 / 3, 2 / 3]], pl.List(pl.Float64)
),
id='moving_average_window_length_3_constant_padding_1',
),
pytest.param(
{
Expand All @@ -67,8 +139,8 @@
'window_length': 3,
'padding': 'nearest',
},
pl.Series('position', [[1., 1.], [2., 2.], [3., 3.]], pl.List(pl.Float64)),
pl.Series('position', [[4 / 3, 4 / 3], [2., 2.], [8 / 3, 8 / 3]], pl.List(pl.Float64)),
pl.Series('position', [[0., 0.], [1., 1.], [0., 0.]], pl.List(pl.Float64)),
pl.Series('position', [[1/3, 1/3], [1/3, 1/3], [1/3, 1/3]], pl.List(pl.Float64)),
id='moving_average_window_length_3_nearest_padding',
),
pytest.param(
Expand All @@ -78,30 +150,171 @@
'window_length': 3,
'padding': 'mirror',
},
pl.Series('position', [[1., 1.], [2., 2.], [3., 3.]], pl.List(pl.Float64)),
pl.Series('position', [[5 / 3, 5 / 3], [2., 2.], [7 / 3, 7 / 3]], pl.List(pl.Float64)),
pl.Series('position', [[0., 0.], [1., 1.], [0., 0.]], pl.List(pl.Float64)),
pl.Series('position', [[2/3, 2/3], [1/3, 1/3], [2/3, 2/3]], pl.List(pl.Float64)),
id='moving_average_window_length_3_mirror_padding',
),
# Method: exponential_moving_average
pytest.param(
{
'method': 'exponential_moving_average',
'n_components': 2,
'window_length': 1,
'padding': 'nearest',
},
pl.Series('position', [[1., 1.], [2., 2.], [3., 3.]], pl.List(pl.Float64)),
pl.Series('position', [[1., 1.], [2., 2.], [3., 3.]], pl.List(pl.Float64)),
pl.Series('position', [[0., 0.], [1., 1.], [0., 0.]], pl.List(pl.Float64)),
pl.Series('position', [[0., 0.], [1., 1.], [0., 0.]], pl.List(pl.Float64)),
id='exponential_moving_average_window_length_1_returns_same_series',
),
pytest.param(
{
'method': 'exponential_moving_average',
'n_components': 2,
'window_length': 2,
'padding': None,
},
pl.Series('position', [[0., 0.], [1., 1.], [0., 0.]], pl.List(pl.Float64)),
pl.Series('position', [[None, None], [2/3, 2/3], [2/9, 2/9]], pl.List(pl.Float64)),
id='exponential_moving_average_window_length_2_no_padding',
),
pytest.param(
{
'method': 'exponential_moving_average',
'n_components': 2,
'window_length': 2,
'padding': 0.0,
},
pl.Series('position', [[0., 0.], [1., 1.], [0., 0.]], pl.List(pl.Float64)),
pl.Series('position', [[0., 0.], [2/3, 2/3], [2/9, 2/9]], pl.List(pl.Float64)),
id='exponential_moving_average_window_length_2_constant_padding_0',
),
pytest.param(
{
'method': 'exponential_moving_average',
'n_components': 2,
'window_length': 2,
'padding': 1.0,
},
pl.Series('position', [[0., 0.], [1., 1.], [0., 0.]], pl.List(pl.Float64)),
pl.Series(
'position', [[1 / 3, 1 / 3], [7 / 9, 7 / 9], [7 / 27, 7 / 27]], pl.List(pl.Float64)
),
id='exponential_moving_average_window_length_2_constant_padding_1',
),
pytest.param(
{
'method': 'exponential_moving_average',
'n_components': 2,
'window_length': 2,
'padding': 'nearest',
},
pl.Series('position', [[0., 0.], [1., 1.], [0., 0.]], pl.List(pl.Float64)),
pl.Series('position', [[0., 0.], [2/3, 2/3], [2/9, 2/9]], pl.List(pl.Float64)),
id='exponential_moving_average_window_length_2_nearest_padding',
),
pytest.param(
{
'method': 'exponential_moving_average',
'n_components': 2,
'window_length': 2,
'padding': 'mirror',
},
pl.Series('position', [[0., 0.], [1., 1.], [0., 0.]], pl.List(pl.Float64)),
pl.Series(
'position', [[1 / 3, 1 / 3], [7 / 9, 7 / 9], [7 / 27, 7 / 27]], pl.List(pl.Float64)
),
id='exponential_moving_average_window_length_2_mirror_padding',
),
pytest.param(
{
'method': 'exponential_moving_average',
'n_components': 2,
'window_length': 3,
'padding': None,
},
pl.Series('position', [[0., 0.], [1., 1.], [0., 0.]], pl.List(pl.Float64)),
pl.Series('position', [[None, None], [None, None], [0.25, 0.25]], pl.List(pl.Float64)),
id='exponential_moving_average_window_length_3_no_padding',
),
pytest.param(
{
'method': 'exponential_moving_average',
'n_components': 2,
'window_length': 3,
'padding': 0.0,
},
pl.Series('position', [[0., 0.], [1., 1.], [0., 0.]], pl.List(pl.Float64)),
pl.Series('position', [[0., 0.], [0.5, 0.5], [0.25, 0.25]], pl.List(pl.Float64)),
id='exponential_moving_average_window_length_3_constant_padding_0',
),
pytest.param(
{
'method': 'exponential_moving_average',
'n_components': 2,
'window_length': 3,
'padding': 1.0,
},
pl.Series('position', [[0., 0.], [1., 1.], [0., 0.]], pl.List(pl.Float64)),
pl.Series('position', [[0.5, 0.5], [0.75, 0.75], [0.375, 0.375]], pl.List(pl.Float64)),
id='exponential_moving_average_window_length_3_constant_padding_1',
),
pytest.param(
{
'method': 'exponential_moving_average',
'n_components': 2,
'window_length': 3,
'padding': 'nearest',
},
pl.Series('position', [[0., 0.], [1., 1.], [0., 0.]], pl.List(pl.Float64)),
pl.Series('position', [[0., 0.], [0.5, 0.5], [0.25, 0.25]], pl.List(pl.Float64)),
id='exponential_moving_average_window_length_3_nearest_padding',
),
pytest.param(
{
'method': 'exponential_moving_average',
'n_components': 2,
'window_length': 3,
'padding': 'mirror',
},
pl.Series('position', [[0., 0.], [1., 1.], [0., 0.]], pl.List(pl.Float64)),
pl.Series(
'position', [[0.25, 0.25], [0.625, 0.625], [0.3125, 0.3125]], pl.List(pl.Float64)
),
id='exponential_moving_average_window_length_3_mirror_padding',
),
# Method: savitzky_golay
pytest.param(
{
'method': 'savitzky_golay',
'n_components': 2,
'window_length': 2,
'degree': 1,
},
pl.Series('position', [[1., 1.], [2., 2.], [3., 3.]], pl.List(pl.Float64)),
pl.Series('position', [[1.5, 1.5], [2.5, 2.5], [3., 3.]], pl.List(pl.Float64)),
id='savitzky_golay_window_length_2_degree_1_returns',
pl.Series('position', [[0., 0.], [1., 1.], [0., 0.]], pl.List(pl.Float64)),
pl.Series('position', [[0.5, 0.5], [0.5, 0.5], [0., 0.]], pl.List(pl.Float64)),
id='savitzky_golay_window_length_2_degree_1_returns_mean_of_window',
),
pytest.param(
{
'method': 'savitzky_golay',
'n_components': 2,
'window_length': 3,
'degree': 1,
},
pl.Series('position', [[0., 0.], [1., 1.], [0., 0.]], pl.List(pl.Float64)),
pl.Series('position', [[1/3, 1/3], [1/3, 1/3], [1/3, 1/3]], pl.List(pl.Float64)),
id='savitzky_golay_window_length_3_degree_1_returns_mean_of_window',
),
pytest.param(
{
'method': 'savitzky_golay',
'n_components': 2,
'window_length': 3,
'degree': 2,
},
pl.Series('position', [[0., 0.], [1., 1.], [0., 0.]], pl.List(pl.Float64)),
pl.Series('position', [[0.,0.], [1., 1.], [0., 0.]], pl.List(pl.Float64)),
id='savitzky_golay_window_length_3_degree_2_returns',
),
],
)
Expand Down

0 comments on commit f74a75f

Please sign in to comment.