diff --git a/mir_eval/display.py b/mir_eval/display.py index a67ae326..f5d5e1ef 100644 --- a/mir_eval/display.py +++ b/mir_eval/display.py @@ -2,40 +2,26 @@ """Display functions""" from collections import defaultdict +from weakref import WeakKeyDictionary import numpy as np from scipy.signal import spectrogram +import matplotlib as mpl from matplotlib.patches import Rectangle from matplotlib.ticker import FuncFormatter, MultipleLocator from matplotlib.ticker import Formatter from matplotlib.colors import LinearSegmentedColormap, LogNorm, ColorConverter from matplotlib.collections import BrokenBarHCollection +from matplotlib.transforms import Bbox, TransformedBbox from .melody import freq_to_voicing from .util import midi_to_hz, hz_to_midi -def __expand_limits(ax, limits, which="x"): - """Expand axis limits""" - if which == "x": - getter, setter = ax.get_xlim, ax.set_xlim - elif which == "y": - getter, setter = ax.get_ylim, ax.set_ylim - else: - raise ValueError("invalid axis: {}".format(which)) - - old_lims = getter() - new_lims = list(limits) - - # infinite limits occur on new axis objects with no data - if np.isfinite(old_lims[0]): - new_lims[0] = min(old_lims[0], limits[0]) - - if np.isfinite(old_lims[1]): - new_lims[1] = max(old_lims[1], limits[1]) - - setter(new_lims) +# This dictionary is used to track mir_eval-specific attributes +# attached to matplotlib axes +__AXMAP = WeakKeyDictionary() def __get_axes(ax=None, fig=None): @@ -62,18 +48,21 @@ def __get_axes(ax=None, fig=None): """ new_axes = False - if ax is not None: - return ax, new_axes + if ax is None: + if fig is None: + import matplotlib.pyplot as plt - if fig is None: - import matplotlib.pyplot as plt + fig = plt.gcf() - fig = plt.gcf() + if not fig.get_axes(): + new_axes = True + ax = fig.gca() - if not fig.get_axes(): - new_axes = True + # Create a storage bucket for this axes in case we need it + if ax not in __AXMAP: + __AXMAP[ax] = dict() - return fig.gca(), new_axes + return ax, new_axes def segments( @@ -84,6 +73,7 @@ def segments( text=False, text_kw=None, ax=None, + prop_cycle=None, **kwargs ): """Plot a segmentation as a set of disjoint rectangles. @@ -103,6 +93,7 @@ def segments( height : number The height of the rectangles. By default, this will be the top of the plot (minus ``base``). + .. note:: If either `base` or `height` are provided, both must be provided. text : bool If true, each segment's label is displayed in its upper-left corner @@ -113,6 +104,9 @@ def segments( ax : matplotlib.pyplot.axes An axis handle on which to draw the segmentation. If none is provided, a new set of axes is created. + prop_cycle : cycle.Cycler + An optional property cycle object to specify style properties. + If not provided, the default property cycler will be retrieved from matplotlib. **kwargs Additional keyword arguments to pass to ``matplotlib.patches.Rectangle``. @@ -135,17 +129,29 @@ def segments( ax, new_axes = __get_axes(ax=ax) - if new_axes: - ax.set_ylim([0, 1]) + if prop_cycle is None: + __AXMAP[ax].setdefault("prop_cycle", mpl.rcParams["axes.prop_cycle"]) + __AXMAP[ax].setdefault("prop_iter", iter(mpl.rcParams["axes.prop_cycle"])) + elif "prop_iter" not in __AXMAP[ax]: + __AXMAP[ax]["prop_cycle"] = prop_cycle + __AXMAP[ax]["prop_iter"] = iter(prop_cycle) - # Infer height - if base is None: - base = ax.get_ylim()[0] + prop_cycle = __AXMAP[ax]["prop_cycle"] + prop_iter = __AXMAP[ax]["prop_iter"] - if height is None: - height = ax.get_ylim()[1] + if new_axes: + ax.set_yticks([]) - cycler = ax._get_patches_for_fill.prop_cycler + if base is None and height is None: + # If neither are provided, we'll use axes coordinates to span the figure + base, height = 0, 1 + transform = ax.get_xaxis_transform() + + elif base is not None and height is not None: + # If both are provided, we'll use data coordinates + transform = None + else: + raise ValueError("When specifying base or height, both must be provided.") seg_map = dict() @@ -153,36 +159,43 @@ def segments( if lab in seg_map: continue - style = next(cycler) + try: + properties = next(prop_iter) + except StopIteration: + prop_iter = iter(prop_cycle) + __AXMAP[ax]["prop_iter"] = prop_iter + properties = next(prop_iter) + + style = { + k: v + for k, v in properties.items() + if k in ["color", "facecolor", "edgecolor", "linewidth"] + } + # Swap color -> facecolor here so we preserve edgecolor on rects + style.setdefault("facecolor", style["color"]) + style.pop("color", None) seg_map[lab] = seg_def_style.copy() seg_map[lab].update(style) - # Swap color -> facecolor here so we preserve edgecolor on rects - seg_map[lab]["facecolor"] = seg_map[lab].pop("color") seg_map[lab].update(kwargs) seg_map[lab]["label"] = lab for ival, lab in zip(intervals, labels): - rect = Rectangle((ival[0], base), ival[1] - ival[0], height, **seg_map[lab]) - ax.add_patch(rect) + rect = ax.axvspan(ival[0], ival[1], ymin=base, ymax=height, **seg_map[lab]) seg_map[lab].pop("label", None) if text: + bbox = Bbox.from_extents(ival[0], base, ival[1], height) + tbbox = TransformedBbox(bbox, transform) ann = ax.annotate( lab, xy=(ival[0], height), - xycoords="data", + xycoords=transform, xytext=(8, -10), textcoords="offset points", + clip_path=rect, + clip_box=tbbox, **text_kw ) - ann.set_clip_path(rect) - - if new_axes: - ax.set_yticks([]) - - # Only expand if we have data - if intervals.size: - __expand_limits(ax, [intervals.min(), intervals.max()], which="x") return ax @@ -196,6 +209,7 @@ def labeled_intervals( extend_labels=True, ax=None, tick=True, + prop_cycle=None, **kwargs ): """Plot labeled intervals with each label on its own row. @@ -244,6 +258,10 @@ def labeled_intervals( tick : bool If ``True``, sets tick positions and labels on the y-axis. + prop_cycle : cycle.Cycler + An optional property cycle object to specify style properties. + If not provided, the default property cycler will be retrieved from matplotlib. + **kwargs Additional keyword arguments to pass to `matplotlib.collection.BrokenBarHCollection`. @@ -254,33 +272,59 @@ def labeled_intervals( A handle to the (possibly constructed) plot axes """ # Get the axes handle - ax, _ = __get_axes(ax=ax) + ax, new_axes = __get_axes(ax=ax) + + if prop_cycle is None: + __AXMAP[ax].setdefault("prop_cycle", mpl.rcParams["axes.prop_cycle"]) + __AXMAP[ax].setdefault("prop_iter", iter(mpl.rcParams["axes.prop_cycle"])) + elif "prop_iter" not in __AXMAP[ax]: + __AXMAP[ax]["prop_cycle"] = prop_cycle + __AXMAP[ax]["prop_iter"] = iter(prop_cycle) + + prop_cycle = __AXMAP[ax]["prop_cycle"] + prop_iter = __AXMAP[ax]["prop_iter"] # Make sure we have a numpy array intervals = np.atleast_2d(intervals) if label_set is None: # If we have non-empty pre-existing tick labels, use them - label_set = [_.get_text() for _ in ax.get_yticklabels()] # If none of the label strings have content, treat it as empty - if not any(label_set): - label_set = [] + label_set = __AXMAP[ax].get("labels", []) else: label_set = list(label_set) # Put additional labels at the end, in order + extended = False if extend_labels: ticks = label_set + sorted(set(labels) - set(label_set)) + if ticks != label_set and len(label_set) > 0: + extended = True elif label_set: ticks = label_set else: ticks = sorted(set(labels)) + # Push the ticks up into the axmap + __AXMAP[ax]["labels"] = ticks + style = dict(linewidth=1) - style.update(next(ax._get_patches_for_fill.prop_cycler)) + try: + properties = next(prop_iter) + except StopIteration: + prop_iter = iter(prop_cycle) + __AXMAP[ax]["prop_iter"] = prop_iter + properties = next(prop_iter) + + style = { + k: v + for k, v in properties.items() + if k in ["color", "facecolor", "edgecolor", "linewidth"] + } # Swap color -> facecolor here so we preserve edgecolor on rects - style["facecolor"] = style.pop("color") + style.setdefault("facecolor", style["color"]) + style.pop("color", None) style.update(kwargs) if base is None: @@ -303,13 +347,13 @@ def labeled_intervals( xvals[lab].append((ival[0], ival[1] - ival[0])) for lab in seg_y: - ax.add_collection(BrokenBarHCollection(xvals[lab], seg_y[lab], **style)) + ax.broken_barh(xvals[lab], seg_y[lab], **style) # Pop the label after the first time we see it, so we only get # one legend entry style.pop("label", None) # Draw a line separating the new labels from pre-existing labels - if label_set != ticks: + if extended: ax.axhline(len(label_set), color="k", alpha=0.5) if tick: @@ -319,11 +363,6 @@ def labeled_intervals( ax.set_yticklabels(ticks, va="bottom") ax.yaxis.set_major_formatter(IntervalFormatter(base, ticks)) - if base.size: - __expand_limits(ax, [base.min(), (base + height).max()], which="y") - if intervals.size: - __expand_limits(ax, [intervals.min(), intervals.max()], which="x") - return ax @@ -389,13 +428,19 @@ def hierarchy(intervals_hier, labels_hier, levels=None, ax=None, **kwargs): for ints, labs, key in zip(intervals_hier[::-1], labels_hier[::-1], levels[::-1]): labeled_intervals(ints, labs, label=key, ax=ax, **kwargs) - # Reverse the patch ordering for anything we've added. - # This way, intervals are listed in the legend from top to bottom - ax.patches[n_patches:] = ax.patches[n_patches:][::-1] return ax -def events(times, labels=None, base=None, height=None, ax=None, text_kw=None, **kwargs): +def events( + times, + labels=None, + base=None, + height=None, + ax=None, + text_kw=None, + prop_cycle=None, + **kwargs +): """Plot event times as a set of vertical lines Parameters @@ -413,6 +458,7 @@ def events(times, labels=None, base=None, height=None, ax=None, text_kw=None, ** height : number The height of the lines. By default, this will be the top of the plot (minus `base`). + .. note:: If either `base` or `height` are provided, both must be provided. ax : matplotlib.pyplot.axes An axis handle on which to draw the segmentation. If none is provided, a new set of axes is created. @@ -420,6 +466,9 @@ def events(times, labels=None, base=None, height=None, ax=None, text_kw=None, ** If `labels` is provided, the properties of the text objects can be specified here. See `matplotlib.pyplot.Text` for valid parameters + prop_cycle : cycle.Cycler + An optional property cycle object to specify style properties. + If not provided, the default property cycler will be retrieved from matplotlib. **kwargs Additional keyword arguments to pass to `matplotlib.pyplot.vlines`. @@ -441,39 +490,52 @@ def events(times, labels=None, base=None, height=None, ax=None, text_kw=None, ** # Get the axes handle ax, new_axes = __get_axes(ax=ax) - # If we have fresh axes, set the limits - - if new_axes: - # Infer base and height - if base is None: - base = 0 - if height is None: - height = 1 - - ax.set_ylim([base, height]) - else: - if base is None: - base = ax.get_ylim()[0] + if prop_cycle is None: + __AXMAP[ax].setdefault("prop_cycle", mpl.rcParams["axes.prop_cycle"]) + __AXMAP[ax].setdefault("prop_iter", iter(mpl.rcParams["axes.prop_cycle"])) + elif "prop_iter" not in __AXMAP[ax]: + __AXMAP[ax]["prop_cycle"] = prop_cycle + __AXMAP[ax]["prop_iter"] = iter(prop_cycle) - if height is None: - height = ax.get_ylim()[1] + prop_cycle = __AXMAP[ax]["prop_cycle"] + prop_iter = __AXMAP[ax]["prop_iter"] - cycler = ax._get_patches_for_fill.prop_cycler + if base is None and height is None: + # If neither are provided, we'll use axes coordinates to span the figure + base, height = 0, 1 + transform = ax.get_xaxis_transform() - style = next(cycler).copy() + elif base is not None and height is not None: + # If both are provided, we'll use data coordinates + transform = None + else: + raise ValueError("When specifying base or height, both must be provided.") + + # Advance the property iterator if we can, restart it if we must + try: + properties = next(prop_iter) + except StopIteration: + prop_iter = iter(prop_cycle) + __AXMAP[ax]["prop_iter"] = prop_iter + properties = next(prop_iter) + + style = { + k: v for k, v in properties.items() if k in ["color", "linestyle", "linewidth"] + } style.update(kwargs) + # If the user provided 'colors', don't override it with 'color' if "colors" in style: style.pop("color", None) - lines = ax.vlines(times, base, base + height, **style) + lines = ax.vlines(times, base, base + height, transform=transform, **style) if labels: for path, lab in zip(lines.get_paths(), labels): ax.annotate( lab, xy=(path.vertices[0][0], height), - xycoords="data", + xycoords=transform, xytext=(8, -10), textcoords="offset points", **text_kw @@ -482,15 +544,12 @@ def events(times, labels=None, base=None, height=None, ax=None, text_kw=None, ** if new_axes: ax.set_yticks([]) - __expand_limits(ax, [base, base + height], which="y") - - if times.size: - __expand_limits(ax, [times.min(), times.max()], which="x") - return ax -def pitch(times, frequencies, midi=False, unvoiced=False, ax=None, **kwargs): +def pitch( + times, frequencies, midi=False, unvoiced=False, ax=None, prop_cycle=None, **kwargs +): """Visualize pitch contours Parameters @@ -517,6 +576,10 @@ def pitch(times, frequencies, midi=False, unvoiced=False, ax=None, **kwargs): An axis handle on which to draw the pitch contours. If none is provided, a new set of axes is created. + prop_cycle : cycle.Cycler + An optional property cycle object to specify style properties. + If not provided, the default property cycler will be retrieved from matplotlib. + **kwargs Additional keyword arguments to `matplotlib.pyplot.plot`. @@ -527,6 +590,16 @@ def pitch(times, frequencies, midi=False, unvoiced=False, ax=None, **kwargs): """ ax, _ = __get_axes(ax=ax) + if prop_cycle is None: + __AXMAP[ax].setdefault("prop_cycle", mpl.rcParams["axes.prop_cycle"]) + __AXMAP[ax].setdefault("prop_iter", iter(mpl.rcParams["axes.prop_cycle"])) + elif "prop_iter" not in __AXMAP[ax]: + __AXMAP[ax]["prop_cycle"] = prop_cycle + __AXMAP[ax]["prop_iter"] = iter(prop_cycle) + + prop_cycle = __AXMAP[ax]["prop_cycle"] + prop_iter = __AXMAP[ax]["prop_iter"] + times = np.asarray(times) # First, segment into contiguously voiced contours @@ -549,8 +622,12 @@ def pitch(times, frequencies, midi=False, unvoiced=False, ax=None, **kwargs): u_slices.append(idx) # Now we just need to plot the contour - style = dict() - style.update(next(ax._get_lines.prop_cycler)) + try: + style = next(prop_iter) + except StopIteration: + prop_iter = iter(prop_cycle) + __AXMAP[ax]["prop_iter"] = prop_iter + style = next(prop_iter) style.update(kwargs) if midi: @@ -573,7 +650,9 @@ def pitch(times, frequencies, midi=False, unvoiced=False, ax=None, **kwargs): return ax -def multipitch(times, frequencies, midi=False, unvoiced=False, ax=None, **kwargs): +def multipitch( + times, frequencies, midi=False, unvoiced=False, ax=None, prop_cycle=None, **kwargs +): """Visualize multiple f0 measurements Parameters @@ -603,6 +682,10 @@ def multipitch(times, frequencies, midi=False, unvoiced=False, ax=None, **kwargs An axis handle on which to draw the pitch contours. If none is provided, a new set of axes is created. + prop_cycle : cycle.Cycler + An optional property cycle object to specify style properties. + If not provided, the default property cycler will be retrieved from matplotlib. + **kwargs Additional keyword arguments to `plt.scatter`. @@ -614,9 +697,24 @@ def multipitch(times, frequencies, midi=False, unvoiced=False, ax=None, **kwargs # Get the axes handle ax, _ = __get_axes(ax=ax) + if prop_cycle is None: + __AXMAP[ax].setdefault("prop_cycle", mpl.rcParams["axes.prop_cycle"]) + __AXMAP[ax].setdefault("prop_iter", iter(mpl.rcParams["axes.prop_cycle"])) + elif "prop_iter" not in __AXMAP[ax]: + __AXMAP[ax]["prop_cycle"] = prop_cycle + __AXMAP[ax]["prop_iter"] = iter(prop_cycle) + + prop_cycle = __AXMAP[ax]["prop_cycle"] + prop_iter = __AXMAP[ax]["prop_iter"] + # Set up a style for the plot - style_voiced = dict() - style_voiced.update(next(ax._get_lines.prop_cycler)) + try: + style_voiced = next(prop_iter) + except StopIteration: + prop_iter = iter(prop_cycle) + __AXMAP[ax]["prop_iter"] = prop_iter + style_voiced = next(prop_iter) + style_voiced.update(kwargs) style_unvoiced = style_voiced.copy() @@ -710,11 +808,21 @@ def piano_roll(intervals, pitches=None, midi=None, ax=None, **kwargs): # Minor tick at each semitone ax.yaxis.set_minor_locator(MultipleLocator(1)) - ax.axis("auto") return ax -def separation(sources, fs=22050, labels=None, alpha=0.75, ax=None, **kwargs): +def separation( + sources, + fs=22050, + labels=None, + alpha=0.75, + ax=None, + rasterized=True, + edgecolors="None", + shading="gouraud", + prop_cycle=None, + **kwargs +): """Source-separation visualization Parameters @@ -730,6 +838,17 @@ def separation(sources, fs=22050, labels=None, alpha=0.75, ax=None, **kwargs): ax : matplotlib.pyplot.axes An axis handle on which to draw the spectrograms. If none is provided, a new set of axes is created. + rasterized : bool + If `True`, the spectrogram is rasterized. + edgecolors : str or None + The color of the edges of the spectrogram patches. + Set to "None" (default) to disable edge coloring. + shading : str + The shading method to use for the spectrogram. + See `matplotlib.pyplot.pcolormesh` for valid options. + prop_cycle : cycle.Cycler + An optional property cycle object to specify colors for each signal. + If not provided, the default property cycler will be retrieved from matplotlib. **kwargs Additional keyword arguments to ``scipy.signal.spectrogram`` @@ -768,12 +887,29 @@ def separation(sources, fs=22050, labels=None, alpha=0.75, ax=None, **kwargs): color_conv = ColorConverter() + if prop_cycle is None: + __AXMAP[ax].setdefault("prop_cycle", mpl.rcParams["axes.prop_cycle"]) + __AXMAP[ax].setdefault("prop_iter", iter(mpl.rcParams["axes.prop_cycle"])) + elif "prop_iter" not in __AXMAP[ax]: + __AXMAP[ax]["prop_cycle"] = prop_cycle + __AXMAP[ax]["prop_iter"] = iter(prop_cycle) + + prop_cycle = __AXMAP[ax]["prop_cycle"] + prop_iter = __AXMAP[ax]["prop_iter"] + for i, spec in enumerate(specs): # For each source, grab a new color from the cycler # Then construct a colormap that interpolates from # [transparent white -> new color] - color = next(ax._get_lines.prop_cycler)["color"] - color = color_conv.to_rgba(color, alpha=alpha) + # Advance the property iterator if we can, restart it if we must + try: + properties = next(prop_iter) + except StopIteration: + prop_iter = iter(prop_cycle) + __AXMAP[ax]["prop_iter"] = prop_iter + properties = next(prop_iter) + + color = color_conv.to_rgba(properties["color"], alpha=alpha) cmap = LinearSegmentedColormap.from_list( labels[i], [(1.0, 1.0, 1.0, 0.0), color] ) @@ -784,16 +920,16 @@ def separation(sources, fs=22050, labels=None, alpha=0.75, ax=None, **kwargs): spec, cmap=cmap, norm=LogNorm(vmin=ref_min, vmax=ref_max), - shading="gouraud", - label=labels[i], + rasterized=rasterized, + edgecolors=edgecolors, + shading=shading, ) # Attach a 0x0 rect to the axis with the corresponding label # This way, it will show up in the legend - ax.add_patch(Rectangle((0, 0), 0, 0, color=color, label=labels[i])) - - if new_axes: - ax.axis("tight") + ax.add_patch( + Rectangle((times.min(), freqs.min()), 0, 0, color=color, label=labels[i]) + ) return ax diff --git a/mir_eval/sonify.py b/mir_eval/sonify.py index 45d059f6..e05be5e5 100644 --- a/mir_eval/sonify.py +++ b/mir_eval/sonify.py @@ -224,7 +224,7 @@ def pitch_contour( time indices for each frequency measurement, in seconds frequencies : np.ndarray frequency measurements, in Hz. - Non-positive measurements will be interpreted as un-voiced samples. + Non-positive measurements or NaNs will be interpreted as un-voiced samples. fs : int desired sampling rate of the output signal amplitudes : np.ndarray @@ -252,6 +252,8 @@ def pitch_contour( # Squash the negative frequencies. # wave(0) = 0, so clipping here will un-voice the corresponding instants frequencies = np.maximum(frequencies, 0.0) + # Convert nans to zeros to unvoice + frequencies = np.nan_to_num(frequencies, copy=False) # Build a frequency interpolator f_interp = interp1d( diff --git a/setup.cfg b/setup.cfg index fbf4fd5e..61ac3672 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,5 +1,5 @@ [tool:pytest] -addopts = --cov-report term-missing --cov mir_eval --cov-report=xml +addopts = --cov-report term-missing --cov mir_eval --cov-report=xml --mpl --mpl-baseline-path=baseline_images/test_display [pydocstyle] # convention = numpy diff --git a/tests/baseline_images/test_display/events.png b/tests/baseline_images/test_display/events.png deleted file mode 100644 index 48673b1f..00000000 Binary files a/tests/baseline_images/test_display/events.png and /dev/null differ diff --git a/tests/baseline_images/test_display/hierarchy_label.png b/tests/baseline_images/test_display/hierarchy_label.png deleted file mode 100644 index c8eb0502..00000000 Binary files a/tests/baseline_images/test_display/hierarchy_label.png and /dev/null differ diff --git a/tests/baseline_images/test_display/hierarchy_nolabel.png b/tests/baseline_images/test_display/hierarchy_nolabel.png deleted file mode 100644 index d80b2241..00000000 Binary files a/tests/baseline_images/test_display/hierarchy_nolabel.png and /dev/null differ diff --git a/tests/baseline_images/test_display/labeled_events.png b/tests/baseline_images/test_display/labeled_events.png deleted file mode 100644 index 1851048c..00000000 Binary files a/tests/baseline_images/test_display/labeled_events.png and /dev/null differ diff --git a/tests/baseline_images/test_display/labeled_intervals.png b/tests/baseline_images/test_display/labeled_intervals.png deleted file mode 100644 index a0418c8b..00000000 Binary files a/tests/baseline_images/test_display/labeled_intervals.png and /dev/null differ diff --git a/tests/baseline_images/test_display/labeled_intervals_compare.png b/tests/baseline_images/test_display/labeled_intervals_compare.png deleted file mode 100644 index dad20920..00000000 Binary files a/tests/baseline_images/test_display/labeled_intervals_compare.png and /dev/null differ diff --git a/tests/baseline_images/test_display/labeled_intervals_compare_common.png b/tests/baseline_images/test_display/labeled_intervals_compare_common.png deleted file mode 100644 index 53afe8c1..00000000 Binary files a/tests/baseline_images/test_display/labeled_intervals_compare_common.png and /dev/null differ diff --git a/tests/baseline_images/test_display/labeled_intervals_compare_noextend.png b/tests/baseline_images/test_display/labeled_intervals_compare_noextend.png deleted file mode 100644 index 3aa3b481..00000000 Binary files a/tests/baseline_images/test_display/labeled_intervals_compare_noextend.png and /dev/null differ diff --git a/tests/baseline_images/test_display/labeled_intervals_noextend.png b/tests/baseline_images/test_display/labeled_intervals_noextend.png deleted file mode 100644 index a0418c8b..00000000 Binary files a/tests/baseline_images/test_display/labeled_intervals_noextend.png and /dev/null differ diff --git a/tests/baseline_images/test_display/multipitch_hz_unvoiced.png b/tests/baseline_images/test_display/multipitch_hz_unvoiced.png deleted file mode 100644 index 65e1e221..00000000 Binary files a/tests/baseline_images/test_display/multipitch_hz_unvoiced.png and /dev/null differ diff --git a/tests/baseline_images/test_display/multipitch_hz_voiced.png b/tests/baseline_images/test_display/multipitch_hz_voiced.png deleted file mode 100644 index 65e1e221..00000000 Binary files a/tests/baseline_images/test_display/multipitch_hz_voiced.png and /dev/null differ diff --git a/tests/baseline_images/test_display/multipitch_midi.png b/tests/baseline_images/test_display/multipitch_midi.png deleted file mode 100644 index 0652575e..00000000 Binary files a/tests/baseline_images/test_display/multipitch_midi.png and /dev/null differ diff --git a/tests/baseline_images/test_display/piano_roll.png b/tests/baseline_images/test_display/piano_roll.png deleted file mode 100644 index 9fa4d165..00000000 Binary files a/tests/baseline_images/test_display/piano_roll.png and /dev/null differ diff --git a/tests/baseline_images/test_display/piano_roll_midi.png b/tests/baseline_images/test_display/piano_roll_midi.png deleted file mode 100644 index 9fa4d165..00000000 Binary files a/tests/baseline_images/test_display/piano_roll_midi.png and /dev/null differ diff --git a/tests/baseline_images/test_display/pitch_hz.png b/tests/baseline_images/test_display/pitch_hz.png deleted file mode 100644 index 7d9a2959..00000000 Binary files a/tests/baseline_images/test_display/pitch_hz.png and /dev/null differ diff --git a/tests/baseline_images/test_display/pitch_midi.png b/tests/baseline_images/test_display/pitch_midi.png deleted file mode 100644 index fe2d53cd..00000000 Binary files a/tests/baseline_images/test_display/pitch_midi.png and /dev/null differ diff --git a/tests/baseline_images/test_display/pitch_midi_hz.png b/tests/baseline_images/test_display/pitch_midi_hz.png deleted file mode 100644 index 1fcadceb..00000000 Binary files a/tests/baseline_images/test_display/pitch_midi_hz.png and /dev/null differ diff --git a/tests/baseline_images/test_display/segment.png b/tests/baseline_images/test_display/segment.png deleted file mode 100644 index a0cb5d35..00000000 Binary files a/tests/baseline_images/test_display/segment.png and /dev/null differ diff --git a/tests/baseline_images/test_display/segment_text.png b/tests/baseline_images/test_display/segment_text.png deleted file mode 100644 index 79efe951..00000000 Binary files a/tests/baseline_images/test_display/segment_text.png and /dev/null differ diff --git a/tests/baseline_images/test_display/separation.png b/tests/baseline_images/test_display/separation.png deleted file mode 100644 index 26e61999..00000000 Binary files a/tests/baseline_images/test_display/separation.png and /dev/null differ diff --git a/tests/baseline_images/test_display/separation_label.png b/tests/baseline_images/test_display/separation_label.png deleted file mode 100644 index 70513020..00000000 Binary files a/tests/baseline_images/test_display/separation_label.png and /dev/null differ diff --git a/tests/baseline_images/test_display/test_display_events.png b/tests/baseline_images/test_display/test_display_events.png new file mode 100644 index 00000000..49cbf472 Binary files /dev/null and b/tests/baseline_images/test_display/test_display_events.png differ diff --git a/tests/baseline_images/test_display/test_display_hierarchy_label.png b/tests/baseline_images/test_display/test_display_hierarchy_label.png new file mode 100644 index 00000000..acb69a81 Binary files /dev/null and b/tests/baseline_images/test_display/test_display_hierarchy_label.png differ diff --git a/tests/baseline_images/test_display/test_display_hierarchy_nolabel.png b/tests/baseline_images/test_display/test_display_hierarchy_nolabel.png new file mode 100644 index 00000000..8dbf65f6 Binary files /dev/null and b/tests/baseline_images/test_display/test_display_hierarchy_nolabel.png differ diff --git a/tests/baseline_images/test_display/test_display_labeled_events.png b/tests/baseline_images/test_display/test_display_labeled_events.png new file mode 100644 index 00000000..d13f83fc Binary files /dev/null and b/tests/baseline_images/test_display/test_display_labeled_events.png differ diff --git a/tests/baseline_images/test_display/test_display_labeled_intervals.png b/tests/baseline_images/test_display/test_display_labeled_intervals.png new file mode 100644 index 00000000..f76608cb Binary files /dev/null and b/tests/baseline_images/test_display/test_display_labeled_intervals.png differ diff --git a/tests/baseline_images/test_display/test_display_labeled_intervals_compare.png b/tests/baseline_images/test_display/test_display_labeled_intervals_compare.png new file mode 100644 index 00000000..c4a5d183 Binary files /dev/null and b/tests/baseline_images/test_display/test_display_labeled_intervals_compare.png differ diff --git a/tests/baseline_images/test_display/test_display_labeled_intervals_compare_common.png b/tests/baseline_images/test_display/test_display_labeled_intervals_compare_common.png new file mode 100644 index 00000000..fdc053b5 Binary files /dev/null and b/tests/baseline_images/test_display/test_display_labeled_intervals_compare_common.png differ diff --git a/tests/baseline_images/test_display/test_display_labeled_intervals_compare_noextend.png b/tests/baseline_images/test_display/test_display_labeled_intervals_compare_noextend.png new file mode 100644 index 00000000..208f189f Binary files /dev/null and b/tests/baseline_images/test_display/test_display_labeled_intervals_compare_noextend.png differ diff --git a/tests/baseline_images/test_display/test_display_labeled_intervals_noextend.png b/tests/baseline_images/test_display/test_display_labeled_intervals_noextend.png new file mode 100644 index 00000000..f76608cb Binary files /dev/null and b/tests/baseline_images/test_display/test_display_labeled_intervals_noextend.png differ diff --git a/tests/baseline_images/test_display/test_display_multipitch_hz_unvoiced.png b/tests/baseline_images/test_display/test_display_multipitch_hz_unvoiced.png new file mode 100644 index 00000000..674d37be Binary files /dev/null and b/tests/baseline_images/test_display/test_display_multipitch_hz_unvoiced.png differ diff --git a/tests/baseline_images/test_display/test_display_multipitch_hz_voiced.png b/tests/baseline_images/test_display/test_display_multipitch_hz_voiced.png new file mode 100644 index 00000000..674d37be Binary files /dev/null and b/tests/baseline_images/test_display/test_display_multipitch_hz_voiced.png differ diff --git a/tests/baseline_images/test_display/test_display_multipitch_midi.png b/tests/baseline_images/test_display/test_display_multipitch_midi.png new file mode 100644 index 00000000..d01bcf6e Binary files /dev/null and b/tests/baseline_images/test_display/test_display_multipitch_midi.png differ diff --git a/tests/baseline_images/test_display/test_display_piano_roll.png b/tests/baseline_images/test_display/test_display_piano_roll.png new file mode 100644 index 00000000..fd38a3ac Binary files /dev/null and b/tests/baseline_images/test_display/test_display_piano_roll.png differ diff --git a/tests/baseline_images/test_display/test_display_piano_roll_midi.png b/tests/baseline_images/test_display/test_display_piano_roll_midi.png new file mode 100644 index 00000000..fd38a3ac Binary files /dev/null and b/tests/baseline_images/test_display/test_display_piano_roll_midi.png differ diff --git a/tests/baseline_images/test_display/test_display_pitch_hz.png b/tests/baseline_images/test_display/test_display_pitch_hz.png new file mode 100644 index 00000000..b987b0bc Binary files /dev/null and b/tests/baseline_images/test_display/test_display_pitch_hz.png differ diff --git a/tests/baseline_images/test_display/test_display_pitch_midi.png b/tests/baseline_images/test_display/test_display_pitch_midi.png new file mode 100644 index 00000000..32adeb80 Binary files /dev/null and b/tests/baseline_images/test_display/test_display_pitch_midi.png differ diff --git a/tests/baseline_images/test_display/test_display_pitch_midi_hz.png b/tests/baseline_images/test_display/test_display_pitch_midi_hz.png new file mode 100644 index 00000000..fcaaced1 Binary files /dev/null and b/tests/baseline_images/test_display/test_display_pitch_midi_hz.png differ diff --git a/tests/baseline_images/test_display/test_display_segment.png b/tests/baseline_images/test_display/test_display_segment.png new file mode 100644 index 00000000..56411f49 Binary files /dev/null and b/tests/baseline_images/test_display/test_display_segment.png differ diff --git a/tests/baseline_images/test_display/test_display_segment_text.png b/tests/baseline_images/test_display/test_display_segment_text.png new file mode 100644 index 00000000..cf2f4a34 Binary files /dev/null and b/tests/baseline_images/test_display/test_display_segment_text.png differ diff --git a/tests/baseline_images/test_display/test_display_separation.png b/tests/baseline_images/test_display/test_display_separation.png new file mode 100644 index 00000000..2a1cade4 Binary files /dev/null and b/tests/baseline_images/test_display/test_display_separation.png differ diff --git a/tests/baseline_images/test_display/test_display_separation_label.png b/tests/baseline_images/test_display/test_display_separation_label.png new file mode 100644 index 00000000..cdb0056f Binary files /dev/null and b/tests/baseline_images/test_display/test_display_separation_label.png differ diff --git a/tests/baseline_images/test_display/test_display_ticker_midi_zoom.png b/tests/baseline_images/test_display/test_display_ticker_midi_zoom.png new file mode 100644 index 00000000..b13974de Binary files /dev/null and b/tests/baseline_images/test_display/test_display_ticker_midi_zoom.png differ diff --git a/tests/baseline_images/test_display/ticker_midi_zoom.png b/tests/baseline_images/test_display/ticker_midi_zoom.png deleted file mode 100644 index a3fc6f20..00000000 Binary files a/tests/baseline_images/test_display/ticker_midi_zoom.png and /dev/null differ diff --git a/tests/test_display.py b/tests/test_display.py index 748bf476..d2c254b0 100644 --- a/tests/test_display.py +++ b/tests/test_display.py @@ -5,16 +5,11 @@ # For testing purposes, clobber the rcfile import matplotlib -matplotlib.use("Agg") # nopep8 - import matplotlib.pyplot as plt import numpy as np import pytest -# We'll make a decorator to handle style contexts -from decorator import decorator - import mir_eval import mir_eval.display from mir_eval.io import load_labeled_intervals @@ -23,20 +18,25 @@ from mir_eval.io import load_ragged_time_series from mir_eval.io import load_wav +from packaging import version -pytestmark = pytest.mark.skip( - reason="disabling display tests until after merger of #370" -) +# Workaround to enable test skipping on older matplotlibs where we know it to be problematic +MPL_VERSION = version.parse(matplotlib.__version__) +OLD_MPL = not (MPL_VERSION >= version.parse("3.8.0")) +# Workaround for old freetype builds with our image fixtures +FT_VERSION = version.parse(matplotlib.ft2font.__freetype_version__) +OLD_FT = not (FT_VERSION >= version.parse("2.10")) -@decorator -def styled(f, *args, **kwargs): - matplotlib.rcdefaults() - return f(*args, **kwargs) +STYLE = "default" -@pytest.mark.mpl_image_compare(baseline_images=["segment"], extensions=["png"]) -@styled +@pytest.mark.mpl_image_compare( + baseline_images=["test_display_segment"], + extensions=["png"], + style=STYLE, + tolerance=6, +) def test_display_segment(): plt.figure() @@ -48,10 +48,16 @@ def test_display_segment(): # Draw a legend plt.legend() + return plt.gcf() -@pytest.mark.mpl_image_compare(baseline_images=["segment_text"], extensions=["png"]) -@styled +@pytest.mark.mpl_image_compare( + baseline_images=["test_display_segment_text"], + extensions=["png"], + style=STYLE, + tolerance=6, +) +@pytest.mark.xfail(OLD_MPL, reason=f"matplotlib version < {MPL_VERSION}", strict=False) def test_display_segment_text(): plt.figure() @@ -60,12 +66,15 @@ def test_display_segment_text(): # Plot the segments with no labels mir_eval.display.segments(intervals, labels, text=True) + return plt.gcf() @pytest.mark.mpl_image_compare( - baseline_images=["labeled_intervals"], extensions=["png"] + baseline_images=["test_display_labeled_intervals"], + extensions=["png"], + style=STYLE, + tolerance=6, ) -@styled def test_display_labeled_intervals(): plt.figure() @@ -74,12 +83,15 @@ def test_display_labeled_intervals(): # Plot the chords with nothing fancy mir_eval.display.labeled_intervals(intervals, labels) + return plt.gcf() @pytest.mark.mpl_image_compare( - baseline_images=["labeled_intervals_noextend"], extensions=["png"] + baseline_images=["test_display_labeled_intervals_noextend"], + extensions=["png"], + style=STYLE, + tolerance=6, ) -@styled def test_display_labeled_intervals_noextend(): plt.figure() @@ -92,12 +104,15 @@ def test_display_labeled_intervals_noextend(): mir_eval.display.labeled_intervals( intervals, labels, label_set=[], extend_labels=False, ax=ax ) + return plt.gcf() @pytest.mark.mpl_image_compare( - baseline_images=["labeled_intervals_compare"], extensions=["png"] + baseline_images=["test_display_labeled_intervals_compare"], + extensions=["png"], + style=STYLE, + tolerance=6, ) -@styled def test_display_labeled_intervals_compare(): plt.figure() @@ -112,12 +127,15 @@ def test_display_labeled_intervals_compare(): mir_eval.display.labeled_intervals(est_int, est_labels, alpha=0.5, label="Estimate") plt.legend() + return plt.gcf() @pytest.mark.mpl_image_compare( - baseline_images=["labeled_intervals_compare_noextend"], extensions=["png"] + baseline_images=["test_display_labeled_intervals_compare_noextend"], + extensions=["png"], + style=STYLE, + tolerance=6, ) -@styled def test_display_labeled_intervals_compare_noextend(): plt.figure() @@ -134,12 +152,15 @@ def test_display_labeled_intervals_compare_noextend(): ) plt.legend() + return plt.gcf() @pytest.mark.mpl_image_compare( - baseline_images=["labeled_intervals_compare_common"], extensions=["png"] + baseline_images=["test_display_labeled_intervals_compare_common"], + extensions=["png"], + style=STYLE, + tolerance=6, ) -@styled def test_display_labeled_intervals_compare_common(): plt.figure() @@ -158,12 +179,15 @@ def test_display_labeled_intervals_compare_common(): ) plt.legend() + return plt.gcf() @pytest.mark.mpl_image_compare( - baseline_images=["hierarchy_nolabel"], extensions=["png"] + baseline_images=["test_display_hierarchy_nolabel"], + extensions=["png"], + style=STYLE, + tolerance=6, ) -@styled def test_display_hierarchy_nolabel(): plt.figure() @@ -175,10 +199,15 @@ def test_display_hierarchy_nolabel(): mir_eval.display.hierarchy([int0, int1], [lab0, lab1]) plt.legend() + return plt.gcf() -@pytest.mark.mpl_image_compare(baseline_images=["hierarchy_label"], extensions=["png"]) -@styled +@pytest.mark.mpl_image_compare( + baseline_images=["test_display_hierarchy_label"], + extensions=["png"], + style=STYLE, + tolerance=6, +) def test_display_hierarchy_label(): plt.figure() @@ -190,11 +219,17 @@ def test_display_hierarchy_label(): mir_eval.display.hierarchy([int0, int1], [lab0, lab1], levels=["Large", "Small"]) plt.legend() + return plt.gcf() -@pytest.mark.mpl_image_compare(baseline_images=["pitch_hz"], extensions=["png"]) -@styled -def test_pitch_hz(): +@pytest.mark.mpl_image_compare( + baseline_images=["test_display_pitch_hz"], + extensions=["png"], + style=STYLE, + tolerance=6, +) +@pytest.mark.xfail(OLD_FT, reason=f"freetype version < {FT_VERSION}", strict=False) +def test_display_pitch_hz(): plt.figure() ref_times, ref_freqs = load_labeled_events("data/melody/ref00.txt") @@ -204,11 +239,16 @@ def test_pitch_hz(): mir_eval.display.pitch(ref_times, ref_freqs, unvoiced=True, label="Reference") mir_eval.display.pitch(est_times, est_freqs, unvoiced=True, label="Estimate") plt.legend() + return plt.gcf() -@pytest.mark.mpl_image_compare(baseline_images=["pitch_midi"], extensions=["png"]) -@styled -def test_pitch_midi(): +@pytest.mark.mpl_image_compare( + baseline_images=["test_display_pitch_midi"], + extensions=["png"], + style=STYLE, + tolerance=6, +) +def test_display_pitch_midi(): plt.figure() times, freqs = load_labeled_events("data/melody/ref00.txt") @@ -216,11 +256,16 @@ def test_pitch_midi(): # Plot pitches on a midi scale with note tickers mir_eval.display.pitch(times, freqs, midi=True) mir_eval.display.ticker_notes() + return plt.gcf() -@pytest.mark.mpl_image_compare(baseline_images=["pitch_midi_hz"], extensions=["png"]) -@styled -def test_pitch_midi_hz(): +@pytest.mark.mpl_image_compare( + baseline_images=["test_display_pitch_midi_hz"], + extensions=["png"], + style=STYLE, + tolerance=6, +) +def test_display_pitch_midi_hz(): plt.figure() times, freqs = load_labeled_events("data/melody/ref00.txt") @@ -228,36 +273,47 @@ def test_pitch_midi_hz(): # Plot pitches on a midi scale with note tickers mir_eval.display.pitch(times, freqs, midi=True) mir_eval.display.ticker_pitch() + return plt.gcf() @pytest.mark.mpl_image_compare( - baseline_images=["multipitch_hz_unvoiced"], extensions=["png"] + baseline_images=["test_display_multipitch_hz_unvoiced"], + extensions=["png"], + style=STYLE, + tolerance=6, ) -@styled -def test_multipitch_hz_unvoiced(): +def test_display_multipitch_hz_unvoiced(): plt.figure() times, pitches = load_ragged_time_series("data/multipitch/est01.txt") # Plot pitches on a midi scale with note tickers mir_eval.display.multipitch(times, pitches, midi=False, unvoiced=True) + return plt.gcf() @pytest.mark.mpl_image_compare( - baseline_images=["multipitch_hz_voiced"], extensions=["png"] + baseline_images=["test_display_multipitch_hz_voiced"], + extensions=["png"], + style=STYLE, + tolerance=6, ) -@styled -def test_multipitch_hz_voiced(): +def test_display_multipitch_hz_voiced(): plt.figure() times, pitches = load_ragged_time_series("data/multipitch/est01.txt") mir_eval.display.multipitch(times, pitches, midi=False, unvoiced=False) + return plt.gcf() -@pytest.mark.mpl_image_compare(baseline_images=["multipitch_midi"], extensions=["png"]) -@styled -def test_multipitch_midi(): +@pytest.mark.mpl_image_compare( + baseline_images=["test_display_multipitch_midi"], + extensions=["png"], + style=STYLE, + tolerance=6, +) +def test_display_multipitch_midi(): plt.figure() ref_t, ref_p = load_ragged_time_series("data/multipitch/ref01.txt") @@ -268,11 +324,16 @@ def test_multipitch_midi(): mir_eval.display.multipitch(est_t, est_p, midi=True, alpha=0.5, label="Estimate") plt.legend() + return plt.gcf() -@pytest.mark.mpl_image_compare(baseline_images=["piano_roll"], extensions=["png"]) -@styled -def test_pianoroll(): +@pytest.mark.mpl_image_compare( + baseline_images=["test_display_piano_roll"], + extensions=["png"], + style=STYLE, + tolerance=6, +) +def test_display_piano_roll(): plt.figure() ref_t, ref_p = load_valued_intervals("data/transcription/ref04.txt") @@ -284,11 +345,16 @@ def test_pianoroll(): ) plt.legend() + return plt.gcf() -@pytest.mark.mpl_image_compare(baseline_images=["piano_roll_midi"], extensions=["png"]) -@styled -def test_pianoroll_midi(): +@pytest.mark.mpl_image_compare( + baseline_images=["test_display_piano_roll_midi"], + extensions=["png"], + style=STYLE, + tolerance=6, +) +def test_display_piano_roll_midi(): plt.figure() ref_t, ref_p = load_valued_intervals("data/transcription/ref04.txt") @@ -302,20 +368,31 @@ def test_pianoroll_midi(): ) plt.legend() + return plt.gcf() -@pytest.mark.mpl_image_compare(baseline_images=["ticker_midi_zoom"], extensions=["png"]) -@styled -def test_ticker_midi_zoom(): +@pytest.mark.mpl_image_compare( + baseline_images=["test_display_ticker_midi_zoom"], + extensions=["png"], + style=STYLE, + tolerance=6, +) +def test_display_ticker_midi_zoom(): plt.figure() plt.plot(np.arange(3)) mir_eval.display.ticker_notes() + return plt.gcf() -@pytest.mark.mpl_image_compare(baseline_images=["separation"], extensions=["png"]) -@styled -def test_separation(): +@pytest.mark.mpl_image_compare( + baseline_images=["test_display_separation"], + extensions=["png"], + style=STYLE, + tolerance=6, +) +@pytest.mark.xfail(OLD_FT, reason=f"freetype version < {FT_VERSION}", strict=False) +def test_display_separation(): plt.figure() x0, fs = load_wav("data/separation/ref05/0.wav") @@ -323,11 +400,17 @@ def test_separation(): x2, fs = load_wav("data/separation/ref05/2.wav") mir_eval.display.separation([x0, x1, x2], fs=fs) + return plt.gcf() -@pytest.mark.mpl_image_compare(baseline_images=["separation_label"], extensions=["png"]) -@styled -def test_separation_label(): +@pytest.mark.mpl_image_compare( + baseline_images=["test_display_separation_label"], + extensions=["png"], + style=STYLE, + tolerance=6, +) +@pytest.mark.xfail(OLD_FT, reason=f"freetype version < {FT_VERSION}", strict=False) +def test_display_separation_label(): plt.figure() x0, fs = load_wav("data/separation/ref05/0.wav") @@ -337,11 +420,16 @@ def test_separation_label(): mir_eval.display.separation([x0, x1, x2], fs=fs, labels=["Alice", "Bob", "Carol"]) plt.legend() + return plt.gcf() -@pytest.mark.mpl_image_compare(baseline_images=["events"], extensions=["png"]) -@styled -def test_events(): +@pytest.mark.mpl_image_compare( + baseline_images=["test_display_events"], + extensions=["png"], + style=STYLE, + tolerance=6, +) +def test_display_events(): plt.figure() # Load some event data @@ -352,11 +440,16 @@ def test_events(): mir_eval.display.events(beats_ref, label="reference") mir_eval.display.events(beats_est, label="estimate") plt.legend() + return plt.gcf() -@pytest.mark.mpl_image_compare(baseline_images=["labeled_events"], extensions=["png"]) -@styled -def test_labeled_events(): +@pytest.mark.mpl_image_compare( + baseline_images=["test_display_labeled_events"], + extensions=["png"], + style=STYLE, + tolerance=6, +) +def test_display_labeled_events(): plt.figure() # Load some event data @@ -365,9 +458,10 @@ def test_labeled_events(): labels = list("abcdefghijklmnop") # Plot both with labels mir_eval.display.events(beats_ref, labels) + return plt.gcf() @pytest.mark.xfail(raises=ValueError) -def test_pianoroll_nopitch_nomidi(): +def test_display_pianoroll_nopitch_nomidi(): # Issue 214 mir_eval.display.piano_roll([[0, 1]])