diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 66545eacd..7713bf914 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -67,7 +67,7 @@ repos: rev: v1.5.1 hooks: - id: mypy - additional_dependencies: [types-all] + additional_dependencies: [types-all, pandas-stubs, types-tqdm] - repo: https://github.com/pre-commit/pre-commit-hooks rev: v4.4.0 hooks: diff --git a/pyproject.toml b/pyproject.toml index d8e38dc23..ad9c041ac 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -61,7 +61,8 @@ test = [ "pybtex", "pytest>=6.0.0", "pytest-cov>=4.0.0", - "types-tqdm" + "types-tqdm", + "typing_extensions" ] [project.urls] @@ -110,7 +111,6 @@ check_untyped_defs = true disallow_incomplete_defs = true disallow_untyped_defs = true warn_redundant_casts = true -warn_unused_ignores = true [[tool.mypy.overrides]] module = "scipy.*" diff --git a/src/pymovements/plotting/main_sequence_plot.py b/src/pymovements/plotting/main_sequence_plot.py index 551e73152..7edc7238b 100644 --- a/src/pymovements/plotting/main_sequence_plot.py +++ b/src/pymovements/plotting/main_sequence_plot.py @@ -110,10 +110,12 @@ def main_sequence_plot( alpha=alpha, s=marker_size, marker=marker, - **kwargs, + # to handle after https://github.com/pydata/xarray/pull/8030 is merged + **kwargs, # type: ignore ) - plt.title(title) + if title: + plt.title(title) plt.xlabel('Amplitude [dva]') plt.ylabel('Peak Velocity [dva/s]') diff --git a/src/pymovements/plotting/traceplot.py b/src/pymovements/plotting/traceplot.py index 9b6b62988..6839f840c 100644 --- a/src/pymovements/plotting/traceplot.py +++ b/src/pymovements/plotting/traceplot.py @@ -21,12 +21,17 @@ from __future__ import annotations import sys +from typing import Dict +from typing import Literal +from typing import Sequence +from typing import Tuple -import matplotlib +import matplotlib.colors import matplotlib.pyplot as plt +import matplotlib.scale import numpy as np -from matplotlib import colors from matplotlib.collections import LineCollection +from typing_extensions import TypeAlias from pymovements.gaze.gaze_dataframe import GazeDataFrame @@ -37,44 +42,59 @@ if 'pytest' in sys.modules: # pragma: no cover matplotlib.use('Agg') -DEFAULT_SEGMENTDATA = { +LinearSegmentedColormapType: TypeAlias = Dict[ + Literal['red', 'green', 'blue', 'alpha'], Sequence[Tuple[float, ...]], +] + +DEFAULT_SEGMENTDATA: LinearSegmentedColormapType = { 'red': [ - [0.0, 0.0, 0.0], - [0.5, 1.0, 1.0], - [1.0, 1.0, 1.0], + (0.0, 0.0, 0.0), + (0.5, 1.0, 1.0), + (1.0, 1.0, 1.0), ], 'green': [ - [0.0, 0.0, 0.0], - [0.5, 1.0, 1.0], - [1.0, 0.0, 0.0], + (0.0, 0.0, 0.0), + (0.5, 1.0, 1.0), + (1.0, 0.0, 0.0), ], 'blue': [ - [0.0, 0.0, 0.0], - [0.5, 0.0, 0.0], - [1.0, 0.0, 0.0], + (0.0, 0.0, 0.0), + (0.5, 0.0, 0.0), + (1.0, 0.0, 0.0), + ], + 'alpha': [ + (1.0, 1.0, 1.0), + (1.0, 1.0, 1.0), + (1.0, 1.0, 1.0), ], } -DEFAULT_SEGMENTDATA_TWOSLOPE = { +DEFAULT_SEGMENTDATA_TWOSLOPE: LinearSegmentedColormapType = { 'red': [ - [0.0, 0.0, 0.0], - [0.5, 0.0, 0.0], - [0.75, 1.0, 1.0], - [1.0, 1.0, 1.0], + (0.0, 0.0, 0.0), + (0.5, 0.0, 0.0), + (0.75, 1.0, 1.0), + (1.0, 1.0, 1.0), ], 'green': [ - [0.0, 0.0, 0.0], - [0.25, 1.0, 1.0], - [0.5, 0.0, 0.0], - [0.75, 1.0, 1.0], - [1.0, 0.0, 0.0], + (0.0, 0.0, 0.0), + (0.25, 1.0, 1.0), + (0.5, 0.0, 0.0), + (0.75, 1.0, 1.0), + (1.0, 0.0, 0.0), ], 'blue': [ - [0.0, 1.0, 1.0], - [0.25, 1.0, 1.0], - [0.5, 0.0, 0.0], - [1.0, 0.0, 0.0], + (0.0, 1.0, 1.0), + (0.25, 1.0, 1.0), + (0.5, 0.0, 0.0), + (1.0, 0.0, 0.0), + ], + 'alpha': [ + (1.0, 1.0, 1.0), + (1.0, 1.0, 1.0), + (1.0, 1.0, 1.0), + (1.0, 1.0, 1.0), ], } @@ -83,9 +103,9 @@ def traceplot( gaze: GazeDataFrame, position_column: str = 'pixel', cval: np.ndarray | None = None, # pragma: no cover - cmap: colors.Colormap | None = None, - cmap_norm: colors.Normalize | str | None = None, - cmap_segmentdata: dict[str, list[list[float]]] | None = None, + cmap: matplotlib.colors.Colormap | None = None, + cmap_norm: matplotlib.colors.Normalize | str | None = None, + cmap_segmentdata: LinearSegmentedColormapType | None = None, cbar_label: str | None = None, show_cbar: bool = False, padding: float | None = None, @@ -148,7 +168,7 @@ def traceplot( show_cbar = False cval_max = np.nanmax(np.abs(cval)) - cval_min = np.nanmin(cval) + cval_min = np.nanmin(cval).astype(float) if cmap_norm is None: if cval_max and cval_min < 0: @@ -183,7 +203,10 @@ def traceplot( elif isinstance(cmap_norm, str): # pylint: disable=protected-access - if (scale_class := matplotlib.scale._scale_mapping.get(cmap_norm, None)) is None: + # to handle after https://github.com/pydata/xarray/pull/8030 is merged + if ( + scale_class := matplotlib.scale._scale_mapping.get(cmap_norm, None) # type: ignore + ) is None: raise ValueError(f'cmap_norm string {cmap_norm} is not supported') norm_class = matplotlib.colors.make_norm_from_scale(scale_class) @@ -218,7 +241,8 @@ def traceplot( # sm.set_array(cval) fig.colorbar(line, label=cbar_label, ax=ax) - ax.set_title(title) + if title: + ax.set_title(title) if savepath is not None: fig.savefig(savepath) diff --git a/tests/plotting/main_sequence_plot_test.py b/tests/plotting/main_sequence_plot_test.py index 2c752839c..f960ea3bc 100644 --- a/tests/plotting/main_sequence_plot_test.py +++ b/tests/plotting/main_sequence_plot_test.py @@ -169,6 +169,31 @@ def test_main_sequence_plot_not_show(input_df, show, monkeypatch): mock_function.assert_not_called() +@pytest.mark.parametrize( + ('input_df', 'title'), + [ + pytest.param( + EventDataFrame( + pl.DataFrame( + { + 'amplitude': np.arange(100), + 'peak_velocity': np.linspace(10, 50, num=100), + 'name': ['saccade' for _ in range(100)], + }, + ), + ), + 'foo', + id='do_not_show_plot', + ), + ], +) +def test_main_sequence_plot_set_title(input_df, title, monkeypatch): + mock_function = Mock() + monkeypatch.setattr(plt, 'title', mock_function) + main_sequence_plot(input_df, title=title) + plt.close() + + @pytest.mark.parametrize( ('input_df', 'expected_error', 'error_msg'), [ diff --git a/tests/plotting/traceplot_test.py b/tests/plotting/traceplot_test.py index 0d1b2a3fd..8ded6fef2 100644 --- a/tests/plotting/traceplot_test.py +++ b/tests/plotting/traceplot_test.py @@ -109,6 +109,10 @@ def gaze_fixture(): {'cval': np.arange(0, 200), 'show_cbar': False}, id='show_cbar_false', ), + pytest.param( + {'cval': np.arange(0, 200), 'title': 'foo'}, + id='set_title', + ), ], ) def test_traceplot_show(gaze, kwargs, monkeypatch): diff --git a/tox.ini b/tox.ini index 50571bb92..5e5925bf9 100644 --- a/tox.ini +++ b/tox.ini @@ -80,9 +80,9 @@ commands = changedir = {toxinidir} deps = mypy - .[test] pandas-stubs types-all + types-tqdm commands = mypy {toxinidir}/src