From 08d58ed9e784bc2c0a6fb26e95d19dfe3b9160dd Mon Sep 17 00:00:00 2001 From: Brian McFee Date: Wed, 19 Sep 2018 12:46:47 -0400 Subject: [PATCH] added a little helper function to vggish to get features directly from waveforms --- openmic/vggish/__init__.py | 38 ++++++++++++++++++++++++++++++ tests/test_openmic_vggish_model.py | 17 +++++++++++++ 2 files changed, 55 insertions(+) diff --git a/openmic/vggish/__init__.py b/openmic/vggish/__init__.py index d77b900..1eda7b6 100644 --- a/openmic/vggish/__init__.py +++ b/openmic/vggish/__init__.py @@ -23,3 +23,41 @@ __pproc__ = Postprocessor(PCA_PARAMS) postprocess = __pproc__.postprocess + + +def waveform_to_features(data, sample_rate, compress=True): + '''Converts an audio waveform to VGGish features, with or without + PCA compression. + + Parameters + ---------- + data : np.array of either one dimension (mono) or two dimensions (stereo) + + sample_rate: + Sample rate of the audio data + + compress : bool + If True, PCA and quantization are applied to the features. + If False, the features are taken directly from the model output + + Returns + ------- + time_points : np.ndarray, len=n + Time points in seconds of the features + + features : np.ndarray, shape=(n, 128) + The output features, with or without PCA compression and quantization. + ''' + + import tensorflow as tf + + examples = waveform_to_examples(data, sample_rate) + + with tf.Graph().as_default(), tf.Session() as sess: + time_points, features = transform(examples, sess) + + if compress: + features_z = postprocess(features) + return time_points, features_z + + return time_points, features diff --git a/tests/test_openmic_vggish_model.py b/tests/test_openmic_vggish_model.py index 1830637..8cd0ea3 100644 --- a/tests/test_openmic_vggish_model.py +++ b/tests/test_openmic_vggish_model.py @@ -1,9 +1,12 @@ import pytest +import numpy as np +import soundfile as sf import tensorflow as tf import openmic.vggish.inputs import openmic.vggish.model as model +from openmic.vggish import waveform_to_features def test_model_transform_soundfile(ogg_file): @@ -12,3 +15,17 @@ def test_model_transform_soundfile(ogg_file): time_points, features = model.transform(examples, sess) assert len(time_points) == len(features) > 1 + + +def test_wf_to_features(ogg_file): + data, rate = sf.read(ogg_file) + + time_points_z, features_z = waveform_to_features(data, rate, compress=True) + assert len(time_points_z) == len(features_z) + + time_points, features = waveform_to_features(data, rate, compress=False) + assert len(time_points) == len(features) + + assert np.allclose(time_points, time_points_z) + + assert np.allclose(features_z, openmic.vggish.postprocess(features))