Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add gaze.transforms.smooth() #555

Merged
merged 26 commits into from
Sep 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
d930231
First implementation of smooth method
jakobchwastek Sep 12, 2023
8180f11
* Docs for padding
jakobchwastek Sep 12, 2023
d97555b
First Test for smooth method
jakobchwastek Sep 12, 2023
6451ea4
Fix padding functionality
jakobchwastek Sep 14, 2023
816e64e
Fix exponential moving average
jakobchwastek Sep 14, 2023
19bb655
Add _identity return value tyoe
jakobchwastek Sep 14, 2023
b408d09
Add some tests
jakobchwastek Sep 14, 2023
d5b7895
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 14, 2023
5a7f63e
Fix mypy typing error
jakobchwastek Sep 14, 2023
b345aae
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 14, 2023
8c6eceb
* Fix No blank lines allowed after function docstring
jakobchwastek Sep 14, 2023
a92e61b
Merge remote-tracking branch 'origin/feature/add-transforms-smooth' i…
jakobchwastek Sep 14, 2023
f74a75f
* minor fix pad_width type
jakobchwastek Sep 21, 2023
288f087
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 21, 2023
8457c30
* add elif; try fix coverage exponential_moving_average
jakobchwastek Sep 21, 2023
bf0be39
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 21, 2023
3613e69
* Fix coverage
jakobchwastek Sep 22, 2023
590be44
* Remove unnecessary "else" after "return"
jakobchwastek Sep 22, 2023
f3bab1e
Merge branch 'main' into feature/add-transforms-smooth
jakobchwastek Sep 22, 2023
d5a6ddc
* add smooth method to gaze dataframe + tests
jakobchwastek Sep 22, 2023
aeeaa35
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 22, 2023
3d925db
* fix no whitespaces allowed surrounding docstring text
jakobchwastek Sep 22, 2023
a6c73bd
* fix no whitespaces allowed surrounding docstring text
jakobchwastek Sep 22, 2023
24111c4
Merge branch 'main' into feature/add-transforms-smooth
SiQube Sep 22, 2023
12bd2df
Complete docs
jakobchwastek Sep 25, 2023
67bbbde
Merge branch 'main' into feature/add-transforms-smooth
SiQube Sep 29, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions src/pymovements/gaze/gaze_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,52 @@ def pos2vel(
"""
self.transform('pos2vel', method=method, **kwargs)

def smooth(
self,
method: str = 'savitzky_golay',
window_length: int = 7,
degree: int = 2,
column: str = 'position',
padding: str | float | int | None = 'nearest',
**kwargs: int | float | str,
) -> None:
"""Smooth data in a column.

Parameters
----------
method:
The method to use for smoothing. Choose from ``savitzky_golay``, ``moving_average``,
``exponential_moving_average``. See :func:`~transforms.smooth()` for details.
window_length:
For ``moving_average`` this is the window size to calculate the mean of the subsequent
samples. For ``savitzky_golay`` this is the window size to use for the polynomial fit.
For ``exponential_moving_average`` this is the span parameter.
degree:
The degree of the polynomial to use. This has only an effect if using
``savitzky_golay`` as smoothing method. `degree` must be less than `window_length`.
column:
The input column name to which the smoothing is applied.
padding:
Must be either ``None``, a scalar or one of the strings
``mirror``, ``nearest`` or ``wrap``.
This determines the type of extension to use for the padded signal to
which the filter is applied.
When passing ``None``, no extension padding is used.
When passing a scalar value, data will be padded using the passed value.
See :func:`~transforms.smooth()` for details on the padding methods.
**kwargs:
Additional keyword arguments to be passed to the :func:`~transforms.smooth()` method.
"""
self.transform(
'smooth',
column=column,
method=method,
degree=degree,
window_length=window_length,
padding=padding,
**kwargs,
)

def detect(
self,
method: Callable[..., pm.EventDataFrame] | str,
Expand Down
160 changes: 159 additions & 1 deletion src/pymovements/gaze/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@

from pymovements.utils import checks


TransformMethod = TypeVar('TransformMethod', bound=Callable[..., pl.Expr])


Expand Down Expand Up @@ -558,6 +557,165 @@ def savitzky_golay(
).alias(output_column)


@register_transform
def smooth(
*,
method: str,
window_length: int,
n_components: int,
degree: int | None = None,
column: str = 'position',
padding: str | float | int | None = 'nearest',
) -> pl.Expr:
"""
Smooth data in a column.

Parameters
----------
method:
The method to use for smoothing. See Notes for more details.
window_length
For ``moving_average`` this is the window size to calculate the mean of the subsequent
samples. For ``savitzky_golay`` this is the window size to use for the polynomial fit.
For ``exponential_moving_average`` this is the span parameter.
n_components:
Number of components in input column.
degree:
The degree of the polynomial to use. This has only an effect if using ``savitzky_golay`` as
smoothing method. `degree` must be less than `window_length`.
column:
The input column name to which the smoothing is applied.
padding:
Must be either ``None``, a scalar or one of the strings ``mirror``, ``nearest`` or ``wrap``.
This determines the type of extension to use for the padded signal to
which the filter is applied.
When passing ``None``, no extension padding is used.
When passing a scalar value, data will be padded using the passed value.
See the Notes for more details on the padding methods.

Returns
-------
polars.Expr
The respective polars expression.

Notes
-----
There following methods are available for smoothing:

* ``savitzky_golay``: Smooth data by applying a Savitzky-Golay filter.
See :py:func:`~pymovements.gaze.transforms.savitzky_golay` for further details.
* ``moving_average``: Smooth data by calculating the mean of the subsequent samples.
Each smoothed sample is calculated by the mean of the samples in the window around the sample.
* ``exponential_moving_average``: Smooth data by exponentially weighted moving average.

Details on the `padding` options:

* ``None``: No padding extension is used.
* scalar value (int or float): The padding extension contains the specified scalar value.
* ``mirror``: Repeats the values at the edges in reverse order. The value closest to the edge is
not included.
* ``nearest``: The padding extension contains the nearest input value.
* ``wrap``: The padding extension contains the values from the other end of the array.

Given the input is ``[1, 2, 3, 4, 5, 6, 7, 8]``, and
`window_length` is 7, the following table shows the padded data for
the various ``padding`` options:

+-------------+-------------+----------------------------+-------------+
| mode | padding | input | padding |
+=============+=============+============================+=============+
| ``None`` | ``- - -`` | ``1 2 3 4 5 6 7 8`` | ``- - -`` |
+-------------+-------------+----------------------------+-------------+
| ``0`` | ``0 0 0`` | ``1 2 3 4 5 6 7 8`` | ``0 0 0`` |
+-------------+-------------+----------------------------+-------------+
| ``1`` | ``1 1 1`` | ``1 2 3 4 5 6 7 8`` | ``1 1 1`` |
+-------------+-------------+----------------------------+-------------+
| ``nearest`` | ``1 1 1`` | ``1 2 3 4 5 6 7 8`` | ``8 8 8`` |
+-------------+-------------+----------------------------+-------------+
| ``mirror`` | ``4 3 2`` | ``1 2 3 4 5 6 7 8`` | ``7 6 5`` |
+-------------+-------------+----------------------------+-------------+
| ``wrap`` | ``6 7 8`` | ``1 2 3 4 5 6 7 8`` | ``1 2 3`` |
+-------------+-------------+----------------------------+-------------+

"""
_check_window_length(window_length=window_length)
_check_padding(padding=padding)

if method in {'moving_average', 'exponential_moving_average'}:
pad_kwargs: dict[str, Any] = {'pad_width': 0}
pad_func = _identity

if isinstance(padding, (int, float)):
pad_kwargs['constant_values'] = padding
padding = 'constant'
elif padding == 'nearest':
# option 'nearest' is called 'edge' for np.pad
padding = 'edge'
elif padding == 'mirror':
# option 'mirror' is called 'reflect' for np.pad
padding = 'reflect'

if padding is not None:
pad_kwargs['mode'] = padding
pad_kwargs['pad_width'] = np.ceil(window_length / 2).astype(int)

pad_func = partial(
np.pad,
**pad_kwargs,
)

if method == 'moving_average':

return pl.concat_list(
[
pl.col(column).list.get(component).map(pad_func).list.explode()
.rolling_mean(window_size=window_length, center=True)
.shift(periods=pad_kwargs['pad_width'])
.slice(pad_kwargs['pad_width'] * 2)
for component in range(n_components)
],
).alias(column)

return pl.concat_list(
[
pl.col(column).list.get(component).map(pad_func).list.explode()
.ewm_mean(
span=window_length,
adjust=False,
min_periods=window_length,
).shift(periods=pad_kwargs['pad_width'])
.slice(pad_kwargs['pad_width'] * 2)
for component in range(n_components)
],
).alias(column)

if method == 'savitzky_golay':
if degree is None:
raise TypeError("'degree' must not be none for method 'savitzky_golay'")

return savitzky_golay(
window_length=window_length,
degree=degree,
sampling_rate=1,
padding=padding,
derivative=0,
n_components=n_components,
input_column=column,
output_column=None,
)

supported_methods = ['moving_average', 'exponential_moving_average', 'savitzky_golay']

raise ValueError(
f"Unknown method '{method}'. Supported methods are: {supported_methods}",
)


def _identity(x: Any) -> Any:
"""Identity function as placeholder for None as padding."""
return x


def _check_window_length(window_length: Any) -> None:
"""Check that window length is an integer and greater than zero."""
checks.check_is_not_none(window_length=window_length)
Expand Down
116 changes: 116 additions & 0 deletions tests/gaze/gaze_transform_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,6 +439,50 @@ def fixture_experiment():
),
id='pos2vel_preceding_trialize_single_column_str',
),
pytest.param(
{
'data': pl.from_dict(
{
'x_dva': [1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
'y_dva': [2.0, 2.0, 2.0, 2.0, 2.0, 2.0],
},
),
'position_columns': ['x_dva', 'y_dva'],
},
'smooth', {'method': 'moving_average', 'window_length': 3},
pm.GazeDataFrame(
data=pl.from_dict(
{
'x_dva': [1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
'y_dva': [2.0, 2.0, 2.0, 2.0, 2.0, 2.0],
},
),
position_columns=['x_dva', 'y_dva'],
),
id='smooth',
),
pytest.param(
{
'data': pl.from_dict(
{
'x_dva': [1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
'y_dva': [2.0, 2.0, 2.0, 2.0, 2.0, 2.0],
},
),
'position_columns': ['x_dva', 'y_dva'],
},
pm.gaze.transforms.smooth, {'method': 'moving_average', 'window_length': 3},
pm.GazeDataFrame(
data=pl.from_dict(
{
'x_dva': [1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
'y_dva': [2.0, 2.0, 2.0, 2.0, 2.0, 2.0],
},
),
position_columns=['x_dva', 'y_dva'],
),
id='smooth_method_pass',
),

],
)
Expand Down Expand Up @@ -750,3 +794,75 @@ def test_gaze_dataframe_pos2vel_exceptions(init_kwargs, exception, expected_msg)

msg, = excinfo.value.args
assert msg == expected_msg


@pytest.mark.parametrize(
('gaze_init_kwargs', 'kwargs', 'expected'),
[
pytest.param(
{
'data': pl.from_dict(
{
'x_pix': [0.0, 1.0, 0.0, 0.0, 1.0, 0.0],
'y_pix': [0.0, 1.0, 0.0, 0.0, 1.0, 0.0],
'x_dva': [0.0, 1.0, 0.0, 0.0, 1.0, 0.0],
'y_dva': [0.0, 1.0, 0.0, 0.0, 1.0, 0.0],
},
),
'pixel_columns': ['x_pix', 'y_pix'],
'position_columns': ['x_dva', 'y_dva'],
},
{'method': 'moving_average', 'column': 'pixel', 'window_length': 3},
pm.GazeDataFrame(
data=pl.from_dict(
{

'x_pix': [1 / 3, 1 / 3, 1 / 3, 1 / 3, 1 / 3, 1 / 3],
'y_pix': [1 / 3, 1 / 3, 1 / 3, 1 / 3, 1 / 3, 1 / 3],
'x_dva': [0.0, 1.0, 0.0, 0.0, 1.0, 0.0],
'y_dva': [0.0, 1.0, 0.0, 0.0, 1.0, 0.0],
},
),
pixel_columns=['x_pix', 'y_pix'],
position_columns=['x_dva', 'y_dva'],
),
id='pixel',
),
pytest.param(
{
'data': pl.from_dict(
{
'x_pix': [0.0, 1.0, 0.0, 0.0, 1.0, 0.0],
'y_pix': [0.0, 1.0, 0.0, 0.0, 1.0, 0.0],
'x_dva': [0.0, 1.0, 0.0, 0.0, 1.0, 0.0],
'y_dva': [0.0, 1.0, 0.0, 0.0, 1.0, 0.0],
},
),
'pixel_columns': ['x_pix', 'y_pix'],
'position_columns': ['x_dva', 'y_dva'],
},
{'method': 'moving_average', 'column': 'position', 'window_length': 3},
pm.GazeDataFrame(
data=pl.from_dict(
{

'x_pix': [0.0, 1.0, 0.0, 0.0, 1.0, 0.0],
'y_pix': [0.0, 1.0, 0.0, 0.0, 1.0, 0.0],
'x_dva': [1 / 3, 1 / 3, 1 / 3, 1 / 3, 1 / 3, 1 / 3],
'y_dva': [1 / 3, 1 / 3, 1 / 3, 1 / 3, 1 / 3, 1 / 3],
},
),
pixel_columns=['x_pix', 'y_pix'],
position_columns=['x_dva', 'y_dva'],
),
id='position',
),
],
)
def test_gaze_dataframe_smooth_expected_column(
gaze_init_kwargs, kwargs, expected,
):
gaze = pm.GazeDataFrame(**gaze_init_kwargs)
gaze.smooth(**kwargs)

assert_frame_equal(gaze.frame, expected.frame)
Loading