diff --git a/src/pymovements/plotting/main_sequence_plot.py b/src/pymovements/plotting/main_sequence_plot.py index 551e73152..f0f05f09e 100644 --- a/src/pymovements/plotting/main_sequence_plot.py +++ b/src/pymovements/plotting/main_sequence_plot.py @@ -110,10 +110,12 @@ def main_sequence_plot( alpha=alpha, s=marker_size, marker=marker, - **kwargs, + # XXX to handle after https://github.com/pydata/xarray/pull/8030 is merged + **kwargs, # type: ignore ) - plt.title(title) + if title: + plt.title(title) plt.xlabel('Amplitude [dva]') plt.ylabel('Peak Velocity [dva/s]') diff --git a/src/pymovements/plotting/traceplot.py b/src/pymovements/plotting/traceplot.py index 9b6b62988..84156bcbb 100644 --- a/src/pymovements/plotting/traceplot.py +++ b/src/pymovements/plotting/traceplot.py @@ -21,11 +21,14 @@ from __future__ import annotations import sys +from typing import Literal +from typing import Sequence +from typing import TypeAlias -import matplotlib +import matplotlib.colors import matplotlib.pyplot as plt +import matplotlib.scale import numpy as np -from matplotlib import colors from matplotlib.collections import LineCollection from pymovements.gaze.gaze_dataframe import GazeDataFrame @@ -37,44 +40,59 @@ if 'pytest' in sys.modules: # pragma: no cover matplotlib.use('Agg') -DEFAULT_SEGMENTDATA = { +LinearSegmentedColormapType: TypeAlias = dict[ + Literal['red', 'green', 'blue', 'alpha'], Sequence[tuple[float, ...]], +] + +DEFAULT_SEGMENTDATA: LinearSegmentedColormapType = { 'red': [ - [0.0, 0.0, 0.0], - [0.5, 1.0, 1.0], - [1.0, 1.0, 1.0], + (0.0, 0.0, 0.0), + (0.5, 1.0, 1.0), + (1.0, 1.0, 1.0), ], 'green': [ - [0.0, 0.0, 0.0], - [0.5, 1.0, 1.0], - [1.0, 0.0, 0.0], + (0.0, 0.0, 0.0), + (0.5, 1.0, 1.0), + (1.0, 0.0, 0.0), ], 'blue': [ - [0.0, 0.0, 0.0], - [0.5, 0.0, 0.0], - [1.0, 0.0, 0.0], + (0.0, 0.0, 0.0), + (0.5, 0.0, 0.0), + (1.0, 0.0, 0.0), + ], + 'alpha': [ + (1.0, 1.0, 1.0), + (1.0, 1.0, 1.0), + (1.0, 1.0, 1.0), ], } -DEFAULT_SEGMENTDATA_TWOSLOPE = { +DEFAULT_SEGMENTDATA_TWOSLOPE: LinearSegmentedColormapType = { 'red': [ - [0.0, 0.0, 0.0], - [0.5, 0.0, 0.0], - [0.75, 1.0, 1.0], - [1.0, 1.0, 1.0], + (0.0, 0.0, 0.0), + (0.5, 0.0, 0.0), + (0.75, 1.0, 1.0), + (1.0, 1.0, 1.0), ], 'green': [ - [0.0, 0.0, 0.0], - [0.25, 1.0, 1.0], - [0.5, 0.0, 0.0], - [0.75, 1.0, 1.0], - [1.0, 0.0, 0.0], + (0.0, 0.0, 0.0), + (0.25, 1.0, 1.0), + (0.5, 0.0, 0.0), + (0.75, 1.0, 1.0), + (1.0, 0.0, 0.0), ], 'blue': [ - [0.0, 1.0, 1.0], - [0.25, 1.0, 1.0], - [0.5, 0.0, 0.0], - [1.0, 0.0, 0.0], + (0.0, 1.0, 1.0), + (0.25, 1.0, 1.0), + (0.5, 0.0, 0.0), + (1.0, 0.0, 0.0), + ], + 'alpha': [ + (1.0, 1.0, 1.0), + (1.0, 1.0, 1.0), + (1.0, 1.0, 1.0), + (1.0, 1.0, 1.0), ], } @@ -83,9 +101,9 @@ def traceplot( gaze: GazeDataFrame, position_column: str = 'pixel', cval: np.ndarray | None = None, # pragma: no cover - cmap: colors.Colormap | None = None, - cmap_norm: colors.Normalize | str | None = None, - cmap_segmentdata: dict[str, list[list[float]]] | None = None, + cmap: matplotlib.colors.Colormap | None = None, + cmap_norm: matplotlib.colors.Normalize | str | None = None, + cmap_segmentdata: LinearSegmentedColormapType | None = None, cbar_label: str | None = None, show_cbar: bool = False, padding: float | None = None, @@ -148,7 +166,7 @@ def traceplot( show_cbar = False cval_max = np.nanmax(np.abs(cval)) - cval_min = np.nanmin(cval) + cval_min = np.nanmin(cval).astype(float) if cmap_norm is None: if cval_max and cval_min < 0: @@ -218,7 +236,8 @@ def traceplot( # sm.set_array(cval) fig.colorbar(line, label=cbar_label, ax=ax) - ax.set_title(title) + if title: + ax.set_title(title) if savepath is not None: fig.savefig(savepath) diff --git a/tests/plotting/main_sequence_plot_test.py b/tests/plotting/main_sequence_plot_test.py index 2c752839c..f960ea3bc 100644 --- a/tests/plotting/main_sequence_plot_test.py +++ b/tests/plotting/main_sequence_plot_test.py @@ -169,6 +169,31 @@ def test_main_sequence_plot_not_show(input_df, show, monkeypatch): mock_function.assert_not_called() +@pytest.mark.parametrize( + ('input_df', 'title'), + [ + pytest.param( + EventDataFrame( + pl.DataFrame( + { + 'amplitude': np.arange(100), + 'peak_velocity': np.linspace(10, 50, num=100), + 'name': ['saccade' for _ in range(100)], + }, + ), + ), + 'foo', + id='do_not_show_plot', + ), + ], +) +def test_main_sequence_plot_set_title(input_df, title, monkeypatch): + mock_function = Mock() + monkeypatch.setattr(plt, 'title', mock_function) + main_sequence_plot(input_df, title=title) + plt.close() + + @pytest.mark.parametrize( ('input_df', 'expected_error', 'error_msg'), [ diff --git a/tests/plotting/traceplot_test.py b/tests/plotting/traceplot_test.py index 0d1b2a3fd..8ded6fef2 100644 --- a/tests/plotting/traceplot_test.py +++ b/tests/plotting/traceplot_test.py @@ -109,6 +109,10 @@ def gaze_fixture(): {'cval': np.arange(0, 200), 'show_cbar': False}, id='show_cbar_false', ), + pytest.param( + {'cval': np.arange(0, 200), 'title': 'foo'}, + id='set_title', + ), ], ) def test_traceplot_show(gaze, kwargs, monkeypatch):