diff --git a/mir_eval/display.py b/mir_eval/display.py index 685d7ccf..a67ae326 100644 --- a/mir_eval/display.py +++ b/mir_eval/display.py @@ -17,8 +17,7 @@ def __expand_limits(ax, limits, which="x"): - """Helper function to expand axis limits""" - + """Expand axis limits""" if which == "x": getter, setter = ax.get_xlim, ax.set_xlim elif which == "y": @@ -60,9 +59,7 @@ def __get_axes(ax=None, fig=None): new_axes : bool If `True`, the axis object was newly constructed. If `False`, the axis object already existed. - """ - new_axes = False if ax is not None: @@ -256,7 +253,6 @@ def labeled_intervals( ax : matplotlib.pyplot.axes._subplots.AxesSubplot A handle to the (possibly constructed) plot axes """ - # Get the axes handle ax, _ = __get_axes(ax=ax) @@ -347,6 +343,7 @@ def __init__(self, base, ticks): self._map = {int(k): v for k, v in zip(base, ticks)} def __call__(self, x, pos=None): + """Map the input position to its corresponding interval label""" return self._map.get(int(x), "") @@ -379,7 +376,6 @@ def hierarchy(intervals_hier, labels_hier, levels=None, ax=None, **kwargs): ax : matplotlib.pyplot.axes._subplots.AxesSubplot A handle to the (possibly constructed) plot axes """ - # This will break if a segment label exists in multiple levels if levels is None: levels = list(range(len(intervals_hier))) @@ -529,7 +525,6 @@ def pitch(times, frequencies, midi=False, unvoiced=False, ax=None, **kwargs): ax : matplotlib.pyplot.axes._subplots.AxesSubplot A handle to the (possibly constructed) plot axes """ - ax, _ = __get_axes(ax=ax) times = np.asarray(times) @@ -616,7 +611,6 @@ def multipitch(times, frequencies, midi=False, unvoiced=False, ax=None, **kwargs ax : matplotlib.pyplot.axes._subplots.AxesSubplot A handle to the (possibly constructed) plot axes """ - # Get the axes handle ax, _ = __get_axes(ax=ax) @@ -697,7 +691,6 @@ def piano_roll(intervals, pitches=None, midi=None, ax=None, **kwargs): ax : matplotlib.pyplot.axes._subplots.AxesSubplot A handle to the (possibly constructed) plot axes """ - if midi is None: if pitches is None: raise ValueError("At least one of `midi` or `pitches` " "must be provided.") @@ -745,7 +738,6 @@ def separation(sources, fs=22050, labels=None, alpha=0.75, ax=None, **kwargs): ax The axis handle for this plot """ - # Get the axes handle ax, new_axes = __get_axes(ax=ax) @@ -807,12 +799,11 @@ def separation(sources, fs=22050, labels=None, alpha=0.75, ax=None, **kwargs): def __ticker_midi_note(x, pos): - """A ticker function for midi notes. + """Format midi notes for ticker decoration. Inputs x are interpreted as midi numbers, and converted to [NOTE][OCTAVE]+[cents]. """ - NOTES = ["C", "C#", "D", "D#", "E", "F", "F#", "G", "G#", "A", "A#", "B"] cents = float(np.mod(x, 1.0)) @@ -830,12 +821,11 @@ def __ticker_midi_note(x, pos): def __ticker_midi_hz(x, pos): - """A ticker function for midi pitches. + """Format midi pitches for ticker decoration. Inputs x are interpreted as midi numbers, and converted to Hz. """ - return "{:g}".format(midi_to_hz(x)) diff --git a/mir_eval/segment.py b/mir_eval/segment.py index 1ccab6cc..7a49d6ff 100644 --- a/mir_eval/segment.py +++ b/mir_eval/segment.py @@ -85,7 +85,7 @@ def validate_boundary(reference_intervals, estimated_intervals, trim): - """Checks that the input annotations to a segment boundary estimation + """Check that the input annotations to a segment boundary estimation metric (i.e. one that only takes in segment intervals) look like valid segment times, and throws helpful errors if not. @@ -101,9 +101,7 @@ def validate_boundary(reference_intervals, estimated_intervals, trim): :func:`mir_eval.io.load_labeled_intervals`. trim : bool will the start and end events be trimmed? - """ - if trim: # If we're trimming, then we need at least 2 intervals min_size = 2 @@ -124,7 +122,7 @@ def validate_boundary(reference_intervals, estimated_intervals, trim): def validate_structure( reference_intervals, reference_labels, estimated_intervals, estimated_labels ): - """Checks that the input annotations to a structure estimation metric (i.e. + """Check that the input annotations to a structure estimation metric (i.e. one that takes in both segment boundaries and their labels) look like valid segment times and labels, and throws helpful errors if not. @@ -226,9 +224,7 @@ def detection( recall of reference reference boundaries f_measure : float F-measure (weighted harmonic mean of ``precision`` and ``recall``) - """ - validate_boundary(reference_intervals, estimated_intervals, trim) # Convert intervals to boundaries @@ -288,9 +284,7 @@ def deviation(reference_intervals, estimated_intervals, trim=False): estimated_to_reference : float median time from each estimated boundary to the closest reference boundary - """ - validate_boundary(reference_intervals, estimated_intervals, trim) # Convert intervals to boundaries @@ -471,9 +465,7 @@ def rand_index( ------- rand_index : float > 0 Rand index - """ - validate_structure( reference_intervals, reference_labels, estimated_intervals, estimated_labels ) @@ -519,7 +511,7 @@ def rand_index( def _contingency_matrix(reference_indices, estimated_indices): - """Computes the contingency matrix of a true labeling vs an estimated one. + """Compute the contingency matrix of a true labeling vs an estimated one. Parameters ---------- @@ -602,7 +594,7 @@ def ari( estimated_labels, frame_size=0.1, ): - """Adjusted Rand Index (ARI) for frame clustering segmentation evaluation. + """Compute the Adjusted Rand Index (ARI) for frame clustering segmentation evaluation. Examples -------- @@ -715,7 +707,7 @@ def _mutual_info_score(reference_indices, estimated_indices, contingency=None): def _entropy(labels): - """Calculates the entropy for a labeling. + """Calculate the entropy for a labeling. Parameters ---------- @@ -1044,9 +1036,7 @@ def nce( If `|y_ref|==1`, then `S_under` will be 0. S_F F-measure for (S_over, S_under) - """ - validate_structure( reference_intervals, reference_labels, estimated_intervals, estimated_labels ) @@ -1177,9 +1167,7 @@ def vmeasure( If `|y_ref|==1`, then `V_recall` will be 0. V_F F-measure for (V_precision, V_recall) - """ - return nce( reference_intervals, reference_labels, @@ -1226,9 +1214,7 @@ def evaluate(ref_intervals, ref_labels, est_intervals, est_labels, **kwargs): scores : dict Dictionary of scores, where the key is the metric name (str) and the value is the (float) score achieved. - """ - # Adjust timespan of estimations relative to ground truth ref_intervals, ref_labels = util.adjust_intervals( ref_intervals, labels=ref_labels, t_min=0.0 diff --git a/mir_eval/separation.py b/mir_eval/separation.py index 43a59388..0bb0704e 100644 --- a/mir_eval/separation.py +++ b/mir_eval/separation.py @@ -60,7 +60,7 @@ def validate(reference_sources, estimated_sources): - """Checks that the input data to a metric are valid, and throws helpful + """Check that the input data to a metric are valid, and throws helpful errors if not. Parameters @@ -71,7 +71,6 @@ def validate(reference_sources, estimated_sources): matrix containing estimated sources """ - if reference_sources.shape != estimated_sources.shape: raise ValueError( "The shape of estimated sources and the true " @@ -136,7 +135,7 @@ def validate(reference_sources, estimated_sources): def _any_source_silent(sources): - """Returns true if the parameter sources has any silent first dimensions""" + """Return true if the parameter sources has any silent first dimensions""" return np.any( np.all(np.sum(sources, axis=tuple(range(2, sources.ndim))) == 0, axis=1) ) @@ -198,7 +197,6 @@ def bss_eval_sources(reference_sources, estimated_sources, compute_permutation=T 92, pp. 1928-1936, 2012. """ - # make sure the input is of shape (nsrc, nsampl) if estimated_sources.ndim == 1: estimated_sources = estimated_sources[np.newaxis, :] @@ -319,9 +317,7 @@ def bss_eval_sources_framewise( the mean SIR sense (estimated source number ``perm[j]`` corresponds to true source number ``j``). Note: ``perm`` will be ``range(nsrc)`` for all windows if ``compute_permutation`` is ``False`` - """ - # make sure the input is of shape (nsrc, nsampl) if estimated_sources.ndim == 1: estimated_sources = estimated_sources[np.newaxis, :] @@ -367,7 +363,7 @@ def bss_eval_sources_framewise( def bss_eval_images(reference_sources, estimated_sources, compute_permutation=True): - """Implementation of the bss_eval_images function from the + """Compute the bss_eval_images function from the BSS_EVAL Matlab toolbox. Ordering and measurement of the separation quality for estimated source @@ -423,9 +419,7 @@ def bss_eval_images(reference_sources, estimated_sources, compute_permutation=Tr Lutter and Ngoc Q.K. Duong, "The Signal Separation Evaluation Campaign (2007-2010): Achievements and remaining challenges", Signal Processing, 92, pp. 1928-1936, 2012. - """ - # make sure the input has 3 dimensions # assuming input is in shape (nsampl) or (nsrc, nsampl) estimated_sources = np.atleast_3d(estimated_sources) @@ -570,9 +564,7 @@ def bss_eval_images_framewise( true source number j) Note: perm will be range(nsrc) for all windows if compute_permutation is False - """ - # make sure the input has 3 dimensions # assuming input is in shape (nsampl) or (nsrc, nsampl) estimated_sources = np.atleast_3d(estimated_sources) diff --git a/mir_eval/sonify.py b/mir_eval/sonify.py index 64af0341..cdb56122 100644 --- a/mir_eval/sonify.py +++ b/mir_eval/sonify.py @@ -12,7 +12,7 @@ def clicks(times, fs, click=None, length=None): - """Returns a signal with the signal 'click' placed at each specified time + """Return a signal with the signal 'click' placed at each specified time Parameters ---------- @@ -63,7 +63,7 @@ def clicks(times, fs, click=None, length=None): def time_frequency( gram, frequencies, times, fs, function=np.sin, length=None, n_dec=1, threshold=0.01 ): - """Reverse synthesis of a time-frequency representation of a signal + r"""Reverse synthesis of a time-frequency representation of a signal Parameters ---------- @@ -126,7 +126,7 @@ def time_frequency( sample_intervals = np.round(times * fs).astype(int) def _fast_synthesize(frequency): - """A faster way to synthesize a signal. + """Efficiently synthesize a signal. Generate one cycle, and simulate arbitrary repetitions using array indexing tricks. """ @@ -217,7 +217,7 @@ def __interpolator(x): def pitch_contour( times, frequencies, fs, amplitudes=None, function=np.sin, length=None, kind="linear" ): - """Sonify a pitch contour. + r"""Sonify a pitch contour. Parameters ---------- @@ -245,7 +245,6 @@ def pitch_contour( output : np.ndarray synthesized version of the pitch contour """ - fs = float(fs) if length is None: diff --git a/mir_eval/tempo.py b/mir_eval/tempo.py index 2334f2be..aa090da8 100644 --- a/mir_eval/tempo.py +++ b/mir_eval/tempo.py @@ -27,7 +27,7 @@ def validate_tempi(tempi, reference=True): - """Checks that there are two non-negative tempi. + """Check that there are two non-negative tempi. For a reference value, at least one tempo has to be greater than zero. Parameters @@ -36,9 +36,7 @@ def validate_tempi(tempi, reference=True): length-2 array of tempo, in bpm reference : bool indicates a reference value - """ - if tempi.size != 2: raise ValueError("tempi must have exactly two values") @@ -52,7 +50,7 @@ def validate_tempi(tempi, reference=True): def validate(reference_tempi, reference_weight, estimated_tempi): - """Checks that the input annotations to a metric look like valid tempo + """Check that the input annotations to a metric look like valid tempo annotations. Parameters @@ -109,7 +107,6 @@ def detection(reference_tempi, reference_weight, estimated_tempi, tol=0.08): If ``tol < 0`` or ``tol > 1``. """ - validate(reference_tempi, reference_weight, estimated_tempi) if tol < 0 or tol > 1: diff --git a/mir_eval/transcription.py b/mir_eval/transcription.py index 663268ce..65504279 100644 --- a/mir_eval/transcription.py +++ b/mir_eval/transcription.py @@ -115,7 +115,7 @@ def validate(ref_intervals, ref_pitches, est_intervals, est_pitches): - """Checks that the input annotations to a metric look like time intervals + """Check that the input annotations to a metric look like time intervals and a pitch list, and throws helpful errors if not. Parameters @@ -146,7 +146,7 @@ def validate(ref_intervals, ref_pitches, est_intervals, est_pitches): def validate_intervals(ref_intervals, est_intervals): - """Checks that the input annotations to a metric look like time intervals, + """Check that the input annotations to a metric look like time intervals, and throws helpful errors if not. Parameters diff --git a/mir_eval/transcription_velocity.py b/mir_eval/transcription_velocity.py index 5fead3ae..ee9b3a12 100644 --- a/mir_eval/transcription_velocity.py +++ b/mir_eval/transcription_velocity.py @@ -67,7 +67,7 @@ def validate( est_pitches, est_velocities, ): - """Checks that the input annotations have valid time intervals, pitches, + """Check that the input annotations have valid time intervals, pitches, and velocities, and throws helpful errors if not. Parameters