Skip to content

Commit

Permalink
fix mypy to upgrade matplotlib
Browse files Browse the repository at this point in the history
  • Loading branch information
SiQube committed Sep 15, 2023
1 parent 9a10c04 commit d8a884e
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 33 deletions.
6 changes: 4 additions & 2 deletions src/pymovements/plotting/main_sequence_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,10 +110,12 @@ def main_sequence_plot(
alpha=alpha,
s=marker_size,
marker=marker,
**kwargs,
# XXX to handle after https://github.com/pydata/xarray/pull/8030 is merged
**kwargs, # type: ignore
)

plt.title(title)
if title:
plt.title(title)

Check warning on line 118 in src/pymovements/plotting/main_sequence_plot.py

View check run for this annotation

Codecov / codecov/patch

src/pymovements/plotting/main_sequence_plot.py#L118

Added line #L118 was not covered by tests
plt.xlabel('Amplitude [dva]')
plt.ylabel('Peak Velocity [dva/s]')

Expand Down
81 changes: 50 additions & 31 deletions src/pymovements/plotting/traceplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,14 @@
from __future__ import annotations

import sys
from typing import Literal
from typing import Sequence
from typing import TypeAlias

import matplotlib
import matplotlib.colors

Check warning on line 28 in src/pymovements/plotting/traceplot.py

View check run for this annotation

Codecov / codecov/patch

src/pymovements/plotting/traceplot.py#L28

Added line #L28 was not covered by tests
import matplotlib.pyplot as plt
import matplotlib.scale

Check warning on line 30 in src/pymovements/plotting/traceplot.py

View check run for this annotation

Codecov / codecov/patch

src/pymovements/plotting/traceplot.py#L30

Added line #L30 was not covered by tests
import numpy as np
from matplotlib import colors
from matplotlib.collections import LineCollection

from pymovements.gaze.gaze_dataframe import GazeDataFrame
Expand All @@ -37,44 +40,59 @@
if 'pytest' in sys.modules: # pragma: no cover
matplotlib.use('Agg')

DEFAULT_SEGMENTDATA = {
LinearSegmentedColormapType: TypeAlias = dict[

Check warning on line 43 in src/pymovements/plotting/traceplot.py

View check run for this annotation

Codecov / codecov/patch

src/pymovements/plotting/traceplot.py#L43

Added line #L43 was not covered by tests
Literal['red', 'green', 'blue', 'alpha'], Sequence[tuple[float, ...]],
]

DEFAULT_SEGMENTDATA: LinearSegmentedColormapType = {

Check warning on line 47 in src/pymovements/plotting/traceplot.py

View check run for this annotation

Codecov / codecov/patch

src/pymovements/plotting/traceplot.py#L47

Added line #L47 was not covered by tests
'red': [
[0.0, 0.0, 0.0],
[0.5, 1.0, 1.0],
[1.0, 1.0, 1.0],
(0.0, 0.0, 0.0),
(0.5, 1.0, 1.0),
(1.0, 1.0, 1.0),
],
'green': [
[0.0, 0.0, 0.0],
[0.5, 1.0, 1.0],
[1.0, 0.0, 0.0],
(0.0, 0.0, 0.0),
(0.5, 1.0, 1.0),
(1.0, 0.0, 0.0),
],
'blue': [
[0.0, 0.0, 0.0],
[0.5, 0.0, 0.0],
[1.0, 0.0, 0.0],
(0.0, 0.0, 0.0),
(0.5, 0.0, 0.0),
(1.0, 0.0, 0.0),
],
'alpha': [
(1.0, 1.0, 1.0),
(1.0, 1.0, 1.0),
(1.0, 1.0, 1.0),
],
}


DEFAULT_SEGMENTDATA_TWOSLOPE = {
DEFAULT_SEGMENTDATA_TWOSLOPE: LinearSegmentedColormapType = {

Check warning on line 71 in src/pymovements/plotting/traceplot.py

View check run for this annotation

Codecov / codecov/patch

src/pymovements/plotting/traceplot.py#L71

Added line #L71 was not covered by tests
'red': [
[0.0, 0.0, 0.0],
[0.5, 0.0, 0.0],
[0.75, 1.0, 1.0],
[1.0, 1.0, 1.0],
(0.0, 0.0, 0.0),
(0.5, 0.0, 0.0),
(0.75, 1.0, 1.0),
(1.0, 1.0, 1.0),
],
'green': [
[0.0, 0.0, 0.0],
[0.25, 1.0, 1.0],
[0.5, 0.0, 0.0],
[0.75, 1.0, 1.0],
[1.0, 0.0, 0.0],
(0.0, 0.0, 0.0),
(0.25, 1.0, 1.0),
(0.5, 0.0, 0.0),
(0.75, 1.0, 1.0),
(1.0, 0.0, 0.0),
],
'blue': [
[0.0, 1.0, 1.0],
[0.25, 1.0, 1.0],
[0.5, 0.0, 0.0],
[1.0, 0.0, 0.0],
(0.0, 1.0, 1.0),
(0.25, 1.0, 1.0),
(0.5, 0.0, 0.0),
(1.0, 0.0, 0.0),
],
'alpha': [
(1.0, 1.0, 1.0),
(1.0, 1.0, 1.0),
(1.0, 1.0, 1.0),
(1.0, 1.0, 1.0),
],
}

Expand All @@ -83,9 +101,9 @@ def traceplot(
gaze: GazeDataFrame,
position_column: str = 'pixel',
cval: np.ndarray | None = None, # pragma: no cover
cmap: colors.Colormap | None = None,
cmap_norm: colors.Normalize | str | None = None,
cmap_segmentdata: dict[str, list[list[float]]] | None = None,
cmap: matplotlib.colors.Colormap | None = None,
cmap_norm: matplotlib.colors.Normalize | str | None = None,
cmap_segmentdata: LinearSegmentedColormapType | None = None,
cbar_label: str | None = None,
show_cbar: bool = False,
padding: float | None = None,
Expand Down Expand Up @@ -148,7 +166,7 @@ def traceplot(
show_cbar = False

cval_max = np.nanmax(np.abs(cval))
cval_min = np.nanmin(cval)
cval_min = np.nanmin(cval).astype(float)

Check warning on line 169 in src/pymovements/plotting/traceplot.py

View check run for this annotation

Codecov / codecov/patch

src/pymovements/plotting/traceplot.py#L169

Added line #L169 was not covered by tests

if cmap_norm is None:
if cval_max and cval_min < 0:
Expand Down Expand Up @@ -218,7 +236,8 @@ def traceplot(
# sm.set_array(cval)
fig.colorbar(line, label=cbar_label, ax=ax)

ax.set_title(title)
if title:
ax.set_title(title)

Check warning on line 240 in src/pymovements/plotting/traceplot.py

View check run for this annotation

Codecov / codecov/patch

src/pymovements/plotting/traceplot.py#L240

Added line #L240 was not covered by tests

if savepath is not None:
fig.savefig(savepath)
Expand Down
25 changes: 25 additions & 0 deletions tests/plotting/main_sequence_plot_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,31 @@ def test_main_sequence_plot_not_show(input_df, show, monkeypatch):
mock_function.assert_not_called()


@pytest.mark.parametrize(
('input_df', 'title'),
[
pytest.param(
EventDataFrame(
pl.DataFrame(
{
'amplitude': np.arange(100),
'peak_velocity': np.linspace(10, 50, num=100),
'name': ['saccade' for _ in range(100)],
},
),
),
'foo',
id='do_not_show_plot',
),
],
)
def test_main_sequence_plot_set_title(input_df, title, monkeypatch):
mock_function = Mock()
monkeypatch.setattr(plt, 'title', mock_function)
main_sequence_plot(input_df, title=title)
plt.close()


@pytest.mark.parametrize(
('input_df', 'expected_error', 'error_msg'),
[
Expand Down
4 changes: 4 additions & 0 deletions tests/plotting/traceplot_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,10 @@ def gaze_fixture():
{'cval': np.arange(0, 200), 'show_cbar': False},
id='show_cbar_false',
),
pytest.param(
{'cval': np.arange(0, 200), 'title': 'foo'},
id='set_title',
),
],
)
def test_traceplot_show(gaze, kwargs, monkeypatch):
Expand Down

0 comments on commit d8a884e

Please sign in to comment.