From fd670e62a0438487f1965a69a5ed6576d291b191 Mon Sep 17 00:00:00 2001 From: Nick B Date: Thu, 10 Jan 2019 11:38:54 +0000 Subject: [PATCH] Downbeats: allow weighting of beats per bar * Optional parameter, implicitly defaults to ones for the array * Clean up the handling of lengths into the constructor, it was getting verbose * Check weights don't sum to zero, to avoid divide-by-zero pain. * Weight the HMM results in log space by normalised weight values, as suggested by @Superbock * Add new test to prove that (sufficient, but arbitrary) weighting to 3-time (over 4-time) does indeed return 3-time beats results. This fixes #402. --- madmom/features/downbeats.py | 31 +++++++++++++++++++++++-------- tests/test_features_downbeats.py | 10 ++++++++++ 2 files changed, 33 insertions(+), 8 deletions(-) diff --git a/madmom/features/downbeats.py b/madmom/features/downbeats.py index c350a4263..129ce0915 100644 --- a/madmom/features/downbeats.py +++ b/madmom/features/downbeats.py @@ -153,6 +153,11 @@ class DBNDownBeatTrackingProcessor(Processor): (down-)beat activation function). fps : float, optional Frames per second. + beats_per_bar_weights : list, optional + Weight the beats_per_bar list when choosing. + Higher numbers favour the beat number at the same index, e.g. + for beats_per_bar of [3, 4], a value here for [1, 2] will bias the + the choice towards 4 beats per bar. References ---------- @@ -200,11 +205,15 @@ class DBNDownBeatTrackingProcessor(Processor): def __init__(self, beats_per_bar, min_bpm=MIN_BPM, max_bpm=MAX_BPM, num_tempi=NUM_TEMPI, transition_lambda=TRANSITION_LAMBDA, observation_lambda=OBSERVATION_LAMBDA, threshold=THRESHOLD, - correct=CORRECT, fps=None, **kwargs): + correct=CORRECT, fps=None, beats_per_bar_weights=None, + **kwargs): # pylint: disable=unused-argument # pylint: disable=no-name-in-module # expand arguments to arrays beats_per_bar = np.array(beats_per_bar, ndmin=1) + beats_per_bar_weights = (np.array(beats_per_bar_weights, ndmin=1) + if beats_per_bar_weights + else np.ones(beats_per_bar.shape)) min_bpm = np.array(min_bpm, ndmin=1) max_bpm = np.array(max_bpm, ndmin=1) num_tempi = np.array(num_tempi, ndmin=1) @@ -220,11 +229,14 @@ def __init__(self, beats_per_bar, min_bpm=MIN_BPM, max_bpm=MAX_BPM, if len(transition_lambda) != len(beats_per_bar): transition_lambda = np.repeat(transition_lambda, len(beats_per_bar)) - if not (len(min_bpm) == len(max_bpm) == len(num_tempi) == - len(beats_per_bar) == len(transition_lambda)): - raise ValueError('`min_bpm`, `max_bpm`, `num_tempi`, `num_beats` ' - 'and `transition_lambda` must all have the same ' - 'length.') + lengths = [len(a) for a in (min_bpm, max_bpm, num_tempi, beats_per_bar, + transition_lambda, beats_per_bar_weights)] + if not sum(beats_per_bar_weights): + raise ValueError("`beats_per_bar_weights` cannot total zero") + if np.var(lengths): + raise ValueError('`min_bpm`, `max_bpm`, `num_tempi`, `num_beats`, ' + '`beats_per_bar_weights` and `transition_lambda` ' + 'must all have the same length.') # get num_threads from kwargs num_threads = min(len(beats_per_bar), kwargs.get('num_threads', 1)) # init a pool of workers (if needed) @@ -245,6 +257,7 @@ def __init__(self, beats_per_bar, min_bpm=MIN_BPM, max_bpm=MAX_BPM, self.hmms.append(HiddenMarkovModel(tm, om)) # save variables self.beats_per_bar = beats_per_bar + self.beats_per_bar_weights = beats_per_bar_weights self.threshold = threshold self.correct = correct self.fps = fps @@ -283,8 +296,10 @@ def process(self, activations, **kwargs): # (parallel) decoding of the activations with HMM results = list(self.map(_process_dbn, zip(self.hmms, it.repeat(activations)))) - # choose the best HMM (highest log probability) - best = np.argmax(np.asarray(results)[:, 1]) + # choose the best HMM (highest log probability) after weighting + weights = self.beats_per_bar_weights + scores = np.asarray(results)[:, 1] + np.log(weights / np.sum(weights)) + best = np.argmax(scores) # the best path through the state space path, _ = results[best] # the state space and observation model of the best HMM diff --git a/tests/test_features_downbeats.py b/tests/test_features_downbeats.py index f1a434afe..390ed084f 100644 --- a/tests/test_features_downbeats.py +++ b/tests/test_features_downbeats.py @@ -101,6 +101,16 @@ def test_process(self): downbeats = self.processor(sample_downbeat_act) self.assertTrue(np.allclose(downbeats, np.empty((0, 2)))) + def test_weighting_measure(self): + self.processor = DBNDownBeatTrackingProcessor( + [3, 4], fps=sample_downbeat_act.fps, + beats_per_bar_weights=[100, 1], correct=False) + downbeats = self.processor(sample_downbeat_act) + correct = np.array([[0.08, 1], [0.43, 2], [0.77, 3], + [1.11, 1], [1.45, 2], [1.79, 3], + [2.13, 1], [2.47, 2]]) + self.assertTrue(np.allclose(downbeats, correct)) + class TestPatternTrackingProcessorClass(unittest.TestCase):