From 2747ef5d70c68e428a6833d4d1b613e44c3a6295 Mon Sep 17 00:00:00 2001 From: Justin Salamon Date: Tue, 9 Feb 2021 14:56:43 -0800 Subject: [PATCH] choose_weighted distribution tuple (#144) * Implement _sample_choose_weighted * Add choose_weighted to SUPPORTED_DIST and move validation logic into _validate_distribution Logic moved over from _sample_choose_wighted in util.py * Add choose_weighted to all functions that handle distribution tuples (see detailed commit message) Added handling of choose_weighted in # SUPPORTED_DIST and the following functions: * _validate_distribution * _validate_time * _validate_duration * _validate_time_stretch * _validate_pitch_shift * _validate_label * _validate_source_file * _validate_snr * _ensure_satisfiable_source_time_tuple * _instantiate_event * add_background (docstring) * add_event (docstring) * Add unit test cases for bad choose_weighted tuple data * Get tests up to 100% * minor docstring cleanup * Bump to v1.6.5 and update changelog * Change version to 1.6.5.rc0 * Typo fix --- docs/changes.rst | 4 ++ scaper/core.py | 98 +++++++++++++++++++++++++++++++++++++--------- scaper/util.py | 27 +++++++++++++ scaper/version.py | 2 +- tests/test_core.py | 23 ++++++++++- tests/test_util.py | 20 +++++++++- 6 files changed, 153 insertions(+), 21 deletions(-) diff --git a/docs/changes.rst b/docs/changes.rst index d63c7df..2c5bc84 100644 --- a/docs/changes.rst +++ b/docs/changes.rst @@ -2,6 +2,10 @@ Changelog --------- +v1.6.5.rc0 +~~~~~~~~~~ +- Added a new distirbution tuple: ``("choose_weighted", list_of_options, probabilities)``, which supports weighted sampling: ``list_of_options[i]`` is chosen with probability ``probabilities[i]``. + v1.6.4 ~~~~~~ - Scaper.generate now accepts a new argument for controlling trade-off between speed and quality in pitch shifting and time stretching: diff --git a/scaper/core.py b/scaper/core.py index fe9c3b7..5f6e3ec 100644 --- a/scaper/core.py +++ b/scaper/core.py @@ -24,6 +24,7 @@ from .util import _sample_trunc_norm from .util import _sample_uniform from .util import _sample_choose +from .util import _sample_choose_weighted from .util import _sample_normal from .util import _sample_const from .util import max_polyphony @@ -33,8 +34,25 @@ from .audio import peak_normalize from .version import version as scaper_version + +# HEADS UP! Adding a new distribution tuple? +# Make sure it's properly handled in all of the following: +# SUPPORTED_DIST +# _validate_distribution +# _validate_time +# _validate_duration +# _validate_time_stretch +# _validate_pitch_shift +# _validate_label +# _validate_source_file +# _validate_snr +# _ensure_satisfiable_source_time_tuple +# _instantiate_event +# add_background (docstring) +# add_event (docstring) SUPPORTED_DIST = {"const": _sample_const, "choose": _sample_choose, + "choose_weighted": _sample_choose_weighted, "uniform": _sample_uniform, "normal": _sample_normal, "truncnorm": _sample_trunc_norm} @@ -450,6 +468,27 @@ def _validate_distribution(dist_tuple): raise ScaperError( 'The "choose" distribution tuple must be of length 2 where ' 'the second item is a list.') + # If it's a choose_weighted, tuple must be of length 3, items 2 and 3 must + # be lists of the same length, and the list in item 3 must contain floats + # in the range [0, 1] that sum to 1 (i.e. valid probabilities). + elif dist_tuple[0] == 'choose_weighted': + if len(dist_tuple) != 3: + raise ScaperError('"choose_weighted" distribution tuple must have length 3') + if not isinstance(dist_tuple[1], list) or \ + not isinstance(dist_tuple[2], list) or \ + len(dist_tuple[1]) != len(dist_tuple[2]): + msg = ('The 2nd and 3rd items of the "choose_weighted" distribution tuple ' + 'must be lists of the same length.') + raise ScaperError(msg) + probabilities = np.asarray(dist_tuple[2]) + if probabilities.min() < 0 or probabilities.max() > 1: + msg = ('Values in the probabilities list of the "choose_weighted" ' + 'distribution tuple must be in the range [0, 1].') + raise ScaperError(msg) + if not np.allclose(probabilities.sum(), 1): + msg = ('Values in the probabilities list of the "choose_weighted" ' + 'distribution tuple must sum to 1.') + raise ScaperError(msg) # If it's a uniform distribution, tuple must be of length 3, 2nd item must # be a real number and 3rd item must be real and greater/equal to the 2nd. elif dist_tuple[0] == 'uniform': @@ -527,6 +566,12 @@ def _ensure_satisfiable_source_time_tuple(source_time, source_duration, event_du source_time[1][i] = max(0, source_duration - event_duration) source_time[1] = list(set(source_time[1])) + # For weighted_choose we do the same as choose but without removing duplicates + elif source_time[0] == 'choose_weighted': + for i, t in enumerate(source_time[1]): + if t + event_duration > source_duration: + source_time[1][i] = max(0, source_duration - event_duration) + # If it's a uniform distribution, tuple must be of length 3, We change the 3rd # item to source_duration - event_duration so that we stay in bounds. If the min # out of bounds, we change it to be source_duration - event_duration. @@ -596,7 +641,7 @@ def _validate_label(label, allowed_labels): raise ScaperError( 'Label value must match one of the available labels: ' '{:s}'.format(str(allowed_labels))) - elif label[0] == "choose": + elif label[0] == "choose" or label[0] == "choose_weighted": if label[1]: # list is not empty if not set(label[1]).issubset(set(allowed_labels)): raise ScaperError( @@ -640,8 +685,8 @@ def _validate_source_file(source_file_tuple, label_tuple): if label_tuple[0] != "const" or label_tuple[1] != parent_name: raise ScaperError( "Source file's parent folder name does not match label.") - # Otherwise it must be specified using "choose" - elif source_file_tuple[0] == "choose": + # Otherwise it must be specified using one of "choose" or "choose_weighted" + elif source_file_tuple[0] == "choose" or source_file_tuple[0] == "choose_weighted": if source_file_tuple[1]: # list is not empty if not all(os.path.isfile(x) for x in source_file_tuple[1]): raise ScaperError( @@ -678,7 +723,7 @@ def _validate_time(time_tuple): time_tuple[1] < 0): raise ScaperError( 'Time must be a real non-negative number.') - elif time_tuple[0] == "choose": + elif time_tuple[0] == "choose" or time_tuple[0] == "choose_weighted": if (not time_tuple[1] or not is_real_array(time_tuple[1]) or not all(x is not None for x in time_tuple[1]) or @@ -730,7 +775,7 @@ def _validate_duration(duration_tuple): duration_tuple[1] <= 0): raise ScaperError( 'Duration must be a real number greater than zero.') - elif duration_tuple[0] == "choose": + elif duration_tuple[0] == "choose" or duration_tuple[0] == "choose_weighted": if (not duration_tuple[1] or not is_real_array(duration_tuple[1]) or not all(x > 0 for x in duration_tuple[1])): @@ -779,7 +824,7 @@ def _validate_snr(snr_tuple): if not is_real_number(snr_tuple[1]): raise ScaperError( 'SNR must be a real number.') - elif snr_tuple[0] == "choose": + elif snr_tuple[0] == "choose" or snr_tuple[0] == "choose_weighted": if (not snr_tuple[1] or not is_real_array(snr_tuple[1])): raise ScaperError( @@ -815,7 +860,7 @@ def _validate_pitch_shift(pitch_shift_tuple): if not is_real_number(pitch_shift_tuple[1]): raise ScaperError( 'Pitch shift must be a real number.') - elif pitch_shift_tuple[0] == "choose": + elif pitch_shift_tuple[0] == "choose" or pitch_shift_tuple[0] == "choose_weighted": if (not pitch_shift_tuple[1] or not is_real_array(pitch_shift_tuple[1])): raise ScaperError( @@ -854,7 +899,7 @@ def _validate_time_stretch(time_stretch_tuple): time_stretch_tuple[1] <= 0): raise ScaperError( 'Time stretch must be a real number greater than zero.') - elif time_stretch_tuple[0] == "choose": + elif time_stretch_tuple[0] == "choose" or time_stretch_tuple[0] == "choose_weighted": if (not time_stretch_tuple[1] or not is_real_array(time_stretch_tuple[1]) or not all(x > 0 for x in time_stretch_tuple[1])): @@ -1134,18 +1179,26 @@ def add_background(self, label, source_file, source_time): value will be chosen at random from all available labels or files as determined automatically by Scaper by examining the file structure of ``bg_path`` provided during initialization. + * ``("choose_weighted", valuelist, probabilities)``: choose a value + from ``valuelist`` via weighted sampling, where the probability of + sampling ``valuelist[i]`` is given by ``probabilities[i]``. The + ``probabilities`` list must contain a valid probability distribution, + i.e., all values must be in the range [0, 1] and sum to one. * ``("uniform", min_value, max_value)`` : sample a random value from a uniform distribution between ``min_value`` and ``max_value``. * ``("normal", mean, stddev)`` : sample a random value from a normal distribution defined by its mean ``mean`` and standard deviation ``stddev``. + * ``("truncnorm", mean, stddev, min, max)``: sapmle a random value from + a truncated normal distribution defined by its mean ``mean``, standard + deviation ``stddev``, minimum value ``min`` and maximum value ``max``. IMPORTANT: not all parameters support all distribution tuples. In - particular, ``label`` and ``source_file`` only support ``"const"`` and - ``"choose"``, whereas ``source_time`` supports all distribution tuples. - As noted above, only ``label`` and ``source_file`` support providing an - empty ``valuelist`` with ``"choose"``. + particular, ``label`` and ``source_file`` only support ``"const"``, + ``"choose"`` and ``choose_weighted``, whereas ``source_time`` supports + all distribution tuples. As noted above, only ``label`` and ``source_file`` + support providing an empty ``valuelist`` with ``"choose"``. ''' # These values are fixed for the background sound @@ -1245,18 +1298,26 @@ def add_event(self, label, source_file, source_time, event_time, source files as determined automatically by Scaper by examining the file structure of ``fg_path`` provided during initialization. + * ``("choose_weighted", valuelist, probabilities)``: choose a value + from ``valuelist`` via weighted sampling, where the probability of + sampling ``valuelist[i]`` is given by ``probabilities[i]``. The + ``probabilities`` list must contain a valid probability distribution, + i.e., all values must be in the range [0, 1] and sum to one. * ``("uniform", min_value, max_value)`` : sample a random value from a uniform distribution between ``min_value`` and ``max_value`` (including ``max_value``). * ``("normal", mean, stddev)`` : sample a random value from a normal distribution defined by its mean ``mean`` and standard deviation ``stddev``. + * ``("truncnorm", mean, stddev, min, max)``: sapmle a random value from + a truncated normal distribution defined by its mean ``mean``, standard + deviation ``stddev``, minimum value ``min`` and maximum value ``max``. IMPORTANT: not all parameters support all distribution tuples. In - particular, ``label`` and ``source_file`` only support ``"const"`` and - ``"choose"``, whereas the remaining parameters support all distribution - tuples. As noted above, only ``label`` and ``source_file`` support - providing an empty ``valuelist`` with ``"choose"``. + particular, ``label`` and ``source_file`` only support ``"const"``, + ``"choose"`` and ``"choose_weighted"``, whereas the remaining parameters + support all distribution tuples. As noted above, only ``label`` and + ``source_file`` support providing an empty ``valuelist`` with ``"choose"``. See Also -------- @@ -1355,6 +1416,7 @@ def _instantiate_event(self, event, isbackground=False, allowed_labels = self.fg_labels # determine label + # special case: choose tuple with empty list if event.label[0] == "choose" and not event.label[1]: label_tuple = list(event.label) label_tuple[1] = allowed_labels @@ -1380,9 +1442,9 @@ def _instantiate_event(self, event, isbackground=False, used_labels.append(label) # determine source file + # special case: choose tuple with empty list if event.source_file[0] == "choose" and not event.source_file[1]: - source_files = _get_sorted_files( - os.path.join(file_path, label)) + source_files = _get_sorted_files(os.path.join(file_path, label)) source_file_tuple = list(event.source_file) source_file_tuple[1] = source_files source_file_tuple = tuple(source_file_tuple) diff --git a/scaper/util.py b/scaper/util.py index 33cd352..bf03930 100644 --- a/scaper/util.py +++ b/scaper/util.py @@ -271,6 +271,33 @@ def _sample_choose(list_of_options, random_state): return new_list_of_options[index] +def _sample_choose_weighted(list_of_options, probabilities, random_state): + ''' + Return a random item from ```list_of_options``` using weighted sampling defined + by ```probabilities```, using random_state. The number of items in ```list_of_options``` + and ```probabilities``` must match, and the values in ```probabilities``` must be in the + range [0, 1] and sum to 1. Unlike ```_sample_choose```, duplicates in + ```list_of_options``` are not removed prior to sampling. + + Parameters + ---------- + list_of_options : list + List of items to choose from. + probabilities : list of floats + List of probabilities corresponding to the elements in ```list_of_options```, such + that the item in ```list_of_options[i]``` is chosen with probability ```probabilities[i]```. + random_state : mtrand.RandomState + RandomState object used to sample from this distribution. + + Returns + ------- + value : any + A random item chosen from ```list_of_options```. + + ''' + return random_state.choice(list_of_options, p=probabilities) + + def _sample_trunc_norm(mu, sigma, trunc_min, trunc_max, random_state): ''' Return a random value sampled from a truncated normal distribution with diff --git a/scaper/version.py b/scaper/version.py index 98e03ab..cf25c2c 100644 --- a/scaper/version.py +++ b/scaper/version.py @@ -3,4 +3,4 @@ """Version info""" short_version = '1.6' -version = '1.6.4' +version = '1.6.5.rc0' diff --git a/tests/test_core.py b/tests/test_core.py index 1222deb..9b75fa1 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -915,6 +915,14 @@ def test_ensure_satisfiable_source_time_tuple(): assert (warn) assert np.allclose(_adjusted[1], [0, 1, 2, 5]) + _test_dist = ('choose_weighted', [0, 1, 2, 10, 12, 15, 20], + [0.1, 0.2, 0.05, 0.05, 0.3, 0.1, 0.2]) + _adjusted, warn = scaper.core._ensure_satisfiable_source_time_tuple( + _test_dist, source_duration, event_duration) + assert (warn) + assert np.allclose(_adjusted[1], [0, 1, 2, 5, 5, 5, 5]) + assert np.allclose(_adjusted[2], [0.1, 0.2, 0.05, 0.05, 0.3, 0.1, 0.2]) + def test_validate_distribution(): @@ -940,7 +948,20 @@ def __test_bad_tuple_list(tuple_list): # supported dist tuples, but bad arugments badargs = [('const', 1, 2), - ('choose', 1), ('choose', [], 1), + ('choose', 1), + ('choose', [], 1), + ('choose_weighted'), + ('choose_weighted', []), + ('choose_weighted', [1, 2, 3]), + ('choose_weighted', [1, 2, 3], 5), + ('choose_weighted', 5, [0.2, 0.3, 0.5]), + ('choose_weighted', [1, 2, 3], []), + ('choose_weighted', [1, 2, 3], [0.4, 0.6]), + ('choose_weighted', [1, 2, 3], [0.4, 0.6, 0.1]), + ('choose_weighted', [1, 2, 3], [0.4, 0.3, 0.2]), + ('choose_weighted', [1, 2, 3], [0.7, -0.5, 0.8]), + ('choose_weighted', [1, 2, 3], [1.2, -0.3, 0.1]), + ('choose_weighted', [1, 2, 3], [1.2, 0.3, 0.1]), ('uniform', 1), ('uniform', 1, 2, 3), ('uniform', 2, 1), ('uniform', 'one', 2), ('uniform', 1, 'two'), ('uniform', 0, 1j), ('uniform', 1j, 2), diff --git a/tests/test_util.py b/tests/test_util.py index 443cc6e..215bf84 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -9,7 +9,7 @@ from scaper.util import _validate_folder_path from scaper.util import _get_sorted_files from scaper.util import _populate_label_list -from scaper.util import _sample_trunc_norm, _sample_choose +from scaper.util import _sample_trunc_norm, _sample_choose, _sample_choose_weighted from scaper.util import max_polyphony from scaper.util import polyphony_gini from scaper.util import is_real_number, is_real_array @@ -158,6 +158,24 @@ def test_sample_choose(): pytest.warns(ScaperWarning, _sample_choose, [0, 1, 2, 2, 2], rng) +def test_sample_choose_weighted(): + # make sure probabilities are factored in + rng = _check_random_state(0) + assert _sample_choose_weighted([0, 1, 2], [1, 0, 0], rng) == 0 + assert _sample_choose_weighted([0, 1, 2], [0, 1, 0], rng) == 1 + assert _sample_choose_weighted([0, 1, 2], [0, 0, 1], rng) == 2 + + samples = [] + for _ in range(100000): + samples.append(_sample_choose_weighted([0, 1], [0.3, 0.7], rng)) + + samples = np.asarray(samples) + zero_ratio = (samples == 0).sum() / len(samples) + one_ratio = (samples == 1).sum() / len(samples) + assert np.allclose(zero_ratio, 0.3, atol=1e-2) + assert np.allclose(one_ratio, 0.7, atol=1e-2) + + def test_sample_trunc_norm(): ''' Should return values from a truncated normal distribution.