Skip to content

Commit

Permalink
docstyle pass
Browse files Browse the repository at this point in the history
  • Loading branch information
bmcfee committed Mar 20, 2024
1 parent d69527c commit 588ee9a
Show file tree
Hide file tree
Showing 7 changed files with 21 additions and 57 deletions.
18 changes: 4 additions & 14 deletions mir_eval/display.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@


def __expand_limits(ax, limits, which="x"):
"""Helper function to expand axis limits"""

"""Expand axis limits"""
if which == "x":
getter, setter = ax.get_xlim, ax.set_xlim
elif which == "y":
Expand Down Expand Up @@ -60,9 +59,7 @@ def __get_axes(ax=None, fig=None):
new_axes : bool
If `True`, the axis object was newly constructed.
If `False`, the axis object already existed.
"""

new_axes = False

if ax is not None:
Expand Down Expand Up @@ -256,7 +253,6 @@ def labeled_intervals(
ax : matplotlib.pyplot.axes._subplots.AxesSubplot
A handle to the (possibly constructed) plot axes
"""

# Get the axes handle
ax, _ = __get_axes(ax=ax)

Expand Down Expand Up @@ -347,6 +343,7 @@ def __init__(self, base, ticks):
self._map = {int(k): v for k, v in zip(base, ticks)}

def __call__(self, x, pos=None):
"""Map the input position to its corresponding interval label"""
return self._map.get(int(x), "")


Expand Down Expand Up @@ -379,7 +376,6 @@ def hierarchy(intervals_hier, labels_hier, levels=None, ax=None, **kwargs):
ax : matplotlib.pyplot.axes._subplots.AxesSubplot
A handle to the (possibly constructed) plot axes
"""

# This will break if a segment label exists in multiple levels
if levels is None:
levels = list(range(len(intervals_hier)))
Expand Down Expand Up @@ -529,7 +525,6 @@ def pitch(times, frequencies, midi=False, unvoiced=False, ax=None, **kwargs):
ax : matplotlib.pyplot.axes._subplots.AxesSubplot
A handle to the (possibly constructed) plot axes
"""

ax, _ = __get_axes(ax=ax)

times = np.asarray(times)
Expand Down Expand Up @@ -616,7 +611,6 @@ def multipitch(times, frequencies, midi=False, unvoiced=False, ax=None, **kwargs
ax : matplotlib.pyplot.axes._subplots.AxesSubplot
A handle to the (possibly constructed) plot axes
"""

# Get the axes handle
ax, _ = __get_axes(ax=ax)

Expand Down Expand Up @@ -697,7 +691,6 @@ def piano_roll(intervals, pitches=None, midi=None, ax=None, **kwargs):
ax : matplotlib.pyplot.axes._subplots.AxesSubplot
A handle to the (possibly constructed) plot axes
"""

if midi is None:
if pitches is None:
raise ValueError("At least one of `midi` or `pitches` " "must be provided.")
Expand Down Expand Up @@ -745,7 +738,6 @@ def separation(sources, fs=22050, labels=None, alpha=0.75, ax=None, **kwargs):
ax
The axis handle for this plot
"""

# Get the axes handle
ax, new_axes = __get_axes(ax=ax)

Expand Down Expand Up @@ -807,12 +799,11 @@ def separation(sources, fs=22050, labels=None, alpha=0.75, ax=None, **kwargs):


def __ticker_midi_note(x, pos):
"""A ticker function for midi notes.
"""Format midi notes for ticker decoration.
Inputs x are interpreted as midi numbers, and converted
to [NOTE][OCTAVE]+[cents].
"""

NOTES = ["C", "C#", "D", "D#", "E", "F", "F#", "G", "G#", "A", "A#", "B"]

cents = float(np.mod(x, 1.0))
Expand All @@ -830,12 +821,11 @@ def __ticker_midi_note(x, pos):


def __ticker_midi_hz(x, pos):
"""A ticker function for midi pitches.
"""Format midi pitches for ticker decoration.
Inputs x are interpreted as midi numbers, and converted
to Hz.
"""

return "{:g}".format(midi_to_hz(x))


Expand Down
24 changes: 5 additions & 19 deletions mir_eval/segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@


def validate_boundary(reference_intervals, estimated_intervals, trim):
"""Checks that the input annotations to a segment boundary estimation
"""Check that the input annotations to a segment boundary estimation
metric (i.e. one that only takes in segment intervals) look like valid
segment times, and throws helpful errors if not.
Expand All @@ -101,9 +101,7 @@ def validate_boundary(reference_intervals, estimated_intervals, trim):
:func:`mir_eval.io.load_labeled_intervals`.
trim : bool
will the start and end events be trimmed?
"""

if trim:
# If we're trimming, then we need at least 2 intervals
min_size = 2
Expand All @@ -124,7 +122,7 @@ def validate_boundary(reference_intervals, estimated_intervals, trim):
def validate_structure(
reference_intervals, reference_labels, estimated_intervals, estimated_labels
):
"""Checks that the input annotations to a structure estimation metric (i.e.
"""Check that the input annotations to a structure estimation metric (i.e.
one that takes in both segment boundaries and their labels) look like valid
segment times and labels, and throws helpful errors if not.
Expand Down Expand Up @@ -226,9 +224,7 @@ def detection(
recall of reference reference boundaries
f_measure : float
F-measure (weighted harmonic mean of ``precision`` and ``recall``)
"""

validate_boundary(reference_intervals, estimated_intervals, trim)

# Convert intervals to boundaries
Expand Down Expand Up @@ -288,9 +284,7 @@ def deviation(reference_intervals, estimated_intervals, trim=False):
estimated_to_reference : float
median time from each estimated boundary to the
closest reference boundary
"""

validate_boundary(reference_intervals, estimated_intervals, trim)

# Convert intervals to boundaries
Expand Down Expand Up @@ -471,9 +465,7 @@ def rand_index(
-------
rand_index : float > 0
Rand index
"""

validate_structure(
reference_intervals, reference_labels, estimated_intervals, estimated_labels
)
Expand Down Expand Up @@ -519,7 +511,7 @@ def rand_index(


def _contingency_matrix(reference_indices, estimated_indices):
"""Computes the contingency matrix of a true labeling vs an estimated one.
"""Compute the contingency matrix of a true labeling vs an estimated one.
Parameters
----------
Expand Down Expand Up @@ -602,7 +594,7 @@ def ari(
estimated_labels,
frame_size=0.1,
):
"""Adjusted Rand Index (ARI) for frame clustering segmentation evaluation.
"""Compute the Adjusted Rand Index (ARI) for frame clustering segmentation evaluation.
Examples
--------
Expand Down Expand Up @@ -715,7 +707,7 @@ def _mutual_info_score(reference_indices, estimated_indices, contingency=None):


def _entropy(labels):
"""Calculates the entropy for a labeling.
"""Calculate the entropy for a labeling.
Parameters
----------
Expand Down Expand Up @@ -1044,9 +1036,7 @@ def nce(
If `|y_ref|==1`, then `S_under` will be 0.
S_F
F-measure for (S_over, S_under)
"""

validate_structure(
reference_intervals, reference_labels, estimated_intervals, estimated_labels
)
Expand Down Expand Up @@ -1177,9 +1167,7 @@ def vmeasure(
If `|y_ref|==1`, then `V_recall` will be 0.
V_F
F-measure for (V_precision, V_recall)
"""

return nce(
reference_intervals,
reference_labels,
Expand Down Expand Up @@ -1226,9 +1214,7 @@ def evaluate(ref_intervals, ref_labels, est_intervals, est_labels, **kwargs):
scores : dict
Dictionary of scores, where the key is the metric name (str) and
the value is the (float) score achieved.
"""

# Adjust timespan of estimations relative to ground truth
ref_intervals, ref_labels = util.adjust_intervals(
ref_intervals, labels=ref_labels, t_min=0.0
Expand Down
14 changes: 3 additions & 11 deletions mir_eval/separation.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@


def validate(reference_sources, estimated_sources):
"""Checks that the input data to a metric are valid, and throws helpful
"""Check that the input data to a metric are valid, and throws helpful
errors if not.
Parameters
Expand All @@ -71,7 +71,6 @@ def validate(reference_sources, estimated_sources):
matrix containing estimated sources
"""

if reference_sources.shape != estimated_sources.shape:
raise ValueError(
"The shape of estimated sources and the true "
Expand Down Expand Up @@ -136,7 +135,7 @@ def validate(reference_sources, estimated_sources):


def _any_source_silent(sources):
"""Returns true if the parameter sources has any silent first dimensions"""
"""Return true if the parameter sources has any silent first dimensions"""
return np.any(
np.all(np.sum(sources, axis=tuple(range(2, sources.ndim))) == 0, axis=1)
)
Expand Down Expand Up @@ -198,7 +197,6 @@ def bss_eval_sources(reference_sources, estimated_sources, compute_permutation=T
92, pp. 1928-1936, 2012.
"""

# make sure the input is of shape (nsrc, nsampl)
if estimated_sources.ndim == 1:
estimated_sources = estimated_sources[np.newaxis, :]
Expand Down Expand Up @@ -319,9 +317,7 @@ def bss_eval_sources_framewise(
the mean SIR sense (estimated source number ``perm[j]`` corresponds to
true source number ``j``). Note: ``perm`` will be ``range(nsrc)`` for
all windows if ``compute_permutation`` is ``False``
"""

# make sure the input is of shape (nsrc, nsampl)
if estimated_sources.ndim == 1:
estimated_sources = estimated_sources[np.newaxis, :]
Expand Down Expand Up @@ -367,7 +363,7 @@ def bss_eval_sources_framewise(


def bss_eval_images(reference_sources, estimated_sources, compute_permutation=True):
"""Implementation of the bss_eval_images function from the
"""Compute the bss_eval_images function from the
BSS_EVAL Matlab toolbox.
Ordering and measurement of the separation quality for estimated source
Expand Down Expand Up @@ -423,9 +419,7 @@ def bss_eval_images(reference_sources, estimated_sources, compute_permutation=Tr
Lutter and Ngoc Q.K. Duong, "The Signal Separation Evaluation Campaign
(2007-2010): Achievements and remaining challenges", Signal Processing,
92, pp. 1928-1936, 2012.
"""

# make sure the input has 3 dimensions
# assuming input is in shape (nsampl) or (nsrc, nsampl)
estimated_sources = np.atleast_3d(estimated_sources)
Expand Down Expand Up @@ -570,9 +564,7 @@ def bss_eval_images_framewise(
true source number j)
Note: perm will be range(nsrc) for all windows if compute_permutation
is False
"""

# make sure the input has 3 dimensions
# assuming input is in shape (nsampl) or (nsrc, nsampl)
estimated_sources = np.atleast_3d(estimated_sources)
Expand Down
9 changes: 4 additions & 5 deletions mir_eval/sonify.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@


def clicks(times, fs, click=None, length=None):
"""Returns a signal with the signal 'click' placed at each specified time
"""Return a signal with the signal 'click' placed at each specified time
Parameters
----------
Expand Down Expand Up @@ -63,7 +63,7 @@ def clicks(times, fs, click=None, length=None):
def time_frequency(
gram, frequencies, times, fs, function=np.sin, length=None, n_dec=1, threshold=0.01
):
"""Reverse synthesis of a time-frequency representation of a signal
r"""Reverse synthesis of a time-frequency representation of a signal
Parameters
----------
Expand Down Expand Up @@ -126,7 +126,7 @@ def time_frequency(
sample_intervals = np.round(times * fs).astype(int)

def _fast_synthesize(frequency):
"""A faster way to synthesize a signal.
"""Efficiently synthesize a signal.
Generate one cycle, and simulate arbitrary repetitions
using array indexing tricks.
"""
Expand Down Expand Up @@ -217,7 +217,7 @@ def __interpolator(x):
def pitch_contour(
times, frequencies, fs, amplitudes=None, function=np.sin, length=None, kind="linear"
):
"""Sonify a pitch contour.
r"""Sonify a pitch contour.
Parameters
----------
Expand Down Expand Up @@ -245,7 +245,6 @@ def pitch_contour(
output : np.ndarray
synthesized version of the pitch contour
"""

fs = float(fs)

if length is None:
Expand Down
7 changes: 2 additions & 5 deletions mir_eval/tempo.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@


def validate_tempi(tempi, reference=True):
"""Checks that there are two non-negative tempi.
"""Check that there are two non-negative tempi.
For a reference value, at least one tempo has to be greater than zero.
Parameters
Expand All @@ -36,9 +36,7 @@ def validate_tempi(tempi, reference=True):
length-2 array of tempo, in bpm
reference : bool
indicates a reference value
"""

if tempi.size != 2:
raise ValueError("tempi must have exactly two values")

Expand All @@ -52,7 +50,7 @@ def validate_tempi(tempi, reference=True):


def validate(reference_tempi, reference_weight, estimated_tempi):
"""Checks that the input annotations to a metric look like valid tempo
"""Check that the input annotations to a metric look like valid tempo
annotations.
Parameters
Expand Down Expand Up @@ -109,7 +107,6 @@ def detection(reference_tempi, reference_weight, estimated_tempi, tol=0.08):
If ``tol < 0`` or ``tol > 1``.
"""

validate(reference_tempi, reference_weight, estimated_tempi)

if tol < 0 or tol > 1:
Expand Down
4 changes: 2 additions & 2 deletions mir_eval/transcription.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@


def validate(ref_intervals, ref_pitches, est_intervals, est_pitches):
"""Checks that the input annotations to a metric look like time intervals
"""Check that the input annotations to a metric look like time intervals
and a pitch list, and throws helpful errors if not.
Parameters
Expand Down Expand Up @@ -146,7 +146,7 @@ def validate(ref_intervals, ref_pitches, est_intervals, est_pitches):


def validate_intervals(ref_intervals, est_intervals):
"""Checks that the input annotations to a metric look like time intervals,
"""Check that the input annotations to a metric look like time intervals,
and throws helpful errors if not.
Parameters
Expand Down
2 changes: 1 addition & 1 deletion mir_eval/transcription_velocity.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def validate(
est_pitches,
est_velocities,
):
"""Checks that the input annotations have valid time intervals, pitches,
"""Check that the input annotations have valid time intervals, pitches,
and velocities, and throws helpful errors if not.
Parameters
Expand Down

0 comments on commit 588ee9a

Please sign in to comment.