Skip to content

Commit

Permalink
feat: Add gaze.transforms.smooth() (#555)
Browse files Browse the repository at this point in the history
  • Loading branch information
jakobchwastek authored Sep 29, 2023
1 parent c8ef04d commit dc3a32d
Show file tree
Hide file tree
Showing 4 changed files with 734 additions and 1 deletion.
46 changes: 46 additions & 0 deletions src/pymovements/gaze/gaze_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,52 @@ def pos2vel(
"""
self.transform('pos2vel', method=method, **kwargs)

def smooth(
self,
method: str = 'savitzky_golay',
window_length: int = 7,
degree: int = 2,
column: str = 'position',
padding: str | float | int | None = 'nearest',
**kwargs: int | float | str,
) -> None:
"""Smooth data in a column.
Parameters
----------
method:
The method to use for smoothing. Choose from ``savitzky_golay``, ``moving_average``,
``exponential_moving_average``. See :func:`~transforms.smooth()` for details.
window_length:
For ``moving_average`` this is the window size to calculate the mean of the subsequent
samples. For ``savitzky_golay`` this is the window size to use for the polynomial fit.
For ``exponential_moving_average`` this is the span parameter.
degree:
The degree of the polynomial to use. This has only an effect if using
``savitzky_golay`` as smoothing method. `degree` must be less than `window_length`.
column:
The input column name to which the smoothing is applied.
padding:
Must be either ``None``, a scalar or one of the strings
``mirror``, ``nearest`` or ``wrap``.
This determines the type of extension to use for the padded signal to
which the filter is applied.
When passing ``None``, no extension padding is used.
When passing a scalar value, data will be padded using the passed value.
See :func:`~transforms.smooth()` for details on the padding methods.
**kwargs:
Additional keyword arguments to be passed to the :func:`~transforms.smooth()` method.
"""
self.transform(
'smooth',
column=column,
method=method,
degree=degree,
window_length=window_length,
padding=padding,
**kwargs,
)

def detect(
self,
method: Callable[..., pm.EventDataFrame] | str,
Expand Down
160 changes: 159 additions & 1 deletion src/pymovements/gaze/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@

from pymovements.utils import checks


TransformMethod = TypeVar('TransformMethod', bound=Callable[..., pl.Expr])


Expand Down Expand Up @@ -558,6 +557,165 @@ def savitzky_golay(
).alias(output_column)


@register_transform
def smooth(
*,
method: str,
window_length: int,
n_components: int,
degree: int | None = None,
column: str = 'position',
padding: str | float | int | None = 'nearest',
) -> pl.Expr:
"""
Smooth data in a column.
Parameters
----------
method:
The method to use for smoothing. See Notes for more details.
window_length
For ``moving_average`` this is the window size to calculate the mean of the subsequent
samples. For ``savitzky_golay`` this is the window size to use for the polynomial fit.
For ``exponential_moving_average`` this is the span parameter.
n_components:
Number of components in input column.
degree:
The degree of the polynomial to use. This has only an effect if using ``savitzky_golay`` as
smoothing method. `degree` must be less than `window_length`.
column:
The input column name to which the smoothing is applied.
padding:
Must be either ``None``, a scalar or one of the strings ``mirror``, ``nearest`` or ``wrap``.
This determines the type of extension to use for the padded signal to
which the filter is applied.
When passing ``None``, no extension padding is used.
When passing a scalar value, data will be padded using the passed value.
See the Notes for more details on the padding methods.
Returns
-------
polars.Expr
The respective polars expression.
Notes
-----
There following methods are available for smoothing:
* ``savitzky_golay``: Smooth data by applying a Savitzky-Golay filter.
See :py:func:`~pymovements.gaze.transforms.savitzky_golay` for further details.
* ``moving_average``: Smooth data by calculating the mean of the subsequent samples.
Each smoothed sample is calculated by the mean of the samples in the window around the sample.
* ``exponential_moving_average``: Smooth data by exponentially weighted moving average.
Details on the `padding` options:
* ``None``: No padding extension is used.
* scalar value (int or float): The padding extension contains the specified scalar value.
* ``mirror``: Repeats the values at the edges in reverse order. The value closest to the edge is
not included.
* ``nearest``: The padding extension contains the nearest input value.
* ``wrap``: The padding extension contains the values from the other end of the array.
Given the input is ``[1, 2, 3, 4, 5, 6, 7, 8]``, and
`window_length` is 7, the following table shows the padded data for
the various ``padding`` options:
+-------------+-------------+----------------------------+-------------+
| mode | padding | input | padding |
+=============+=============+============================+=============+
| ``None`` | ``- - -`` | ``1 2 3 4 5 6 7 8`` | ``- - -`` |
+-------------+-------------+----------------------------+-------------+
| ``0`` | ``0 0 0`` | ``1 2 3 4 5 6 7 8`` | ``0 0 0`` |
+-------------+-------------+----------------------------+-------------+
| ``1`` | ``1 1 1`` | ``1 2 3 4 5 6 7 8`` | ``1 1 1`` |
+-------------+-------------+----------------------------+-------------+
| ``nearest`` | ``1 1 1`` | ``1 2 3 4 5 6 7 8`` | ``8 8 8`` |
+-------------+-------------+----------------------------+-------------+
| ``mirror`` | ``4 3 2`` | ``1 2 3 4 5 6 7 8`` | ``7 6 5`` |
+-------------+-------------+----------------------------+-------------+
| ``wrap`` | ``6 7 8`` | ``1 2 3 4 5 6 7 8`` | ``1 2 3`` |
+-------------+-------------+----------------------------+-------------+
"""
_check_window_length(window_length=window_length)
_check_padding(padding=padding)

if method in {'moving_average', 'exponential_moving_average'}:
pad_kwargs: dict[str, Any] = {'pad_width': 0}
pad_func = _identity

if isinstance(padding, (int, float)):
pad_kwargs['constant_values'] = padding
padding = 'constant'
elif padding == 'nearest':
# option 'nearest' is called 'edge' for np.pad
padding = 'edge'
elif padding == 'mirror':
# option 'mirror' is called 'reflect' for np.pad
padding = 'reflect'

if padding is not None:
pad_kwargs['mode'] = padding
pad_kwargs['pad_width'] = np.ceil(window_length / 2).astype(int)

pad_func = partial(
np.pad,
**pad_kwargs,
)

if method == 'moving_average':

return pl.concat_list(
[
pl.col(column).list.get(component).map(pad_func).list.explode()
.rolling_mean(window_size=window_length, center=True)
.shift(periods=pad_kwargs['pad_width'])
.slice(pad_kwargs['pad_width'] * 2)
for component in range(n_components)
],
).alias(column)

return pl.concat_list(
[
pl.col(column).list.get(component).map(pad_func).list.explode()
.ewm_mean(
span=window_length,
adjust=False,
min_periods=window_length,
).shift(periods=pad_kwargs['pad_width'])
.slice(pad_kwargs['pad_width'] * 2)
for component in range(n_components)
],
).alias(column)

if method == 'savitzky_golay':
if degree is None:
raise TypeError("'degree' must not be none for method 'savitzky_golay'")

return savitzky_golay(
window_length=window_length,
degree=degree,
sampling_rate=1,
padding=padding,
derivative=0,
n_components=n_components,
input_column=column,
output_column=None,
)

supported_methods = ['moving_average', 'exponential_moving_average', 'savitzky_golay']

raise ValueError(
f"Unknown method '{method}'. Supported methods are: {supported_methods}",
)


def _identity(x: Any) -> Any:
"""Identity function as placeholder for None as padding."""
return x


def _check_window_length(window_length: Any) -> None:
"""Check that window length is an integer and greater than zero."""
checks.check_is_not_none(window_length=window_length)
Expand Down
116 changes: 116 additions & 0 deletions tests/gaze/gaze_transform_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,6 +439,50 @@ def fixture_experiment():
),
id='pos2vel_preceding_trialize_single_column_str',
),
pytest.param(
{
'data': pl.from_dict(
{
'x_dva': [1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
'y_dva': [2.0, 2.0, 2.0, 2.0, 2.0, 2.0],
},
),
'position_columns': ['x_dva', 'y_dva'],
},
'smooth', {'method': 'moving_average', 'window_length': 3},
pm.GazeDataFrame(
data=pl.from_dict(
{
'x_dva': [1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
'y_dva': [2.0, 2.0, 2.0, 2.0, 2.0, 2.0],
},
),
position_columns=['x_dva', 'y_dva'],
),
id='smooth',
),
pytest.param(
{
'data': pl.from_dict(
{
'x_dva': [1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
'y_dva': [2.0, 2.0, 2.0, 2.0, 2.0, 2.0],
},
),
'position_columns': ['x_dva', 'y_dva'],
},
pm.gaze.transforms.smooth, {'method': 'moving_average', 'window_length': 3},
pm.GazeDataFrame(
data=pl.from_dict(
{
'x_dva': [1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
'y_dva': [2.0, 2.0, 2.0, 2.0, 2.0, 2.0],
},
),
position_columns=['x_dva', 'y_dva'],
),
id='smooth_method_pass',
),
],
)
Expand Down Expand Up @@ -750,3 +794,75 @@ def test_gaze_dataframe_pos2vel_exceptions(init_kwargs, exception, expected_msg)

msg, = excinfo.value.args
assert msg == expected_msg


@pytest.mark.parametrize(
('gaze_init_kwargs', 'kwargs', 'expected'),
[
pytest.param(
{
'data': pl.from_dict(
{
'x_pix': [0.0, 1.0, 0.0, 0.0, 1.0, 0.0],
'y_pix': [0.0, 1.0, 0.0, 0.0, 1.0, 0.0],
'x_dva': [0.0, 1.0, 0.0, 0.0, 1.0, 0.0],
'y_dva': [0.0, 1.0, 0.0, 0.0, 1.0, 0.0],
},
),
'pixel_columns': ['x_pix', 'y_pix'],
'position_columns': ['x_dva', 'y_dva'],
},
{'method': 'moving_average', 'column': 'pixel', 'window_length': 3},
pm.GazeDataFrame(
data=pl.from_dict(
{
'x_pix': [1 / 3, 1 / 3, 1 / 3, 1 / 3, 1 / 3, 1 / 3],
'y_pix': [1 / 3, 1 / 3, 1 / 3, 1 / 3, 1 / 3, 1 / 3],
'x_dva': [0.0, 1.0, 0.0, 0.0, 1.0, 0.0],
'y_dva': [0.0, 1.0, 0.0, 0.0, 1.0, 0.0],
},
),
pixel_columns=['x_pix', 'y_pix'],
position_columns=['x_dva', 'y_dva'],
),
id='pixel',
),
pytest.param(
{
'data': pl.from_dict(
{
'x_pix': [0.0, 1.0, 0.0, 0.0, 1.0, 0.0],
'y_pix': [0.0, 1.0, 0.0, 0.0, 1.0, 0.0],
'x_dva': [0.0, 1.0, 0.0, 0.0, 1.0, 0.0],
'y_dva': [0.0, 1.0, 0.0, 0.0, 1.0, 0.0],
},
),
'pixel_columns': ['x_pix', 'y_pix'],
'position_columns': ['x_dva', 'y_dva'],
},
{'method': 'moving_average', 'column': 'position', 'window_length': 3},
pm.GazeDataFrame(
data=pl.from_dict(
{
'x_pix': [0.0, 1.0, 0.0, 0.0, 1.0, 0.0],
'y_pix': [0.0, 1.0, 0.0, 0.0, 1.0, 0.0],
'x_dva': [1 / 3, 1 / 3, 1 / 3, 1 / 3, 1 / 3, 1 / 3],
'y_dva': [1 / 3, 1 / 3, 1 / 3, 1 / 3, 1 / 3, 1 / 3],
},
),
pixel_columns=['x_pix', 'y_pix'],
position_columns=['x_dva', 'y_dva'],
),
id='position',
),
],
)
def test_gaze_dataframe_smooth_expected_column(
gaze_init_kwargs, kwargs, expected,
):
gaze = pm.GazeDataFrame(**gaze_init_kwargs)
gaze.smooth(**kwargs)

assert_frame_equal(gaze.frame, expected.frame)
Loading

0 comments on commit dc3a32d

Please sign in to comment.