From d8a884e89ab9f5a011ba9daa7c85185ff90d9643 Mon Sep 17 00:00:00 2001 From: SiQube Date: Fri, 15 Sep 2023 11:28:10 -0400 Subject: [PATCH 1/3] fix mypy to upgrade matplotlib --- .../plotting/main_sequence_plot.py | 6 +- src/pymovements/plotting/traceplot.py | 81 ++++++++++++------- tests/plotting/main_sequence_plot_test.py | 25 ++++++ tests/plotting/traceplot_test.py | 4 + 4 files changed, 83 insertions(+), 33 deletions(-) diff --git a/src/pymovements/plotting/main_sequence_plot.py b/src/pymovements/plotting/main_sequence_plot.py index 551e73152..f0f05f09e 100644 --- a/src/pymovements/plotting/main_sequence_plot.py +++ b/src/pymovements/plotting/main_sequence_plot.py @@ -110,10 +110,12 @@ def main_sequence_plot( alpha=alpha, s=marker_size, marker=marker, - **kwargs, + # XXX to handle after https://github.com/pydata/xarray/pull/8030 is merged + **kwargs, # type: ignore ) - plt.title(title) + if title: + plt.title(title) plt.xlabel('Amplitude [dva]') plt.ylabel('Peak Velocity [dva/s]') diff --git a/src/pymovements/plotting/traceplot.py b/src/pymovements/plotting/traceplot.py index 9b6b62988..84156bcbb 100644 --- a/src/pymovements/plotting/traceplot.py +++ b/src/pymovements/plotting/traceplot.py @@ -21,11 +21,14 @@ from __future__ import annotations import sys +from typing import Literal +from typing import Sequence +from typing import TypeAlias -import matplotlib +import matplotlib.colors import matplotlib.pyplot as plt +import matplotlib.scale import numpy as np -from matplotlib import colors from matplotlib.collections import LineCollection from pymovements.gaze.gaze_dataframe import GazeDataFrame @@ -37,44 +40,59 @@ if 'pytest' in sys.modules: # pragma: no cover matplotlib.use('Agg') -DEFAULT_SEGMENTDATA = { +LinearSegmentedColormapType: TypeAlias = dict[ + Literal['red', 'green', 'blue', 'alpha'], Sequence[tuple[float, ...]], +] + +DEFAULT_SEGMENTDATA: LinearSegmentedColormapType = { 'red': [ - [0.0, 0.0, 0.0], - [0.5, 1.0, 1.0], - [1.0, 1.0, 1.0], + (0.0, 0.0, 0.0), + (0.5, 1.0, 1.0), + (1.0, 1.0, 1.0), ], 'green': [ - [0.0, 0.0, 0.0], - [0.5, 1.0, 1.0], - [1.0, 0.0, 0.0], + (0.0, 0.0, 0.0), + (0.5, 1.0, 1.0), + (1.0, 0.0, 0.0), ], 'blue': [ - [0.0, 0.0, 0.0], - [0.5, 0.0, 0.0], - [1.0, 0.0, 0.0], + (0.0, 0.0, 0.0), + (0.5, 0.0, 0.0), + (1.0, 0.0, 0.0), + ], + 'alpha': [ + (1.0, 1.0, 1.0), + (1.0, 1.0, 1.0), + (1.0, 1.0, 1.0), ], } -DEFAULT_SEGMENTDATA_TWOSLOPE = { +DEFAULT_SEGMENTDATA_TWOSLOPE: LinearSegmentedColormapType = { 'red': [ - [0.0, 0.0, 0.0], - [0.5, 0.0, 0.0], - [0.75, 1.0, 1.0], - [1.0, 1.0, 1.0], + (0.0, 0.0, 0.0), + (0.5, 0.0, 0.0), + (0.75, 1.0, 1.0), + (1.0, 1.0, 1.0), ], 'green': [ - [0.0, 0.0, 0.0], - [0.25, 1.0, 1.0], - [0.5, 0.0, 0.0], - [0.75, 1.0, 1.0], - [1.0, 0.0, 0.0], + (0.0, 0.0, 0.0), + (0.25, 1.0, 1.0), + (0.5, 0.0, 0.0), + (0.75, 1.0, 1.0), + (1.0, 0.0, 0.0), ], 'blue': [ - [0.0, 1.0, 1.0], - [0.25, 1.0, 1.0], - [0.5, 0.0, 0.0], - [1.0, 0.0, 0.0], + (0.0, 1.0, 1.0), + (0.25, 1.0, 1.0), + (0.5, 0.0, 0.0), + (1.0, 0.0, 0.0), + ], + 'alpha': [ + (1.0, 1.0, 1.0), + (1.0, 1.0, 1.0), + (1.0, 1.0, 1.0), + (1.0, 1.0, 1.0), ], } @@ -83,9 +101,9 @@ def traceplot( gaze: GazeDataFrame, position_column: str = 'pixel', cval: np.ndarray | None = None, # pragma: no cover - cmap: colors.Colormap | None = None, - cmap_norm: colors.Normalize | str | None = None, - cmap_segmentdata: dict[str, list[list[float]]] | None = None, + cmap: matplotlib.colors.Colormap | None = None, + cmap_norm: matplotlib.colors.Normalize | str | None = None, + cmap_segmentdata: LinearSegmentedColormapType | None = None, cbar_label: str | None = None, show_cbar: bool = False, padding: float | None = None, @@ -148,7 +166,7 @@ def traceplot( show_cbar = False cval_max = np.nanmax(np.abs(cval)) - cval_min = np.nanmin(cval) + cval_min = np.nanmin(cval).astype(float) if cmap_norm is None: if cval_max and cval_min < 0: @@ -218,7 +236,8 @@ def traceplot( # sm.set_array(cval) fig.colorbar(line, label=cbar_label, ax=ax) - ax.set_title(title) + if title: + ax.set_title(title) if savepath is not None: fig.savefig(savepath) diff --git a/tests/plotting/main_sequence_plot_test.py b/tests/plotting/main_sequence_plot_test.py index 2c752839c..f960ea3bc 100644 --- a/tests/plotting/main_sequence_plot_test.py +++ b/tests/plotting/main_sequence_plot_test.py @@ -169,6 +169,31 @@ def test_main_sequence_plot_not_show(input_df, show, monkeypatch): mock_function.assert_not_called() +@pytest.mark.parametrize( + ('input_df', 'title'), + [ + pytest.param( + EventDataFrame( + pl.DataFrame( + { + 'amplitude': np.arange(100), + 'peak_velocity': np.linspace(10, 50, num=100), + 'name': ['saccade' for _ in range(100)], + }, + ), + ), + 'foo', + id='do_not_show_plot', + ), + ], +) +def test_main_sequence_plot_set_title(input_df, title, monkeypatch): + mock_function = Mock() + monkeypatch.setattr(plt, 'title', mock_function) + main_sequence_plot(input_df, title=title) + plt.close() + + @pytest.mark.parametrize( ('input_df', 'expected_error', 'error_msg'), [ diff --git a/tests/plotting/traceplot_test.py b/tests/plotting/traceplot_test.py index 0d1b2a3fd..8ded6fef2 100644 --- a/tests/plotting/traceplot_test.py +++ b/tests/plotting/traceplot_test.py @@ -109,6 +109,10 @@ def gaze_fixture(): {'cval': np.arange(0, 200), 'show_cbar': False}, id='show_cbar_false', ), + pytest.param( + {'cval': np.arange(0, 200), 'title': 'foo'}, + id='set_title', + ), ], ) def test_traceplot_show(gaze, kwargs, monkeypatch): From 68a15e78297acda520b65aa9bdc71d229efb203e Mon Sep 17 00:00:00 2001 From: SiQube Date: Fri, 15 Sep 2023 12:45:00 -0400 Subject: [PATCH 2/3] update tox.ini --- tox.ini | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tox.ini b/tox.ini index 973f1e4e2..053cbad63 100644 --- a/tox.ini +++ b/tox.ini @@ -86,9 +86,9 @@ commands = pre-commit run --all-files --show-diff-on-failure changedir = {toxinidir} deps = mypy - .[test] pandas-stubs types-all + types-tqdm commands = mypy {toxinidir}/src From 90bc5571bc58c55abf4c3b3120f96c8cb27755a1 Mon Sep 17 00:00:00 2001 From: SiQube Date: Fri, 15 Sep 2023 13:09:07 -0400 Subject: [PATCH 3/3] always test 3.8 and 3.11, typealias not a thing yet --- .pre-commit-config.yaml | 2 +- pyproject.toml | 4 ++-- src/pymovements/plotting/main_sequence_plot.py | 2 +- src/pymovements/plotting/traceplot.py | 13 +++++++++---- 4 files changed, 13 insertions(+), 8 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 90b65b910..12d595fc6 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -67,7 +67,7 @@ repos: rev: v1.5.1 hooks: - id: mypy - additional_dependencies: [types-all] + additional_dependencies: [types-all, pandas-stubs, types-tqdm] - repo: https://github.com/pre-commit/pre-commit-hooks rev: v4.4.0 hooks: diff --git a/pyproject.toml b/pyproject.toml index d50f1b34e..3881ad9f1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -61,7 +61,8 @@ test = [ "pybtex", "pytest>=6.0.0", "pytest-cov>=4.0.0", - "types-tqdm" + "types-tqdm", + "typing_extensions" ] [project.urls] @@ -110,7 +111,6 @@ check_untyped_defs = true disallow_incomplete_defs = true disallow_untyped_defs = true warn_redundant_casts = true -warn_unused_ignores = true [[tool.mypy.overrides]] module = "scipy.*" diff --git a/src/pymovements/plotting/main_sequence_plot.py b/src/pymovements/plotting/main_sequence_plot.py index f0f05f09e..7edc7238b 100644 --- a/src/pymovements/plotting/main_sequence_plot.py +++ b/src/pymovements/plotting/main_sequence_plot.py @@ -110,7 +110,7 @@ def main_sequence_plot( alpha=alpha, s=marker_size, marker=marker, - # XXX to handle after https://github.com/pydata/xarray/pull/8030 is merged + # to handle after https://github.com/pydata/xarray/pull/8030 is merged **kwargs, # type: ignore ) diff --git a/src/pymovements/plotting/traceplot.py b/src/pymovements/plotting/traceplot.py index 84156bcbb..6839f840c 100644 --- a/src/pymovements/plotting/traceplot.py +++ b/src/pymovements/plotting/traceplot.py @@ -21,15 +21,17 @@ from __future__ import annotations import sys +from typing import Dict from typing import Literal from typing import Sequence -from typing import TypeAlias +from typing import Tuple import matplotlib.colors import matplotlib.pyplot as plt import matplotlib.scale import numpy as np from matplotlib.collections import LineCollection +from typing_extensions import TypeAlias from pymovements.gaze.gaze_dataframe import GazeDataFrame @@ -40,8 +42,8 @@ if 'pytest' in sys.modules: # pragma: no cover matplotlib.use('Agg') -LinearSegmentedColormapType: TypeAlias = dict[ - Literal['red', 'green', 'blue', 'alpha'], Sequence[tuple[float, ...]], +LinearSegmentedColormapType: TypeAlias = Dict[ + Literal['red', 'green', 'blue', 'alpha'], Sequence[Tuple[float, ...]], ] DEFAULT_SEGMENTDATA: LinearSegmentedColormapType = { @@ -201,7 +203,10 @@ def traceplot( elif isinstance(cmap_norm, str): # pylint: disable=protected-access - if (scale_class := matplotlib.scale._scale_mapping.get(cmap_norm, None)) is None: + # to handle after https://github.com/pydata/xarray/pull/8030 is merged + if ( + scale_class := matplotlib.scale._scale_mapping.get(cmap_norm, None) # type: ignore + ) is None: raise ValueError(f'cmap_norm string {cmap_norm} is not supported') norm_class = matplotlib.colors.make_norm_from_scale(scale_class)