diff --git a/.gitignore b/.gitignore new file mode 100644 index 00000000..29d7036f --- /dev/null +++ b/.gitignore @@ -0,0 +1,112 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# pyenv +.python-version + +# celery beat schedule file +celerybeat-schedule + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.vscode +.DS_Store +__pycache__ +**/reporting + +pretrained_models +docs/build +.vscode \ No newline at end of file diff --git a/LICENSE b/LICENSE new file mode 100644 index 00000000..484aef79 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2019-present, Deezer SA. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file diff --git a/MANIFEST.in b/MANIFEST.in new file mode 100644 index 00000000..31960c79 --- /dev/null +++ b/MANIFEST.in @@ -0,0 +1,3 @@ +include src/resources/*.json +include README.md +include LICENSE \ No newline at end of file diff --git a/Makefile b/Makefile new file mode 100644 index 00000000..989f5ce3 --- /dev/null +++ b/Makefile @@ -0,0 +1,30 @@ +# ======================================================= +# Build script for distribution packaging. +# +# @author Deezer Research +# @licence MIT Licence +# ======================================================= + +clean: + rm -Rf *.egg-info + rm -Rf dist + +build: + @echo "=== Build CPU bdist package" + @python3 setup.py sdist + @echo "=== CPU version checksum" + @openssl sha256 dist/*.tar.gz + +build-gpu: + @echo "=== Build GPU bdist package" + @python3 setup.py sdist --target gpu + @echo "=== GPU version checksum" + @openssl sha256 dist/*.tar.gz + +upload: + twine upload dist/* + +test-upload: + twine upload --repository-url https://test.pypi.org/legacy/ dist/* + +all: clean build build-gpu upload \ No newline at end of file diff --git a/README.md b/README.md index a5f6615a..b80f83dd 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,64 @@ -# spleeter + - +[![PyPI version](https://badge.fury.io/py/spleeter.svg)](https://badge.fury.io/py/spleeter) ![Conda](https://img.shields.io/conda/dn/conda-forge/spleeter) +## About -spleeter will be made available soon! +**Spleeter** is the [Deezer](https://www.deezer.com/) source separation library with pretrained models +written in [Python](https://www.python.org/) and uses [Tensorflow](tensorflow.org/). It makes it easy +to train source separation model (assuming you have a dataset of isolated sources), and provides +already trained state of the art model for performing various flavour of separation : + +* Vocals (singing voice) / accompaniment separation ([2 stems](https://github.com/deezer/spleeter/wiki/2.-Getting-started#using-2stems-model)) +* Vocals / drums / bass / other separation ([4 stems](https://github.com/deezer/spleeter/wiki/2.-Getting-started#using-4stems-model)) +* Vocals / drums / bass / piano / other separation ([5 stems](https://github.com/deezer/spleeter/wiki/2.-Getting-started#using-5stems-model)) + +2 stems and 4 stems models have state of the art performances on the +[musdb](https://sigsep.github.io/datasets/musdb.html) dataset. It is also very fast as +it can perform separation of audio files to 4 stems 100x faster than real-time when run on a *GPU*. +We designed it so you can use it straight from [command line](https://github.com/deezer/spleeter/wiki/2.-Getting-started#usage) +as well as directly in your own development pipeline as a +[Python library](https://github.com/deezer/spleeter/wiki/4.-API-Reference#separator) + +**Spleeter** can be installed with [Conda](https://github.com/deezer/spleeter/wiki/1.-Installation#using-conda), +with [pip](https://github.com/deezer/spleeter/wiki/1.-Installation#using-pip) or be used with +[Docker](https://github.com/deezer/spleeter/wiki/2.-Getting-started#using-docker-image). + +## Quick start + +Want to try it out ? Just clone the repository and install a +[Conda](https://github.com/deezer/spleeter/wiki/1.-Installation#using-conda) +environment to start separating audio file as follows: + +```bash +$ git clone https://github.com/Deezer/spleeter +$ conda env create -f spleeter/conda/spleeter-cpu.yaml +$ conda activate spleeter-cpu +$ spleeter separate -i spleeter/audio_example.mp3 -p spleeter:2stems -o output +``` +You should get two separated audio files (`vocals.wav` and `accompaniment.wav`) +in the `output/audio_example` folder. + +For a more detailed documentation, please check the [repository wiki](https://github.com/deezer/spleeter/wiki) + +## Reference +If you use **Spleeter** in your work, please cite: + +``` +@misc{spleeter2019, + title={Spleeter: A Fast And State-of-the Art Music Source Separation Tool With Pre-trained Models}, + author={Romain Hennequin and Anis Khlif and Felix Voituret and Manuel Moussallam}, + howpublished={Late-Breaking/Demo ISMIR 2019}, + month={November}, + year={2019} +} +``` + +## License +The code of **Spleeter** is MIT-licensed. + +## Note +This repository include a demo audio file `audio_example.mp3` which is an excerpt +from Slow Motion Dream by Steven M Bryant (c) copyright 2011 Licensed under a Creative +Commons Attribution (3.0) license. http://dig.ccmixter.org/files/stevieb357/34740 +Ft: CSoul,Alex Beroza & Robert Siekawitch diff --git a/audio_example.mp3 b/audio_example.mp3 new file mode 100644 index 00000000..cc917888 Binary files /dev/null and b/audio_example.mp3 differ diff --git a/conda/spleeter-cpu.yaml b/conda/spleeter-cpu.yaml new file mode 100644 index 00000000..df9ed04b --- /dev/null +++ b/conda/spleeter-cpu.yaml @@ -0,0 +1,18 @@ +name: spleeter-cpu + +channels: + - conda-forge + - anaconda + +dependencies: + - python=3.7 + - tensorflow=1.14.0 + - ffmpeg + - pandas==0.25.1 + - requests + - pip + - pip: + - museval==0.3.0 + - musdb==0.3.1 + - norbert==0.2.1 + - spleeter diff --git a/conda/spleeter-gpu.yaml b/conda/spleeter-gpu.yaml new file mode 100644 index 00000000..b269a35f --- /dev/null +++ b/conda/spleeter-gpu.yaml @@ -0,0 +1,19 @@ +name: spleeter-gpu + +channels: + - conda-forge + - anaconda + +dependencies: + - python=3.7 + - tensorflow-gpu=1.14.0 + - ffmpeg + - pandas==0.25.1 + - requests + - pip + - pip: + - museval==0.3.0 + - musdb==0.3.1 + - norbert==0.2.1 + - spleeter + diff --git a/configs/2stems/base_config.json b/configs/2stems/base_config.json new file mode 100644 index 00000000..5f0f5fe2 --- /dev/null +++ b/configs/2stems/base_config.json @@ -0,0 +1,28 @@ +{ + "train_csv": "path/to/train.csv", + "validation_csv": "path/to/test.csv", + "model_dir": "2stems", + "mix_name": "mix", + "instrument_list": ["vocals", "accompaniment"], + "sample_rate":44100, + "frame_length":4096, + "frame_step":1024, + "T":512, + "F":1024, + "n_channels":2, + "separation_exponent":2, + "mask_extension":"zeros", + "learning_rate": 1e-4, + "batch_size":4, + "training_cache":"training_cache", + "validation_cache":"validation_cache", + "train_max_steps": 1000000, + "throttle_secs":300, + "random_seed":0, + "save_checkpoints_steps":150, + "save_summary_steps":5, + "model":{ + "type":"unet.unet", + "params":{} + } +} diff --git a/configs/4stems/base_config.json b/configs/4stems/base_config.json new file mode 100644 index 00000000..b458c56b --- /dev/null +++ b/configs/4stems/base_config.json @@ -0,0 +1,31 @@ +{ + "train_csv": "path/to/train.csv", + "validation_csv": "path/to/test.csv", + "model_dir": "4stems", + "mix_name": "mix", + "instrument_list": ["vocals", "drums", "bass", "other"], + "sample_rate":44100, + "frame_length":4096, + "frame_step":1024, + "T":512, + "F":1024, + "n_channels":2, + "separation_exponent":2, + "mask_extension":"zeros", + "learning_rate": 1e-4, + "batch_size":4, + "training_cache":"training_cache", + "validation_cache":"validation_cache", + "train_max_steps": 1500000, + "throttle_secs":600, + "random_seed":3, + "save_checkpoints_steps":300, + "save_summary_steps":5, + "model":{ + "type":"unet.unet", + "params":{ + "conv_activation":"ELU", + "deconv_activation":"ELU" + } + } +} diff --git a/configs/5stems/base_config.json b/configs/5stems/base_config.json new file mode 100644 index 00000000..aad63121 --- /dev/null +++ b/configs/5stems/base_config.json @@ -0,0 +1,31 @@ +{ + "train_csv": "path/to/train.csv", + "validation_csv": "path/to/test.csv", + "model_dir": "5stems", + "mix_name": "mix", + "instrument_list": ["vocals", "piano", "drums", "bass", "other"], + "sample_rate":44100, + "frame_length":4096, + "frame_step":1024, + "T":512, + "F":1024, + "n_channels":2, + "separation_exponent":2, + "mask_extension":"zeros", + "learning_rate": 1e-4, + "batch_size":4, + "training_cache":"training_cache", + "validation_cache":"validation_cache", + "train_max_steps": 2500000, + "throttle_secs":600, + "random_seed":8, + "save_checkpoints_steps":300, + "save_summary_steps":5, + "model":{ + "type":"unet.softmax_unet", + "params":{ + "conv_activation":"ELU", + "deconv_activation":"ELU" + } + } +} diff --git a/configs/musdb_config.json b/configs/musdb_config.json new file mode 100644 index 00000000..28b0def7 --- /dev/null +++ b/configs/musdb_config.json @@ -0,0 +1,32 @@ +{ + "train_csv": "configs/musdb_train.csv", + "validation_csv": "configs/musdb_validation.csv", + "model_dir": "musdb_model", + "mix_name": "mix", + "instrument_list": ["vocals", "drums", "bass", "other"], + "sample_rate":44100, + "frame_length":4096, + "frame_step":1024, + "T":512, + "F":1024, + "n_channels":2, + "n_chunks_per_song":1, + "separation_exponent":2, + "mask_extension":"zeros", + "learning_rate": 1e-4, + "batch_size":4, + "training_cache":"cache/training", + "validation_cache":"cache/validation", + "train_max_steps": 100000, + "throttle_secs":600, + "random_seed":3, + "save_checkpoints_steps":300, + "save_summary_steps":5, + "model":{ + "type":"unet.unet", + "params":{ + "conv_activation":"ELU", + "deconv_activation":"ELU" + } + } +} diff --git a/configs/musdb_train.csv b/configs/musdb_train.csv new file mode 100644 index 00000000..b8bab3f6 --- /dev/null +++ b/configs/musdb_train.csv @@ -0,0 +1,87 @@ +mix_path,vocals_path,drums_path,bass_path,other_path,duration +train/A Classic Education - NightOwl/mixture.wav,train/A Classic Education - NightOwl/vocals.wav,train/A Classic Education - NightOwl/drums.wav,train/A Classic Education - NightOwl/bass.wav,train/A Classic Education - NightOwl/other.wav,171.247166 +train/ANiMAL - Clinic A/mixture.wav,train/ANiMAL - Clinic A/vocals.wav,train/ANiMAL - Clinic A/drums.wav,train/ANiMAL - Clinic A/bass.wav,train/ANiMAL - Clinic A/other.wav,237.865215 +train/ANiMAL - Easy Tiger/mixture.wav,train/ANiMAL - Easy Tiger/vocals.wav,train/ANiMAL - Easy Tiger/drums.wav,train/ANiMAL - Easy Tiger/bass.wav,train/ANiMAL - Easy Tiger/other.wav,205.473379 +train/Actions - Devil's Words/mixture.wav,train/Actions - Devil's Words/vocals.wav,train/Actions - Devil's Words/drums.wav,train/Actions - Devil's Words/bass.wav,train/Actions - Devil's Words/other.wav,196.626576 +train/Actions - South Of The Water/mixture.wav,train/Actions - South Of The Water/vocals.wav,train/Actions - South Of The Water/drums.wav,train/Actions - South Of The Water/bass.wav,train/Actions - South Of The Water/other.wav,176.610975 +train/Aimee Norwich - Child/mixture.wav,train/Aimee Norwich - Child/vocals.wav,train/Aimee Norwich - Child/drums.wav,train/Aimee Norwich - Child/bass.wav,train/Aimee Norwich - Child/other.wav,189.080091 +train/Alexander Ross - Velvet Curtain/mixture.wav,train/Alexander Ross - Velvet Curtain/vocals.wav,train/Alexander Ross - Velvet Curtain/drums.wav,train/Alexander Ross - Velvet Curtain/bass.wav,train/Alexander Ross - Velvet Curtain/other.wav,514.298776 +train/Angela Thomas Wade - Milk Cow Blues/mixture.wav,train/Angela Thomas Wade - Milk Cow Blues/vocals.wav,train/Angela Thomas Wade - Milk Cow Blues/drums.wav,train/Angela Thomas Wade - Milk Cow Blues/bass.wav,train/Angela Thomas Wade - Milk Cow Blues/other.wav,210.906848 +train/Atlantis Bound - It Was My Fault For Waiting/mixture.wav,train/Atlantis Bound - It Was My Fault For Waiting/vocals.wav,train/Atlantis Bound - It Was My Fault For Waiting/drums.wav,train/Atlantis Bound - It Was My Fault For Waiting/bass.wav,train/Atlantis Bound - It Was My Fault For Waiting/other.wav,268.051156 +train/Auctioneer - Our Future Faces/mixture.wav,train/Auctioneer - Our Future Faces/vocals.wav,train/Auctioneer - Our Future Faces/drums.wav,train/Auctioneer - Our Future Faces/bass.wav,train/Auctioneer - Our Future Faces/other.wav,207.702494 +train/AvaLuna - Waterduct/mixture.wav,train/AvaLuna - Waterduct/vocals.wav,train/AvaLuna - Waterduct/drums.wav,train/AvaLuna - Waterduct/bass.wav,train/AvaLuna - Waterduct/other.wav,259.111474 +train/BigTroubles - Phantom/mixture.wav,train/BigTroubles - Phantom/vocals.wav,train/BigTroubles - Phantom/drums.wav,train/BigTroubles - Phantom/bass.wav,train/BigTroubles - Phantom/other.wav,146.750113 +train/Bill Chudziak - Children Of No-one/mixture.wav,train/Bill Chudziak - Children Of No-one/vocals.wav,train/Bill Chudziak - Children Of No-one/drums.wav,train/Bill Chudziak - Children Of No-one/bass.wav,train/Bill Chudziak - Children Of No-one/other.wav,230.736689 +train/Black Bloc - If You Want Success/mixture.wav,train/Black Bloc - If You Want Success/vocals.wav,train/Black Bloc - If You Want Success/drums.wav,train/Black Bloc - If You Want Success/bass.wav,train/Black Bloc - If You Want Success/other.wav,398.547302 +train/Celestial Shore - Die For Us/mixture.wav,train/Celestial Shore - Die For Us/vocals.wav,train/Celestial Shore - Die For Us/drums.wav,train/Celestial Shore - Die For Us/bass.wav,train/Celestial Shore - Die For Us/other.wav,278.476916 +train/Chris Durban - Celebrate/mixture.wav,train/Chris Durban - Celebrate/vocals.wav,train/Chris Durban - Celebrate/drums.wav,train/Chris Durban - Celebrate/bass.wav,train/Chris Durban - Celebrate/other.wav,301.603991 +train/Clara Berry And Wooldog - Air Traffic/mixture.wav,train/Clara Berry And Wooldog - Air Traffic/vocals.wav,train/Clara Berry And Wooldog - Air Traffic/drums.wav,train/Clara Berry And Wooldog - Air Traffic/bass.wav,train/Clara Berry And Wooldog - Air Traffic/other.wav,173.267302 +train/Clara Berry And Wooldog - Stella/mixture.wav,train/Clara Berry And Wooldog - Stella/vocals.wav,train/Clara Berry And Wooldog - Stella/drums.wav,train/Clara Berry And Wooldog - Stella/bass.wav,train/Clara Berry And Wooldog - Stella/other.wav,195.558458 +train/Cnoc An Tursa - Bannockburn/mixture.wav,train/Cnoc An Tursa - Bannockburn/vocals.wav,train/Cnoc An Tursa - Bannockburn/drums.wav,train/Cnoc An Tursa - Bannockburn/bass.wav,train/Cnoc An Tursa - Bannockburn/other.wav,294.521905 +train/Creepoid - OldTree/mixture.wav,train/Creepoid - OldTree/vocals.wav,train/Creepoid - OldTree/drums.wav,train/Creepoid - OldTree/bass.wav,train/Creepoid - OldTree/other.wav,302.02195 +train/Dark Ride - Burning Bridges/mixture.wav,train/Dark Ride - Burning Bridges/vocals.wav,train/Dark Ride - Burning Bridges/drums.wav,train/Dark Ride - Burning Bridges/bass.wav,train/Dark Ride - Burning Bridges/other.wav,232.663946 +train/Dreamers Of The Ghetto - Heavy Love/mixture.wav,train/Dreamers Of The Ghetto - Heavy Love/vocals.wav,train/Dreamers Of The Ghetto - Heavy Love/drums.wav,train/Dreamers Of The Ghetto - Heavy Love/bass.wav,train/Dreamers Of The Ghetto - Heavy Love/other.wav,294.800544 +train/Drumtracks - Ghost Bitch/mixture.wav,train/Drumtracks - Ghost Bitch/vocals.wav,train/Drumtracks - Ghost Bitch/drums.wav,train/Drumtracks - Ghost Bitch/bass.wav,train/Drumtracks - Ghost Bitch/other.wav,356.913923 +train/Faces On Film - Waiting For Ga/mixture.wav,train/Faces On Film - Waiting For Ga/vocals.wav,train/Faces On Film - Waiting For Ga/drums.wav,train/Faces On Film - Waiting For Ga/bass.wav,train/Faces On Film - Waiting For Ga/other.wav,257.439637 +train/Fergessen - Back From The Start/mixture.wav,train/Fergessen - Back From The Start/vocals.wav,train/Fergessen - Back From The Start/drums.wav,train/Fergessen - Back From The Start/bass.wav,train/Fergessen - Back From The Start/other.wav,168.553651 +train/Fergessen - The Wind/mixture.wav,train/Fergessen - The Wind/vocals.wav,train/Fergessen - The Wind/drums.wav,train/Fergessen - The Wind/bass.wav,train/Fergessen - The Wind/other.wav,191.820045 +train/Flags - 54/mixture.wav,train/Flags - 54/vocals.wav,train/Flags - 54/drums.wav,train/Flags - 54/bass.wav,train/Flags - 54/other.wav,315.164444 +train/Giselle - Moss/mixture.wav,train/Giselle - Moss/vocals.wav,train/Giselle - Moss/drums.wav,train/Giselle - Moss/bass.wav,train/Giselle - Moss/other.wav,201.711746 +train/Grants - PunchDrunk/mixture.wav,train/Grants - PunchDrunk/vocals.wav,train/Grants - PunchDrunk/drums.wav,train/Grants - PunchDrunk/bass.wav,train/Grants - PunchDrunk/other.wav,204.405261 +train/Helado Negro - Mitad Del Mundo/mixture.wav,train/Helado Negro - Mitad Del Mundo/vocals.wav,train/Helado Negro - Mitad Del Mundo/drums.wav,train/Helado Negro - Mitad Del Mundo/bass.wav,train/Helado Negro - Mitad Del Mundo/other.wav,181.672925 +train/Hezekiah Jones - Borrowed Heart/mixture.wav,train/Hezekiah Jones - Borrowed Heart/vocals.wav,train/Hezekiah Jones - Borrowed Heart/drums.wav,train/Hezekiah Jones - Borrowed Heart/bass.wav,train/Hezekiah Jones - Borrowed Heart/other.wav,241.394649 +train/Hollow Ground - Left Blind/mixture.wav,train/Hollow Ground - Left Blind/vocals.wav,train/Hollow Ground - Left Blind/drums.wav,train/Hollow Ground - Left Blind/bass.wav,train/Hollow Ground - Left Blind/other.wav,159.103129 +train/Hop Along - Sister Cities/mixture.wav,train/Hop Along - Sister Cities/vocals.wav,train/Hop Along - Sister Cities/drums.wav,train/Hop Along - Sister Cities/bass.wav,train/Hop Along - Sister Cities/other.wav,283.237007 +train/Invisible Familiars - Disturbing Wildlife/mixture.wav,train/Invisible Familiars - Disturbing Wildlife/vocals.wav,train/Invisible Familiars - Disturbing Wildlife/drums.wav,train/Invisible Familiars - Disturbing Wildlife/bass.wav,train/Invisible Familiars - Disturbing Wildlife/other.wav,218.499773 +train/James May - All Souls Moon/mixture.wav,train/James May - All Souls Moon/vocals.wav,train/James May - All Souls Moon/drums.wav,train/James May - All Souls Moon/bass.wav,train/James May - All Souls Moon/other.wav,220.844989 +train/James May - Dont Let Go/mixture.wav,train/James May - Dont Let Go/vocals.wav,train/James May - Dont Let Go/drums.wav,train/James May - Dont Let Go/bass.wav,train/James May - Dont Let Go/other.wav,241.951927 +train/James May - If You Say/mixture.wav,train/James May - If You Say/vocals.wav,train/James May - If You Say/drums.wav,train/James May - If You Say/bass.wav,train/James May - If You Say/other.wav,258.321995 +train/Jay Menon - Through My Eyes/mixture.wav,train/Jay Menon - Through My Eyes/vocals.wav,train/Jay Menon - Through My Eyes/drums.wav,train/Jay Menon - Through My Eyes/bass.wav,train/Jay Menon - Through My Eyes/other.wav,253.167166 +train/Johnny Lokke - Whisper To A Scream/mixture.wav,train/Johnny Lokke - Whisper To A Scream/vocals.wav,train/Johnny Lokke - Whisper To A Scream/drums.wav,train/Johnny Lokke - Whisper To A Scream/bass.wav,train/Johnny Lokke - Whisper To A Scream/other.wav,255.326621 +"train/Jokers, Jacks & Kings - Sea Of Leaves/mixture.wav","train/Jokers, Jacks & Kings - Sea Of Leaves/vocals.wav","train/Jokers, Jacks & Kings - Sea Of Leaves/drums.wav","train/Jokers, Jacks & Kings - Sea Of Leaves/bass.wav","train/Jokers, Jacks & Kings - Sea Of Leaves/other.wav",191.471746 +train/Leaf - Come Around/mixture.wav,train/Leaf - Come Around/vocals.wav,train/Leaf - Come Around/drums.wav,train/Leaf - Come Around/bass.wav,train/Leaf - Come Around/other.wav,264.382404 +train/Leaf - Wicked/mixture.wav,train/Leaf - Wicked/vocals.wav,train/Leaf - Wicked/drums.wav,train/Leaf - Wicked/bass.wav,train/Leaf - Wicked/other.wav,190.635828 +train/Lushlife - Toynbee Suite/mixture.wav,train/Lushlife - Toynbee Suite/vocals.wav,train/Lushlife - Toynbee Suite/drums.wav,train/Lushlife - Toynbee Suite/bass.wav,train/Lushlife - Toynbee Suite/other.wav,628.378413 +train/Matthew Entwistle - Dont You Ever/mixture.wav,train/Matthew Entwistle - Dont You Ever/vocals.wav,train/Matthew Entwistle - Dont You Ever/drums.wav,train/Matthew Entwistle - Dont You Ever/bass.wav,train/Matthew Entwistle - Dont You Ever/other.wav,113.824218 +train/Meaxic - You Listen/mixture.wav,train/Meaxic - You Listen/vocals.wav,train/Meaxic - You Listen/drums.wav,train/Meaxic - You Listen/bass.wav,train/Meaxic - You Listen/other.wav,412.525714 +train/Music Delta - 80s Rock/mixture.wav,train/Music Delta - 80s Rock/vocals.wav,train/Music Delta - 80s Rock/drums.wav,train/Music Delta - 80s Rock/bass.wav,train/Music Delta - 80s Rock/other.wav,36.733968 +train/Music Delta - Beatles/mixture.wav,train/Music Delta - Beatles/vocals.wav,train/Music Delta - Beatles/drums.wav,train/Music Delta - Beatles/bass.wav,train/Music Delta - Beatles/other.wav,36.176689 +train/Music Delta - Britpop/mixture.wav,train/Music Delta - Britpop/vocals.wav,train/Music Delta - Britpop/drums.wav,train/Music Delta - Britpop/bass.wav,train/Music Delta - Britpop/other.wav,36.594649 +train/Music Delta - Country1/mixture.wav,train/Music Delta - Country1/vocals.wav,train/Music Delta - Country1/drums.wav,train/Music Delta - Country1/bass.wav,train/Music Delta - Country1/other.wav,34.551293 +train/Music Delta - Country2/mixture.wav,train/Music Delta - Country2/vocals.wav,train/Music Delta - Country2/drums.wav,train/Music Delta - Country2/bass.wav,train/Music Delta - Country2/other.wav,17.275646 +train/Music Delta - Disco/mixture.wav,train/Music Delta - Disco/vocals.wav,train/Music Delta - Disco/drums.wav,train/Music Delta - Disco/bass.wav,train/Music Delta - Disco/other.wav,124.598277 +train/Music Delta - Gospel/mixture.wav,train/Music Delta - Gospel/vocals.wav,train/Music Delta - Gospel/drums.wav,train/Music Delta - Gospel/bass.wav,train/Music Delta - Gospel/other.wav,75.557732 +train/Music Delta - Grunge/mixture.wav,train/Music Delta - Grunge/vocals.wav,train/Music Delta - Grunge/drums.wav,train/Music Delta - Grunge/bass.wav,train/Music Delta - Grunge/other.wav,41.656599 +train/Music Delta - Hendrix/mixture.wav,train/Music Delta - Hendrix/vocals.wav,train/Music Delta - Hendrix/drums.wav,train/Music Delta - Hendrix/bass.wav,train/Music Delta - Hendrix/other.wav,19.644082 +train/Music Delta - Punk/mixture.wav,train/Music Delta - Punk/vocals.wav,train/Music Delta - Punk/drums.wav,train/Music Delta - Punk/bass.wav,train/Music Delta - Punk/other.wav,28.583764 +train/Music Delta - Reggae/mixture.wav,train/Music Delta - Reggae/vocals.wav,train/Music Delta - Reggae/drums.wav,train/Music Delta - Reggae/bass.wav,train/Music Delta - Reggae/other.wav,17.275646 +train/Music Delta - Rock/mixture.wav,train/Music Delta - Rock/vocals.wav,train/Music Delta - Rock/drums.wav,train/Music Delta - Rock/bass.wav,train/Music Delta - Rock/other.wav,12.910295 +train/Music Delta - Rockabilly/mixture.wav,train/Music Delta - Rockabilly/vocals.wav,train/Music Delta - Rockabilly/drums.wav,train/Music Delta - Rockabilly/bass.wav,train/Music Delta - Rockabilly/other.wav,25.75093 +train/Night Panther - Fire/mixture.wav,train/Night Panther - Fire/vocals.wav,train/Night Panther - Fire/drums.wav,train/Night Panther - Fire/bass.wav,train/Night Panther - Fire/other.wav,212.810884 +train/North To Alaska - All The Same/mixture.wav,train/North To Alaska - All The Same/vocals.wav,train/North To Alaska - All The Same/drums.wav,train/North To Alaska - All The Same/bass.wav,train/North To Alaska - All The Same/other.wav,247.965896 +train/Patrick Talbot - Set Me Free/mixture.wav,train/Patrick Talbot - Set Me Free/vocals.wav,train/Patrick Talbot - Set Me Free/drums.wav,train/Patrick Talbot - Set Me Free/bass.wav,train/Patrick Talbot - Set Me Free/other.wav,289.785034 +train/Phre The Eon - Everybody's Falling Apart/mixture.wav,train/Phre The Eon - Everybody's Falling Apart/vocals.wav,train/Phre The Eon - Everybody's Falling Apart/drums.wav,train/Phre The Eon - Everybody's Falling Apart/bass.wav,train/Phre The Eon - Everybody's Falling Apart/other.wav,224.235102 +train/Port St Willow - Stay Even/mixture.wav,train/Port St Willow - Stay Even/vocals.wav,train/Port St Willow - Stay Even/drums.wav,train/Port St Willow - Stay Even/bass.wav,train/Port St Willow - Stay Even/other.wav,316.836281 +train/Remember December - C U Next Time/mixture.wav,train/Remember December - C U Next Time/vocals.wav,train/Remember December - C U Next Time/drums.wav,train/Remember December - C U Next Time/bass.wav,train/Remember December - C U Next Time/other.wav,242.532426 +train/Secret Mountains - High Horse/mixture.wav,train/Secret Mountains - High Horse/vocals.wav,train/Secret Mountains - High Horse/drums.wav,train/Secret Mountains - High Horse/bass.wav,train/Secret Mountains - High Horse/other.wav,355.311746 +train/Skelpolu - Together Alone/mixture.wav,train/Skelpolu - Together Alone/vocals.wav,train/Skelpolu - Together Alone/drums.wav,train/Skelpolu - Together Alone/bass.wav,train/Skelpolu - Together Alone/other.wav,325.822404 +train/Snowmine - Curfews/mixture.wav,train/Snowmine - Curfews/vocals.wav,train/Snowmine - Curfews/drums.wav,train/Snowmine - Curfews/bass.wav,train/Snowmine - Curfews/other.wav,275.017143 +train/Spike Mullings - Mike's Sulking/mixture.wav,train/Spike Mullings - Mike's Sulking/vocals.wav,train/Spike Mullings - Mike's Sulking/drums.wav,train/Spike Mullings - Mike's Sulking/bass.wav,train/Spike Mullings - Mike's Sulking/other.wav,256.696599 +train/St Vitus - Word Gets Around/mixture.wav,train/St Vitus - Word Gets Around/vocals.wav,train/St Vitus - Word Gets Around/drums.wav,train/St Vitus - Word Gets Around/bass.wav,train/St Vitus - Word Gets Around/other.wav,247.013878 +train/Steven Clark - Bounty/mixture.wav,train/Steven Clark - Bounty/vocals.wav,train/Steven Clark - Bounty/drums.wav,train/Steven Clark - Bounty/bass.wav,train/Steven Clark - Bounty/other.wav,289.274195 +train/Strand Of Oaks - Spacestation/mixture.wav,train/Strand Of Oaks - Spacestation/vocals.wav,train/Strand Of Oaks - Spacestation/drums.wav,train/Strand Of Oaks - Spacestation/bass.wav,train/Strand Of Oaks - Spacestation/other.wav,243.670204 +train/Sweet Lights - You Let Me Down/mixture.wav,train/Sweet Lights - You Let Me Down/vocals.wav,train/Sweet Lights - You Let Me Down/drums.wav,train/Sweet Lights - You Let Me Down/bass.wav,train/Sweet Lights - You Let Me Down/other.wav,391.790295 +train/Swinging Steaks - Lost My Way/mixture.wav,train/Swinging Steaks - Lost My Way/vocals.wav,train/Swinging Steaks - Lost My Way/drums.wav,train/Swinging Steaks - Lost My Way/bass.wav,train/Swinging Steaks - Lost My Way/other.wav,309.963175 +train/The Districts - Vermont/mixture.wav,train/The Districts - Vermont/vocals.wav,train/The Districts - Vermont/drums.wav,train/The Districts - Vermont/bass.wav,train/The Districts - Vermont/other.wav,227.973515 +train/The Long Wait - Back Home To Blue/mixture.wav,train/The Long Wait - Back Home To Blue/vocals.wav,train/The Long Wait - Back Home To Blue/drums.wav,train/The Long Wait - Back Home To Blue/bass.wav,train/The Long Wait - Back Home To Blue/other.wav,260.458231 +train/The Scarlet Brand - Les Fleurs Du Mal/mixture.wav,train/The Scarlet Brand - Les Fleurs Du Mal/vocals.wav,train/The Scarlet Brand - Les Fleurs Du Mal/drums.wav,train/The Scarlet Brand - Les Fleurs Du Mal/bass.wav,train/The Scarlet Brand - Les Fleurs Du Mal/other.wav,303.438367 +train/The So So Glos - Emergency/mixture.wav,train/The So So Glos - Emergency/vocals.wav,train/The So So Glos - Emergency/drums.wav,train/The So So Glos - Emergency/bass.wav,train/The So So Glos - Emergency/other.wav,166.812154 +train/The Wrong'Uns - Rothko/mixture.wav,train/The Wrong'Uns - Rothko/vocals.wav,train/The Wrong'Uns - Rothko/drums.wav,train/The Wrong'Uns - Rothko/bass.wav,train/The Wrong'Uns - Rothko/other.wav,202.152925 +train/Tim Taler - Stalker/mixture.wav,train/Tim Taler - Stalker/vocals.wav,train/Tim Taler - Stalker/drums.wav,train/Tim Taler - Stalker/bass.wav,train/Tim Taler - Stalker/other.wav,237.633016 +train/Titanium - Haunted Age/mixture.wav,train/Titanium - Haunted Age/vocals.wav,train/Titanium - Haunted Age/drums.wav,train/Titanium - Haunted Age/bass.wav,train/Titanium - Haunted Age/other.wav,248.105215 +train/Traffic Experiment - Once More (With Feeling)/mixture.wav,train/Traffic Experiment - Once More (With Feeling)/vocals.wav,train/Traffic Experiment - Once More (With Feeling)/drums.wav,train/Traffic Experiment - Once More (With Feeling)/bass.wav,train/Traffic Experiment - Once More (With Feeling)/other.wav,435.07229 +train/Triviul - Dorothy/mixture.wav,train/Triviul - Dorothy/vocals.wav,train/Triviul - Dorothy/drums.wav,train/Triviul - Dorothy/bass.wav,train/Triviul - Dorothy/other.wav,187.361814 +train/Voelund - Comfort Lives In Belief/mixture.wav,train/Voelund - Comfort Lives In Belief/vocals.wav,train/Voelund - Comfort Lives In Belief/drums.wav,train/Voelund - Comfort Lives In Belief/bass.wav,train/Voelund - Comfort Lives In Belief/other.wav,209.90839 +train/Wall Of Death - Femme/mixture.wav,train/Wall Of Death - Femme/vocals.wav,train/Wall Of Death - Femme/drums.wav,train/Wall Of Death - Femme/bass.wav,train/Wall Of Death - Femme/other.wav,238.933333 +train/Young Griffo - Blood To Bone/mixture.wav,train/Young Griffo - Blood To Bone/vocals.wav,train/Young Griffo - Blood To Bone/drums.wav,train/Young Griffo - Blood To Bone/bass.wav,train/Young Griffo - Blood To Bone/other.wav,254.397823 +train/Young Griffo - Facade/mixture.wav,train/Young Griffo - Facade/vocals.wav,train/Young Griffo - Facade/drums.wav,train/Young Griffo - Facade/bass.wav,train/Young Griffo - Facade/other.wav,167.857052 diff --git a/configs/musdb_validation.csv b/configs/musdb_validation.csv new file mode 100644 index 00000000..8f1206dd --- /dev/null +++ b/configs/musdb_validation.csv @@ -0,0 +1,15 @@ +mix_path,vocals_path,drums_path,bass_path,other_path,duration +train/ANiMAL - Rockshow/mixture.wav,train/ANiMAL - Rockshow/vocals.wav,train/ANiMAL - Rockshow/drums.wav,train/ANiMAL - Rockshow/bass.wav,train/ANiMAL - Rockshow/other.wav,165.511837 +train/Actions - One Minute Smile/mixture.wav,train/Actions - One Minute Smile/vocals.wav,train/Actions - One Minute Smile/drums.wav,train/Actions - One Minute Smile/bass.wav,train/Actions - One Minute Smile/other.wav,163.375601 +train/Alexander Ross - Goodbye Bolero/mixture.wav,train/Alexander Ross - Goodbye Bolero/vocals.wav,train/Alexander Ross - Goodbye Bolero/drums.wav,train/Alexander Ross - Goodbye Bolero/bass.wav,train/Alexander Ross - Goodbye Bolero/other.wav,418.632562 +train/Clara Berry And Wooldog - Waltz For My Victims/mixture.wav,train/Clara Berry And Wooldog - Waltz For My Victims/vocals.wav,train/Clara Berry And Wooldog - Waltz For My Victims/drums.wav,train/Clara Berry And Wooldog - Waltz For My Victims/bass.wav,train/Clara Berry And Wooldog - Waltz For My Victims/other.wav,175.240998 +train/Fergessen - Nos Palpitants/mixture.wav,train/Fergessen - Nos Palpitants/vocals.wav,train/Fergessen - Nos Palpitants/drums.wav,train/Fergessen - Nos Palpitants/bass.wav,train/Fergessen - Nos Palpitants/other.wav,198.228753 +train/James May - On The Line/mixture.wav,train/James May - On The Line/vocals.wav,train/James May - On The Line/drums.wav,train/James May - On The Line/bass.wav,train/James May - On The Line/other.wav,256.09288 +train/Johnny Lokke - Promises & Lies/mixture.wav,train/Johnny Lokke - Promises & Lies/vocals.wav,train/Johnny Lokke - Promises & Lies/drums.wav,train/Johnny Lokke - Promises & Lies/bass.wav,train/Johnny Lokke - Promises & Lies/other.wav,285.814422 +train/Leaf - Summerghost/mixture.wav,train/Leaf - Summerghost/vocals.wav,train/Leaf - Summerghost/drums.wav,train/Leaf - Summerghost/bass.wav,train/Leaf - Summerghost/other.wav,231.804807 +train/Meaxic - Take A Step/mixture.wav,train/Meaxic - Take A Step/vocals.wav,train/Meaxic - Take A Step/drums.wav,train/Meaxic - Take A Step/bass.wav,train/Meaxic - Take A Step/other.wav,282.517188 +train/Patrick Talbot - A Reason To Leave/mixture.wav,train/Patrick Talbot - A Reason To Leave/vocals.wav,train/Patrick Talbot - A Reason To Leave/drums.wav,train/Patrick Talbot - A Reason To Leave/bass.wav,train/Patrick Talbot - A Reason To Leave/other.wav,259.552653 +train/Skelpolu - Human Mistakes/mixture.wav,train/Skelpolu - Human Mistakes/vocals.wav,train/Skelpolu - Human Mistakes/drums.wav,train/Skelpolu - Human Mistakes/bass.wav,train/Skelpolu - Human Mistakes/other.wav,324.498866 +train/Traffic Experiment - Sirens/mixture.wav,train/Traffic Experiment - Sirens/vocals.wav,train/Traffic Experiment - Sirens/drums.wav,train/Traffic Experiment - Sirens/bass.wav,train/Traffic Experiment - Sirens/other.wav,421.279637 +train/Triviul - Angelsaint/mixture.wav,train/Triviul - Angelsaint/vocals.wav,train/Triviul - Angelsaint/drums.wav,train/Triviul - Angelsaint/bass.wav,train/Triviul - Angelsaint/other.wav,236.704218 +train/Young Griffo - Pennies/mixture.wav,train/Young Griffo - Pennies/vocals.wav,train/Young Griffo - Pennies/drums.wav,train/Young Griffo - Pennies/bass.wav,train/Young Griffo - Pennies/other.wav,277.803537 diff --git a/docker/cpu.Dockerfile b/docker/cpu.Dockerfile new file mode 100644 index 00000000..e3f47e7f --- /dev/null +++ b/docker/cpu.Dockerfile @@ -0,0 +1,24 @@ +FROM continuumio/miniconda3:4.7.10 + +# install tensorflow +RUN conda install -y tensorflow==1.14.0 + +# install ffmpeg for audio loading/writing +RUN conda install -y -c conda-forge ffmpeg + +# install extra python libraries +RUN conda install -y -c anaconda pandas==0.25.1 +RUN conda install -y -c conda-forge libsndfile + +# install ipython +RUN conda install -y ipython + +WORKDIR /workspace/ +COPY ./ spleeter/ + +RUN mkdir /cache/ + +WORKDIR /workspace/spleeter +RUN pip install . + +ENTRYPOINT ["python", "-m", "spleeter"] \ No newline at end of file diff --git a/docker/gpu.Dockerfile b/docker/gpu.Dockerfile new file mode 100644 index 00000000..aedeeda4 --- /dev/null +++ b/docker/gpu.Dockerfile @@ -0,0 +1,35 @@ +FROM nvidia/cuda:10.1-cudnn7-runtime-ubuntu18.04 + +# set work directory +WORKDIR /workspace + +# install anaconda +ENV PATH /opt/conda/bin:$PATH +COPY docker/install_miniconda.sh . +RUN bash ./install_miniconda.sh && rm install_miniconda.sh + +RUN conda update -n base -c defaults conda + +# install tensorflow for GPU +RUN conda install -y tensorflow-gpu==1.14.0 + +# install ffmpeg for audio loading/writing +RUN conda install -y -c conda-forge ffmpeg + +# install extra libs +RUN conda install -y -c anaconda pandas==0.25.1 +RUN conda install -y -c conda-forge libsndfile + +# install ipython +RUN conda install -y ipython + +RUN mkdir /cache/ + +# clone inside image github repository +COPY ./ spleeter/ + +WORKDIR /workspace/spleeter +RUN pip install . + + +ENTRYPOINT ["python", "-m", "spleeter"] \ No newline at end of file diff --git a/docker/install_miniconda.sh b/docker/install_miniconda.sh new file mode 100644 index 00000000..6ea58bcb --- /dev/null +++ b/docker/install_miniconda.sh @@ -0,0 +1,13 @@ +#!/bin/bash +apt-get update --fix-missing && \ +apt-get install -y wget bzip2 ca-certificates curl git && \ +apt-get clean && \ +rm -rf /var/lib/apt/lists/* + +wget --quiet https://repo.anaconda.com/miniconda/Miniconda3-4.6.14-Linux-x86_64.sh -O ~/miniconda.sh && \ +/bin/bash ~/miniconda.sh -b -p /opt/conda && \ +rm ~/miniconda.sh && \ +/opt/conda/bin/conda clean -tipsy && \ +ln -s /opt/conda/etc/profile.d/conda.sh /etc/profile.d/conda.sh && \ +echo ". /opt/conda/etc/profile.d/conda.sh" >> ~/.bashrc && \ +echo "conda activate base" >> ~/.bashrc \ No newline at end of file diff --git a/setup.py b/setup.py new file mode 100644 index 00000000..e350564f --- /dev/null +++ b/setup.py @@ -0,0 +1,108 @@ +#!/usr/bin/env python +# coding: utf8 + +""" Distribution script. """ + +import sys + +from os import path +from setuptools import setup + +__email__ = 'research@deezer.com' +__author__ = 'Deezer Research' +__license__ = 'MIT License' + +# Default project values. +project_name = 'spleeter' +project_version = '1.4.0' +device_target = 'cpu' +tensorflow_dependency = 'tensorflow' +tensorflow_version = '1.14.0' +here = path.abspath(path.dirname(__file__)) +readme_path = path.join(here, 'README.md') +with open(readme_path, 'r') as stream: + readme = stream.read() + +# Check if GPU target is specified. +if '--target' in sys.argv: + target_index = sys.argv.index('--target') + 1 + target = sys.argv[target_index].lower() + sys.argv.remove('--target') + sys.argv.pop(target_index) + +# GPU target compatibility check. +if device_target == 'gpu': + project_name = '{}-gpu'.format(project_name) + tensorflow_dependency = 'tensorflow-gpu' + +# Package setup entrypoint. +setup( + name=project_name, + version=project_version, + description=''' + The Deezer source separation library with + pretrained models based on tensorflow. + ''', + long_description=readme, + long_description_content_type='text/markdown', + author='Deezer Research', + author_email='research@deezer.com', + url='https://github.com/deezer/spleeter', + license='MIT License', + packages=[ + 'spleeter', + 'spleeter.commands', + 'spleeter.model', + 'spleeter.model.functions', + 'spleeter.model.provider', + 'spleeter.resources', + 'spleeter.utils', + 'spleeter.utils.audio', + ], + package_data={'spleeter.resources': ['*.json']}, + python_requires='>=3.6, <3.8', + include_package_data=True, + install_requires=[ + 'importlib_resources ; python_version<"3.7"', + 'musdb==0.3.1', + 'museval==0.3.0', + 'norbert==0.2.1', + 'pandas==0.25.1', + 'requests', + '{}=={}'.format(tensorflow_dependency, tensorflow_version), + ], + entry_points={ + 'console_scripts': ['spleeter=spleeter.__main__:entrypoint'] + }, + classifiers=[ + 'Environment :: Console', + 'Environment :: MacOS X', + 'Intended Audience :: Developers', + 'Intended Audience :: Information Technology', + 'Intended Audience :: Science/Research', + 'License :: OSI Approved :: MIT License', + 'Natural Language :: English', + 'Operating System :: MacOS', + 'Operating System :: Microsoft :: Windows', + 'Operating System :: POSIX :: Linux', + 'Operating System :: Unix', + 'Programming Language :: Python', + 'Programming Language :: Python :: 3', + 'Programming Language :: Python :: 3.6', + 'Programming Language :: Python :: 3.7', + 'Programming Language :: Python :: 3 :: Only', + 'Programming Language :: Python :: Implementation :: CPython', + 'Topic :: Artistic Software', + 'Topic :: Multimedia', + 'Topic :: Multimedia :: Sound/Audio', + 'Topic :: Multimedia :: Sound/Audio :: Analysis', + 'Topic :: Multimedia :: Sound/Audio :: Conversion', + 'Topic :: Multimedia :: Sound/Audio :: Sound Synthesis', + 'Topic :: Scientific/Engineering', + 'Topic :: Scientific/Engineering :: Artificial Intelligence', + 'Topic :: Scientific/Engineering :: Information Analysis', + 'Topic :: Software Development', + 'Topic :: Software Development :: Libraries', + 'Topic :: Software Development :: Libraries :: Python Modules', + 'Topic :: Utilities'] +) diff --git a/spleeter/__init__.py b/spleeter/__init__.py new file mode 100644 index 00000000..e3693719 --- /dev/null +++ b/spleeter/__init__.py @@ -0,0 +1,18 @@ +#!/usr/bin/env python +# coding: utf8 + +""" + Spleeter is the Deezer source separation library with pretrained models. + The library is based on Tensorflow: + + - It provides already trained model for performing separation. + - It makes it easy to train source separation model with tensorflow + (provided you have a dataset of isolated sources). + + This module allows to interact easily from command line with Spleeter + by providing train, evaluation and source separation action. +""" + +__email__ = 'research@deezer.com' +__author__ = 'Deezer Research' +__license__ = 'MIT License' diff --git a/spleeter/__main__.py b/spleeter/__main__.py new file mode 100644 index 00000000..fde52924 --- /dev/null +++ b/spleeter/__main__.py @@ -0,0 +1,52 @@ +#!/usr/bin/env python +# coding: utf8 + +""" + Python oneliner script usage. + + USAGE: python -m spleeter {train,evaluate,separate} ... +""" + +import sys +import warnings + +from .commands import create_argument_parser +from .utils.configuration import load_configuration +from .utils.logging import enable_logging, enable_verbose_logging + +__email__ = 'research@deezer.com' +__author__ = 'Deezer Research' +__license__ = 'MIT License' + + +def main(argv): + """ Spleeter runner. Parse provided command line arguments + and run entrypoint for required command (either train, + evaluate or separate). + + :param argv: Provided command line arguments. + """ + parser = create_argument_parser() + arguments = parser.parse_args(argv[1:]) + if arguments.verbose: + enable_verbose_logging() + else: + enable_logging() + if arguments.command == 'separate': + from .commands.separate import entrypoint + elif arguments.command == 'train': + from .commands.train import entrypoint + elif arguments.command == 'evaluate': + from .commands.evaluate import entrypoint + params = load_configuration(arguments.params_filename) + entrypoint(arguments, params) + + +def entrypoint(): + """ Command line entrypoint. """ + warnings.filterwarnings('ignore') + main(sys.argv) + + +if __name__ == '__main__': + entrypoint() diff --git a/spleeter/commands/__init__.py b/spleeter/commands/__init__.py new file mode 100644 index 00000000..25773998 --- /dev/null +++ b/spleeter/commands/__init__.py @@ -0,0 +1,182 @@ +#!/usr/bin/env python +# coding: utf8 + +""" This modules provides spleeter command as well as CLI parsing methods. """ + +import json + +from argparse import ArgumentParser +from tempfile import gettempdir +from os.path import exists, join + +__email__ = 'research@deezer.com' +__author__ = 'Deezer Research' +__license__ = 'MIT License' + +# -i opt specification. +OPT_INPUT = { + 'dest': 'audio_filenames', + 'nargs': '+', + 'help': 'List of input audio filenames', + 'required': True +} + +# -o opt specification. +OPT_OUTPUT = { + 'dest': 'output_path', + 'default': join(gettempdir(), 'separated_audio'), + 'help': 'Path of the output directory to write audio files in' +} + +# -p opt specification. +OPT_PARAMS = { + 'dest': 'params_filename', + 'default': 'spleeter:2stems', + 'type': str, + 'action': 'store', + 'help': 'JSON filename that contains params' +} + +# -n opt specification. +OPT_OUTPUT_NAMING = { + 'dest': 'output_naming', + 'default': 'filename', + 'choices': ('directory', 'filename'), + 'help': ( + 'Choice for naming the output base path: ' + '"filename" (use the input filename, i.e ' + '/path/to/audio/mix.wav will be separated to ' + '/mix/.wav, ' + '/mix/.wav...) or ' + '"directory" (use the name of the input last level' + ' directory, for instance /path/to/audio/mix.wav ' + 'will be separated to /audio/.wav' + ', /audio/.wav)') +} + +# -d opt specification (separate). +OPT_DURATION = { + 'dest': 'max_duration', + 'type': float, + 'default': 600., + 'help': ( + 'Set a maximum duration for processing audio ' + '(only separate max_duration first seconds of ' + 'the input file)') +} + +# -c opt specification. +OPT_CODEC = { + 'dest': 'audio_codec', + 'choices': ('wav', 'mp3', 'ogg', 'm4a', 'wma', 'flac'), + 'default': 'wav', + 'help': 'Audio codec to be used for the separated output' +} + +# -m opt specification. +OPT_MWF = { + 'dest': 'MWF', + 'action': 'store_const', + 'const': True, + 'default': False, + 'help': 'Whether to use multichannel Wiener filtering for separation', +} + +# --mus_dir opt specification. +OPT_MUSDB = { + 'dest': 'mus_dir', + 'type': str, + 'required': True, + 'help': 'Path to folder with musDB' +} + +# -d opt specification (train). +OPT_DATA = { + 'dest': 'audio_path', + 'type': str, + 'required': True, + 'help': 'Path of the folder containing audio data for training' +} + +# -a opt specification. +OPT_ADAPTER = { + 'dest': 'audio_adapter', + 'type': str, + 'help': 'Name of the audio adapter to use for audio I/O' +} + +# -a opt specification. +OPT_VERBOSE = { + 'action': 'store_true', + 'help': 'Shows verbose logs' +} + + +def _add_common_options(parser): + """ Add common option to the given parser. + + :param parser: Parser to add common opt to. + """ + parser.add_argument('-a', '--adapter', **OPT_ADAPTER) + parser.add_argument('-p', '--params_filename', **OPT_PARAMS) + parser.add_argument('--verbose', **OPT_VERBOSE) + + +def _create_train_parser(parser_factory): + """ Creates an argparser for training command + + :param parser_factory: Factory to use to create parser instance. + :returns: Created and configured parser. + """ + parser = parser_factory('train', help='Train a source separation model') + _add_common_options(parser) + parser.add_argument('-d', '--data', **OPT_DATA) + return parser + + +def _create_evaluate_parser(parser_factory): + """ Creates an argparser for evaluation command + + :param parser_factory: Factory to use to create parser instance. + :returns: Created and configured parser. + """ + parser = parser_factory( + 'evaluate', + help='Evaluate a model on the musDB test dataset') + _add_common_options(parser) + parser.add_argument('-o', '--output_path', **OPT_OUTPUT) + parser.add_argument('--mus_dir', **OPT_MUSDB) + parser.add_argument('-m', '--mwf', **OPT_MWF) + return parser + + +def _create_separate_parser(parser_factory): + """ Creates an argparser for separation command + + :param parser_factory: Factory to use to create parser instance. + :returns: Created and configured parser. + """ + parser = parser_factory('separate', help='Separate audio files') + _add_common_options(parser) + parser.add_argument('-i', '--audio_filenames', **OPT_INPUT) + parser.add_argument('-o', '--output_path', **OPT_OUTPUT) + parser.add_argument('-n', '--output_naming', **OPT_OUTPUT_NAMING) + parser.add_argument('-d', '--max_duration', **OPT_DURATION) + parser.add_argument('-c', '--audio_codec', **OPT_CODEC) + parser.add_argument('-m', '--mwf', **OPT_MWF) + return parser + + +def create_argument_parser(): + """ Creates overall command line parser for Spleeter. + + :returns: Created argument parser. + """ + parser = ArgumentParser(prog='python -m spleeter') + subparsers = parser.add_subparsers() + subparsers.dest = 'command' + subparsers.required = True + _create_separate_parser(subparsers.add_parser) + _create_train_parser(subparsers.add_parser) + _create_evaluate_parser(subparsers.add_parser) + return parser diff --git a/spleeter/commands/evaluate.py b/spleeter/commands/evaluate.py new file mode 100644 index 00000000..dc990ada --- /dev/null +++ b/spleeter/commands/evaluate.py @@ -0,0 +1,154 @@ +#!/usr/bin/env python +# coding: utf8 + +""" + Entrypoint provider for performing model evaluation. + + Evaluation is performed against musDB dataset. + + USAGE: python -m spleeter evaluate \ + -p /path/to/params \ + -o /path/to/output/dir \ + [-m] \ + --mus_dir /path/to/musdb dataset +""" + +import json + +from argparse import Namespace +from itertools import product +from glob import glob +from os.path import join, exists + +# pylint: disable=import-error +import musdb +import museval +import numpy as np +import pandas as pd +# pylint: enable=import-error + +from .separate import entrypoint as separate_entrypoint +from ..utils.logging import get_logger + +__email__ = 'research@deezer.com' +__author__ = 'Deezer Research' +__license__ = 'MIT License' + +_SPLIT = 'test' +_MIXTURE = 'mixture.wav' +_NAMING = 'directory' +_AUDIO_DIRECTORY = 'audio' +_METRICS_DIRECTORY = 'metrics' +_INSTRUMENTS = ('vocals', 'drums', 'bass', 'other') +_METRICS = ('SDR', 'SAR', 'SIR', 'ISR') + + +def _separate_evaluation_dataset(arguments, musdb_root_directory, params): + """ Performs audio separation on the musdb dataset from + the given directory and params. + + :param arguments: Entrypoint arguments. + :param musdb_root_directory: Directory to retrieve dataset from. + :param params: Spleeter configuration to apply to separation. + :returns: Separation output directory path. + """ + songs = glob(join(musdb_root_directory, _SPLIT, '*/')) + mixtures = [join(song, _MIXTURE) for song in songs] + audio_output_directory = join( + arguments.output_path, + _AUDIO_DIRECTORY) + separate_entrypoint( + Namespace( + audio_adapter=arguments.audio_adapter, + audio_filenames=mixtures, + audio_codec='wav', + output_path=join(audio_output_directory, _SPLIT), + output_naming=_NAMING, + max_duration=600., + MWF=arguments.MWF, + verbose=arguments.verbose), + params) + return audio_output_directory + + +def _compute_musdb_metrics( + arguments, + musdb_root_directory, + audio_output_directory): + """ Generates musdb metrics fro previsouly computed audio estimation. + + :param arguments: Entrypoint arguments. + :param audio_output_directory: Directory to get audio estimation from. + :returns: Path of generated metrics directory. + """ + metrics_output_directory = join( + arguments.output_path, + _METRICS_DIRECTORY) + get_logger().info('Starting musdb evaluation (this could be long) ...') + dataset = musdb.DB( + root=musdb_root_directory, + is_wav=True, + subsets=[_SPLIT]) + museval.eval_mus_dir( + dataset=dataset, + estimates_dir=audio_output_directory, + output_dir=metrics_output_directory) + get_logger().info('musdb evaluation done') + return metrics_output_directory + + +def _compile_metrics(metrics_output_directory): + """ Compiles metrics from given directory and returns + results as dict. + + :param metrics_output_directory: Directory to get metrics from. + :returns: Compiled metrics as dict. + """ + songs = glob(join(metrics_output_directory, 'test/*.json')) + index = pd.MultiIndex.from_tuples( + product(_INSTRUMENTS, _METRICS), + names=['instrument', 'metric']) + pd.DataFrame([], index=['config1', 'config2'], columns=index) + metrics = { + instrument: {k: [] for k in _METRICS} + for instrument in _INSTRUMENTS} + for song in songs: + with open(song, 'r') as stream: + data = json.load(stream) + for target in data['targets']: + instrument = target['name'] + for metric in _METRICS: + sdr_med = np.median([ + frame['metrics'][metric] + for frame in target['frames'] + if not np.isnan(frame['metrics'][metric])]) + metrics[instrument][metric].append(sdr_med) + return metrics + + +def entrypoint(arguments, params): + """ Command entrypoint. + + :param arguments: Command line parsed argument as argparse.Namespace. + :param params: Deserialized JSON configuration file provided in CLI args. + """ + # Parse and check musdb directory. + musdb_root_directory = arguments.mus_dir + if not exists(musdb_root_directory): + raise IOError(f'musdb directory {musdb_root_directory} not found') + # Separate musdb sources. + audio_output_directory = _separate_evaluation_dataset( + arguments, + musdb_root_directory, + params) + # Compute metrics with musdb. + metrics_output_directory = _compute_musdb_metrics( + arguments, + musdb_root_directory, + audio_output_directory) + # Compute and pretty print median metrics. + metrics = _compile_metrics(metrics_output_directory) + for instrument, metric in metrics.items(): + get_logger().info('%s:', instrument) + for metric, value in metric.items(): + get_logger().info('%s: %s', metric, f'{np.median(value):.3f}') diff --git a/spleeter/commands/separate.py b/spleeter/commands/separate.py new file mode 100644 index 00000000..0098351f --- /dev/null +++ b/spleeter/commands/separate.py @@ -0,0 +1,180 @@ +#!/usr/bin/env python +# coding: utf8 + +""" + Entrypoint provider for performing source separation. + + USAGE: python -m spleeter separate \ + -p /path/to/params \ + -i inputfile1 inputfile2 ... inputfilen + -o /path/to/output/dir \ + -i /path/to/audio1.wav /path/to/audio2.mp3 +""" + +from multiprocessing import Pool +from os.path import isabs, join, split, splitext +from tempfile import gettempdir + +# pylint: disable=import-error +import tensorflow as tf +import numpy as np +# pylint: enable=import-error + +from ..utils.audio.adapter import get_audio_adapter +from ..utils.audio.convertor import to_n_channels +from ..utils.estimator import create_estimator +from ..utils.tensor import set_tensor_shape + +__email__ = 'research@deezer.com' +__author__ = 'Deezer Research' +__license__ = 'MIT License' + + +def get_dataset(audio_adapter, filenames_and_crops, sample_rate, n_channels): + """" + Build a tensorflow dataset of waveform from a filename list wit crop + information. + + Params: + - audio_adapter: An AudioAdapter instance to load audio from. + - filenames_and_crops: list of (audio_filename, start, duration) + tuples separation is performed on each filaneme + from start (in seconds) to start + duration + (in seconds). + - sample_rate: audio sample_rate of the input and output audio + signals + - n_channels: int, number of channels of the input and output + audio signals + + Returns + A tensorflow dataset of waveform to feed a tensorflow estimator in + predict mode. + """ + filenames, starts, ends = list(zip(*filenames_and_crops)) + dataset = tf.data.Dataset.from_tensor_slices({ + 'audio_id': list(filenames), + 'start': list(starts), + 'end': list(ends) + }) + # Load waveform. + dataset = dataset.map( + lambda sample: dict( + sample, + **audio_adapter.load_tf_waveform( + sample['audio_id'], + sample_rate=sample_rate, + offset=sample['start'], + duration=sample['end'] - sample['start'])), + num_parallel_calls=2) + # Filter out error. + dataset = dataset.filter( + lambda sample: tf.logical_not(sample['waveform_error'])) + # Convert waveform to the right number of channels. + dataset = dataset.map( + lambda sample: dict( + sample, + waveform=to_n_channels(sample['waveform'], n_channels))) + # Set number of channels (required for the model). + dataset = dataset.map( + lambda sample: dict( + sample, + waveform=set_tensor_shape(sample['waveform'], (None, n_channels)))) + return dataset + + +def process_audio( + audio_adapter, + filenames_and_crops, estimator, output_path, + sample_rate, n_channels, codec, output_naming): + """ + Perform separation on a list of audio ids. + + Params: + - audio_adapter: Audio adapter to use for audio I/O. + - filenames_and_crops: list of (audio_filename, start, duration) + tuples separation is performed on each filaneme + from start (in seconds) to start + duration + (in seconds). + - estimator: the tensorflow estimator that performs the + source separation. + - output_path: output_path where to export separated files. + - sample_rate: audio sample_rate of the input and output audio + signals + - n_channels: int, number of channels of the input and output + audio signals + - codec: string codec to be used for export (could be + "wav", "mp3", "ogg", "m4a") could be anything + supported by ffmpeg. + - output_naming: string (= "filename" of "directory") + naming convention for output. + for an input file /path/to/audio/input_file.wav: + * if output_naming is equal to "filename": + output files will be put in the directory /input_file + (/input_file/., + /input_file/....). + * if output_naming is equal to "directory": + output files will be put in the directory /audio/ + (/audio/., + /audio/....) + Use "directory" when separating the MusDB dataset. + + """ + # Get estimator + prediction = estimator.predict( + lambda: get_dataset( + audio_adapter, + filenames_and_crops, + sample_rate, + n_channels), + yield_single_examples=False) + # initialize pool for audio export + pool = Pool(16) + tasks = [] + for sample in prediction: + sample_filename = sample.pop('audio_id', 'unknown_filename').decode() + input_directory, input_filename = split(sample_filename) + if output_naming == 'directory': + output_dirname = split(input_directory)[1] + elif output_naming == 'filename': + output_dirname = splitext(input_filename)[0] + else: + raise ValueError(f'Unknown output naming {output_naming}') + for instrument, waveform in sample.items(): + filename = join( + output_path, + output_dirname, + f'{instrument}.{codec}') + tasks.append( + pool.apply_async( + audio_adapter.save, + (filename, waveform, sample_rate, codec))) + # Wait for everything to be written + for task in tasks: + task.wait(timeout=20) + + +def entrypoint(arguments, params): + """ Command entrypoint. + + :param arguments: Command line parsed argument as argparse.Namespace. + :param params: Deserialized JSON configuration file provided in CLI args. + """ + audio_adapter = get_audio_adapter(arguments.audio_adapter) + filenames = arguments.audio_filenames + output_path = arguments.output_path + max_duration = arguments.max_duration + audio_codec = arguments.audio_codec + output_naming = arguments.output_naming + estimator = create_estimator(params, arguments.MWF) + filenames_and_crops = [ + (filename, 0., max_duration) + for filename in filenames] + process_audio( + audio_adapter, + filenames_and_crops, + estimator, + output_path, + params['sample_rate'], + params['n_channels'], + codec=audio_codec, + output_naming=output_naming) diff --git a/spleeter/commands/train.py b/spleeter/commands/train.py new file mode 100644 index 00000000..2814ae67 --- /dev/null +++ b/spleeter/commands/train.py @@ -0,0 +1,98 @@ +#!/usr/bin/env python +# coding: utf8 + +""" + Entrypoint provider for performing model training. + + USAGE: python -m spleeter train -p /path/to/params +""" + +from functools import partial + +# pylint: disable=import-error +import tensorflow as tf +# pylint: enable=import-error + +from ..dataset import get_training_dataset, get_validation_dataset +from ..model import model_fn +from ..utils.audio.adapter import get_audio_adapter +from ..utils.logging import get_logger + +__email__ = 'research@deezer.com' +__author__ = 'Deezer Research' +__license__ = 'MIT License' + + +def _create_estimator(params): + """ Creates estimator. + + :param params: TF params to build estimator from. + :returns: Built estimator. + """ + session_config = tf.compat.v1.ConfigProto() + session_config.gpu_options.per_process_gpu_memory_fraction = 0.45 + estimator = tf.estimator.Estimator( + model_fn=model_fn, + model_dir=params['model_dir'], + params=params, + config=tf.estimator.RunConfig( + save_checkpoints_steps=params['save_checkpoints_steps'], + tf_random_seed=params['random_seed'], + save_summary_steps=params['save_summary_steps'], + session_config=session_config, + log_step_count_steps=10, + keep_checkpoint_max=2)) + return estimator + + +def _create_train_spec(params, audio_adapter, audio_path): + """ Creates train spec. + + :param params: TF params to build spec from. + :returns: Built train spec. + """ + input_fn = partial(get_training_dataset, params, audio_adapter, audio_path) + train_spec = tf.estimator.TrainSpec( + input_fn=input_fn, + max_steps=params['train_max_steps']) + return train_spec + + +def _create_evaluation_spec(params, audio_adapter, audio_path): + """ Setup eval spec evaluating ever n seconds + + :param params: TF params to build spec from. + :returns: Built evaluation spec. + """ + input_fn = partial( + get_validation_dataset, + params, + audio_adapter, + audio_path) + evaluation_spec = tf.estimator.EvalSpec( + input_fn=input_fn, + steps=None, + throttle_secs=params['throttle_secs']) + return evaluation_spec + + +def entrypoint(arguments, params): + """ Command entrypoint. + + :param arguments: Command line parsed argument as argparse.Namespace. + :param params: Deserialized JSON configuration file provided in CLI args. + """ + audio_adapter = get_audio_adapter(arguments.audio_adapter) + audio_path = arguments.audio_path + estimator = _create_estimator(params) + train_spec = _create_train_spec(params, audio_adapter, audio_path) + evaluation_spec = _create_evaluation_spec( + params, + audio_adapter, + audio_path) + get_logger().info('Start model training') + tf.estimator.train_and_evaluate( + estimator, + train_spec, + evaluation_spec) + get_logger().info('Model training done') diff --git a/spleeter/dataset.py b/spleeter/dataset.py new file mode 100644 index 00000000..dc656528 --- /dev/null +++ b/spleeter/dataset.py @@ -0,0 +1,464 @@ +#!/usr/bin/env python +# coding: utf8 + +""" + Module for building data preprocessing pipeline using the tensorflow data + API. + Data preprocessing such as audio loading, spectrogram computation, cropping, + feature caching or data augmentation is done using a tensorflow dataset object + that output a tuple (input_, output) where: + - input_ is a dictionary with a single key that contains the (batched) mix + spectrogram of audio samples + - output is a dictionary of spectrogram of the isolated tracks (ground truth) + +""" + +import time +import os +from os.path import exists, join, sep as SEPARATOR + +# pylint: disable=import-error +import pandas as pd +import numpy as np +import tensorflow as tf +# pylint: enable=import-error + +from .utils.audio.convertor import ( + db_uint_spectrogram_to_gain, + spectrogram_to_db_uint) +from .utils.audio.spectrogram import ( + compute_spectrogram_tf, + random_pitch_shift, + random_time_stretch) +from .utils.logging import get_logger +from .utils.tensor import ( + check_tensor_shape, + dataset_from_csv, + set_tensor_shape, + sync_apply) + +__email__ = 'research@deezer.com' +__author__ = 'Deezer Research' +__license__ = 'MIT License' + +# Default datasets path parameter to use. +DEFAULT_DATASETS_PATH = join( + 'audio_database', + 'separated_sources', + 'experiments', + 'karaoke_vocal_extraction', + 'tensorflow_experiment' +) + +# Default audio parameters to use. +DEFAULT_AUDIO_PARAMS = { + 'instrument_list': ('vocals', 'accompaniment'), + 'mix_name': 'mix', + 'sample_rate': 44100, + 'frame_length': 4096, + 'frame_step': 1024, + 'T': 512, + 'F': 1024 +} + + +def get_training_dataset(audio_params, audio_adapter, audio_path): + """ Builds training dataset. + + :param audio_params: Audio parameters. + :param audio_adapter: Adapter to load audio from. + :param audio_path: Path of directory containing audio. + :returns: Built dataset. + """ + builder = DatasetBuilder( + audio_params, + audio_adapter, + audio_path, + chunk_duration=audio_params.get('chunk_duration', 20.0), + random_seed=audio_params.get('random_seed', 0)) + return builder.build( + audio_params.get('train_csv'), + cache_directory=audio_params.get('training_cache'), + batch_size=audio_params.get('batch_size'), + n_chunks_per_song=audio_params.get('n_chunks_per_song', 2), + random_data_augmentation=False, + convert_to_uint=True, + wait_for_cache=False) + + +def get_validation_dataset(audio_params, audio_adapter, audio_path): + """ Builds validation dataset. + + :param audio_params: Audio parameters. + :param audio_adapter: Adapter to load audio from. + :param audio_path: Path of directory containing audio. + :returns: Built dataset. + """ + builder = DatasetBuilder( + audio_params, + audio_adapter, + audio_path, + chunk_duration=12.0) + return builder.build( + audio_params.get('validation_csv'), + batch_size=audio_params.get('batch_size'), + cache_directory=audio_params.get('training_cache'), + convert_to_uint=True, + infinite_generator=False, + n_chunks_per_song=1, + # should not perform data augmentation for eval: + random_data_augmentation=False, + random_time_crop=False, + shuffle=False, + ) + + +class InstrumentDatasetBuilder(object): + """ Instrument based filter and mapper provider. """ + + def __init__(self, parent, instrument): + """ Default constructor. + + :param parent: Parent dataset builder. + :param instrument: Target instrument. + """ + self._parent = parent + self._instrument = instrument + self._spectrogram_key = f'{instrument}_spectrogram' + self._min_spectrogram_key = f'min_{instrument}_spectrogram' + self._max_spectrogram_key = f'max_{instrument}_spectrogram' + + def load_waveform(self, sample): + """ Load waveform for given sample. """ + return dict(sample, **self._parent._audio_adapter.load_tf_waveform( + sample[f'{self._instrument}_path'], + offset=sample['start'], + duration=self._parent._chunk_duration, + sample_rate=self._parent._sample_rate, + waveform_name='waveform')) + + def compute_spectrogram(self, sample): + """ Compute spectrogram of the given sample. """ + return dict(sample, **{ + self._spectrogram_key: compute_spectrogram_tf( + sample['waveform'], + frame_length=self._parent._frame_length, + frame_step=self._parent._frame_step, + spec_exponent=1., + window_exponent=1.)}) + + def filter_frequencies(self, sample): + """ """ + return dict(sample, **{ + self._spectrogram_key: + sample[self._spectrogram_key][:, :self._parent._F, :]}) + + def convert_to_uint(self, sample): + """ Convert given sample from float to unit. """ + return dict(sample, **spectrogram_to_db_uint( + sample[self._spectrogram_key], + tensor_key=self._spectrogram_key, + min_key=self._min_spectrogram_key, + max_key=self._max_spectrogram_key)) + + def filter_infinity(self, sample): + """ Filter infinity sample. """ + return tf.logical_not( + tf.math.is_inf( + sample[self._min_spectrogram_key])) + + def convert_to_float32(self, sample): + """ Convert given sample from unit to float. """ + return dict(sample, **{ + self._spectrogram_key: db_uint_spectrogram_to_gain( + sample[self._spectrogram_key], + sample[self._min_spectrogram_key], + sample[self._max_spectrogram_key])}) + + def time_crop(self, sample): + """ """ + def start(sample): + """ mid_segment_start """ + return tf.cast( + tf.maximum( + tf.shape(sample[self._spectrogram_key])[0] + / 2 - self._parent._T / 2, 0), + tf.int32) + return dict(sample, **{ + self._spectrogram_key: sample[self._spectrogram_key][ + start(sample):start(sample) + self._parent._T, :, :]}) + + def filter_shape(self, sample): + """ Filter badly shaped sample. """ + return check_tensor_shape( + sample[self._spectrogram_key], ( + self._parent._T, self._parent._F, 2)) + + def reshape_spectrogram(self, sample): + """ """ + return dict(sample, **{ + self._spectrogram_key: set_tensor_shape( + sample[self._spectrogram_key], + (self._parent._T, self._parent._F, 2))}) + + +class DatasetBuilder(object): + """ + """ + + # Margin at beginning and end of songs in seconds. + MARGIN = 0.5 + + # Wait period for cache (in seconds). + WAIT_PERIOD = 60 + + def __init__( + self, + audio_params, audio_adapter, audio_path, + random_seed=0, chunk_duration=20.0): + """ Default constructor. + + NOTE: Probably need for AudioAdapter. + + :param audio_params: Audio parameters to use. + :param audio_adapter: Audio adapter to use. + :param audio_path: + :param random_seed: + :param chunk_duration: + """ + # Length of segment in frames (if fs=22050 and + # frame_step=512, then T=512 corresponds to 11.89s) + self._T = audio_params['T'] + # Number of frequency bins to be used (should + # be less than frame_length/2 + 1) + self._F = audio_params['F'] + self._sample_rate = audio_params['sample_rate'] + self._frame_length = audio_params['frame_length'] + self._frame_step = audio_params['frame_step'] + self._mix_name = audio_params['mix_name'] + self._instruments = [self._mix_name] + audio_params['instrument_list'] + self._instrument_builders = None + self._chunk_duration = chunk_duration + self._audio_adapter = audio_adapter + self._audio_params = audio_params + self._audio_path = audio_path + self._random_seed = random_seed + + def expand_path(self, sample): + """ Expands audio paths for the given sample. """ + return dict(sample, **{f'{instrument}_path': tf.string_join( + (self._audio_path, sample[f'{instrument}_path']), SEPARATOR) + for instrument in self._instruments}) + + def filter_error(self, sample): + """ Filter errored sample. """ + return tf.logical_not(sample['waveform_error']) + + def filter_waveform(self, sample): + """ Filter waveform from sample. """ + return {k: v for k, v in sample.items() if not k == 'waveform'} + + def harmonize_spectrogram(self, sample): + """ Ensure same size for vocals and mix spectrograms. """ + def _reduce(sample): + return tf.reduce_min([ + tf.shape(sample[f'{instrument}_spectrogram'])[0] + for instrument in self._instruments]) + return dict(sample, **{ + f'{instrument}_spectrogram': + sample[f'{instrument}_spectrogram'][:_reduce(sample), :, :] + for instrument in self._instruments}) + + def filter_short_segments(self, sample): + """ Filter out too short segment. """ + return tf.reduce_any([ + tf.shape(sample[f'{instrument}_spectrogram'])[0] >= self._T + for instrument in self._instruments]) + + def random_time_crop(self, sample): + """ Random time crop of 11.88s. """ + return dict(sample, **sync_apply({ + f'{instrument}_spectrogram': sample[f'{instrument}_spectrogram'] + for instrument in self._instruments}, + lambda x: tf.image.random_crop( + x, (self._T, len(self._instruments) * self._F, 2), + seed=self._random_seed))) + + def random_time_stretch(self, sample): + """ Randomly time stretch the given sample. """ + return dict(sample, **sync_apply({ + f'{instrument}_spectrogram': + sample[f'{instrument}_spectrogram'] + for instrument in self._instruments}, + lambda x: random_time_stretch( + x, factor_min=0.9, factor_max=1.1))) + + def random_pitch_shift(self, sample): + """ Randomly pitch shift the given sample. """ + return dict(sample, **sync_apply({ + f'{instrument}_spectrogram': + sample[f'{instrument}_spectrogram'] + for instrument in self._instruments}, + lambda x: random_pitch_shift( + x, shift_min=-1.0, shift_max=1.0), concat_axis=0)) + + def map_features(self, sample): + """ Select features and annotation of the given sample. """ + input_ = { + f'{self._mix_name}_spectrogram': + sample[f'{self._mix_name}_spectrogram']} + output = { + f'{instrument}_spectrogram': sample[f'{instrument}_spectrogram'] + for instrument in self._audio_params['instrument_list']} + return (input_, output) + + def compute_segments(self, dataset, n_chunks_per_song): + """ Computes segments for each song of the dataset. + + :param dataset: Dataset to compute segments for. + :param n_chunks_per_song: Number of segment per song to compute. + :returns: Segmented dataset. + """ + if n_chunks_per_song <= 0: + raise ValueError('n_chunks_per_song must be positif') + datasets = [] + for k in range(n_chunks_per_song): + if n_chunks_per_song > 1: + datasets.append( + dataset.map(lambda sample: dict(sample, start=tf.maximum( + k * ( + sample['duration'] - self._chunk_duration - 2 + * self.MARGIN) / (n_chunks_per_song - 1) + + self.MARGIN, 0)))) + elif n_chunks_per_song == 1: # Take central segment. + datasets.append( + dataset.map(lambda sample: dict(sample, start=tf.maximum( + sample['duration'] / 2 - self._chunk_duration / 2, + 0)))) + dataset = datasets[-1] + for d in datasets[:-1]: + dataset = dataset.concatenate(d) + return dataset + + @property + def instruments(self): + """ Instrument dataset builder generator. + + :yield InstrumentBuilder instance. + """ + if self._instrument_builders is None: + self._instrument_builders = [] + for instrument in self._instruments: + self._instrument_builders.append( + InstrumentDatasetBuilder(self, instrument)) + for builder in self._instrument_builders: + yield builder + + def cache(self, dataset, cache, wait): + """ Cache the given dataset if cache is enabled. Eventually waits for + cache to be available (useful if another process is already computing + cache) if provided wait flag is True. + + :param dataset: Dataset to be cached if cache is required. + :param cache: Path of cache directory to be used, None if no cache. + :param wait: If caching is enabled, True is cache should be waited. + :returns: Cached dataset if needed, original dataset otherwise. + """ + if cache is not None: + if wait: + while not exists(f'{cache}.index'): + get_logger().info( + 'Cache not available, wait %s', + self.WAIT_PERIOD) + time.sleep(self.WAIT_PERIOD) + cache_path = os.path.split(cache)[0] + os.makedirs(cache_path, exist_ok=True) + return dataset.cache(cache) + return dataset + + def build( + self, csv_path, + batch_size=8, shuffle=True, convert_to_uint=True, + random_data_augmentation=False, random_time_crop=True, + infinite_generator=True, cache_directory=None, + wait_for_cache=False, num_parallel_calls=4, n_chunks_per_song=2,): + """ + TO BE DOCUMENTED. + """ + dataset = dataset_from_csv(csv_path) + dataset = self.compute_segments(dataset, n_chunks_per_song) + # Shuffle data + if shuffle: + dataset = dataset.shuffle( + buffer_size=200000, + seed=self._random_seed, + # useless since it is cached : + reshuffle_each_iteration=True) + # Expand audio path. + dataset = dataset.map(self.expand_path) + # Load waveform, compute spectrogram, and filtering error, + # K bins frequencies, and waveform. + N = num_parallel_calls + for instrument in self.instruments: + dataset = ( + dataset + .map(instrument.load_waveform, num_parallel_calls=N) + .filter(self.filter_error) + .map(instrument.compute_spectrogram, num_parallel_calls=N) + .map(instrument.filter_frequencies)) + dataset = dataset.map(self.filter_waveform) + # Convert to uint before caching in order to save space. + if convert_to_uint: + for instrument in self.instruments: + dataset = dataset.map(instrument.convert_to_uint) + dataset = self.cache(dataset, cache_directory, wait_for_cache) + # Check for INFINITY (should not happen) + for instrument in self.instruments: + dataset = dataset.filter(instrument.filter_infinity) + # Repeat indefinitly + if infinite_generator: + dataset = dataset.repeat(count=-1) + # Ensure same size for vocals and mix spectrograms. + # NOTE: could be done before caching ? + dataset = dataset.map(self.harmonize_spectrogram) + # Filter out too short segment. + # NOTE: could be done before caching ? + dataset = dataset.filter(self.filter_short_segments) + # Random time crop of 11.88s + if random_time_crop: + dataset = dataset.map(self.random_time_crop, num_parallel_calls=N) + else: + # frame_duration = 11.88/T + # take central segment (for validation) + for instrument in self.instruments: + dataset = dataset.map(instrument.time_crop) + # Post cache shuffling. Done where the data are the lightest: + # after croping but before converting back to float. + if shuffle: + dataset = dataset.shuffle( + buffer_size=256, seed=self._random_seed, + reshuffle_each_iteration=True) + # Convert back to float32 + if convert_to_uint: + for instrument in self.instruments: + dataset = dataset.map( + instrument.convert_to_float32, num_parallel_calls=N) + M = 8 # Parallel call post caching. + # Must be applied with the same factor on mix and vocals. + if random_data_augmentation: + dataset = ( + dataset + .map(self.random_time_stretch, num_parallel_calls=M) + .map(self.random_pitch_shift, num_parallel_calls=M)) + # Filter by shape (remove badly shaped tensors). + for instrument in self.instruments: + dataset = ( + dataset + .filter(instrument.filter_shape) + .map(instrument.reshape_spectrogram)) + # Select features and annotation. + dataset = dataset.map(self.map_features) + # Make batch (done after selection to avoid + # error due to unprocessed instrument spectrogram batching). + dataset = dataset.batch(batch_size) + return dataset diff --git a/spleeter/model/__init__.py b/spleeter/model/__init__.py new file mode 100644 index 00000000..384e8389 --- /dev/null +++ b/spleeter/model/__init__.py @@ -0,0 +1,397 @@ +#!/usr/bin/env python +# coding: utf8 + +""" This package provide an estimator builder as well as model functions. """ + +import importlib + +# pylint: disable=import-error +import tensorflow as tf + +from tensorflow.contrib.signal import stft, inverse_stft, hann_window +# pylint: enable=import-error + +from ..utils.tensor import pad_and_partition, pad_and_reshape + +__email__ = 'research@deezer.com' +__author__ = 'Deezer Research' +__license__ = 'MIT License' + + +def get_model_function(model_type): + """ + Get tensorflow function of the model to be applied to the input tensor. + For instance "unet.softmax_unet" will return the softmax_unet function + in the "unet.py" submodule of the current module (spleeter.model). + + Params: + - model_type: str + the relative module path to the model function. + + Returns: + A tensorflow function to be applied to the input tensor to get the + multitrack output. + """ + relative_path_to_module = '.'.join(model_type.split('.')[:-1]) + model_name = model_type.split('.')[-1] + main_module = '.'.join((__name__, 'functions')) + path_to_module = f'{main_module}.{relative_path_to_module}' + module = importlib.import_module(path_to_module) + model_function = getattr(module, model_name) + return model_function + + +class EstimatorSpecBuilder(object): + """ A builder class that allows to builds a multitrack unet model + estimator. The built model estimator has a different behaviour when + used in a train/eval mode and in predict mode. + + * In train/eval mode: it takes as input and outputs magnitude spectrogram + * In predict mode: it takes as input and outputs waveform. The whole + separation process is then done in this function + for performance reason: it makes it possible to run + the whole spearation process (including STFT and + inverse STFT) on GPU. + + :Example: + + >>> from spleeter.model import EstimatorSpecBuilder + >>> builder = EstimatorSpecBuilder() + >>> builder.build_prediction_model() + >>> builder.build_evaluation_model() + >>> builder.build_training_model() + + >>> from spleeter.model import model_fn + >>> estimator = tf.estimator.Estimator(model_fn=model_fn, ...) + """ + + # Supported model functions. + DEFAULT_MODEL = 'unet.unet' + + # Supported loss functions. + L1_MASK = 'L1_mask' + WEIGHTED_L1_MASK = 'weighted_L1_mask' + + # Supported optimizers. + ADADELTA = 'Adadelta' + SGD = 'SGD' + + # Math constants. + WINDOW_COMPENSATION_FACTOR = 2./3. + EPSILON = 1e-10 + + def __init__(self, features, params): + """ Default constructor. Depending on built model + usage, the provided features should be different: + + * In train/eval mode: features is a dictionary with a + "mix_spectrogram" key, associated to the + mix magnitude spectrogram. + * In predict mode: features is a dictionary with a "waveform" + key, associated to the waveform of the sound + to be separated. + + :param features: The input features for the estimator. + :param params: Some hyperparameters as a dictionary. + """ + self._features = features + self._params = params + # Get instrument name. + self._mix_name = params['mix_name'] + self._instruments = params['instrument_list'] + # Get STFT/signals parameters + self._n_channels = params['n_channels'] + self._T = params['T'] + self._F = params['F'] + self._frame_length = params['frame_length'] + self._frame_step = params['frame_step'] + + def _build_output_dict(self): + """ Created a batch_sizexTxFxn_channels input tensor containing + mix magnitude spectrogram, then an output dict from it according + to the selected model in internal parameters. + + :returns: Build output dict. + :raise ValueError: If required model_type is not supported. + """ + input_tensor = self._features[f'{self._mix_name}_spectrogram'] + model = self._params.get('model', None) + if model is not None: + model_type = model.get('type', self.DEFAULT_MODEL) + else: + model_type = self.DEFAULT_MODEL + try: + apply_model = get_model_function(model_type) + except ModuleNotFoundError: + raise ValueError(f'No model function {model_type} found') + return apply_model( + input_tensor, + self._instruments, + self._params['model']['params']) + + def _build_loss(self, output_dict, labels): + """ Construct tensorflow loss and metrics + + :param output_dict: dictionary of network outputs (key: instrument + name, value: estimated spectrogram of the instrument) + :param labels: dictionary of target outputs (key: instrument + name, value: ground truth spectrogram of the instrument) + :returns: tensorflow (loss, metrics) tuple. + """ + loss_type = self._params.get('loss_type', self.L1_MASK) + if loss_type == self.L1_MASK: + losses = { + name: tf.reduce_mean(tf.abs(output - labels[name])) + for name, output in output_dict.items() + } + elif loss_type == self.WEIGHTED_L1_MASK: + losses = { + name: tf.reduce_mean( + tf.reduce_mean( + labels[name], + axis=[1, 2, 3], + keep_dims=True) * + tf.abs(output - labels[name])) + for name, output in output_dict.items() + } + else: + raise ValueError(f"Unkwnown loss type: {loss_type}") + loss = tf.reduce_sum(list(losses.values())) + # Add metrics for monitoring each instrument. + metrics = {k: tf.compat.v1.metrics.mean(v) for k, v in losses.items()} + metrics['absolute_difference'] = tf.compat.v1.metrics.mean(loss) + return loss, metrics + + def _build_optimizer(self): + """ Builds an optimizer instance from internal parameter values. + + Default to AdamOptimizer if not specified. + + :returns: Optimizer instance from internal configuration. + """ + name = self._params.get('optimizer') + if name == self.ADADELTA: + return tf.compat.v1.train.AdadeltaOptimizer() + rate = self._params['learning_rate'] + if name == self.SGD: + return tf.compat.v1.train.GradientDescentOptimizer(rate) + return tf.compat.v1.train.AdamOptimizer(rate) + + def _build_stft_feature(self): + """ Compute STFT of waveform and slice the STFT in segment + with the right length to feed the network. + """ + stft_feature = tf.transpose( + stft( + tf.transpose(self._features['waveform']), + self._frame_length, + self._frame_step, + window_fn=lambda frame_length, dtype: ( + hann_window(frame_length, periodic=True, dtype=dtype)), + pad_end=True), + perm=[1, 2, 0]) + self._features[f'{self._mix_name}_stft'] = stft_feature + self._features[f'{self._mix_name}_spectrogram'] = tf.abs( + pad_and_partition(stft_feature, self._T))[:, :, :self._F, :] + + def _inverse_stft(self, stft): + """ Inverse and reshape the given STFT + + :param stft: input STFT + :returns: inverse STFT (waveform) + """ + inversed = inverse_stft( + tf.transpose(stft, perm=[2, 0, 1]), + self._frame_length, + self._frame_step, + window_fn=lambda frame_length, dtype: ( + hann_window(frame_length, periodic=True, dtype=dtype)) + ) * self.WINDOW_COMPENSATION_FACTOR + reshaped = tf.transpose(inversed) + return reshaped[:tf.shape(self._features['waveform'])[0], :] + + def _build_mwf_output_waveform(self, output_dict): + """ Perform separation with multichannel Wiener Filtering using Norbert. + Note: multichannel Wiener Filtering is not coded in Tensorflow and thus + may be quite slow. + + :param output_dict: dictionary of estimated spectrogram (key: instrument + name, value: estimated spectrogram of the instrument) + :returns: dictionary of separated waveforms (key: instrument name, + value: estimated waveform of the instrument) + """ + import norbert # pylint: disable=import-error + x = self._features[f'{self._mix_name}_stft'] + v = tf.stack( + [ + pad_and_reshape( + output_dict[f'{instrument}_spectrogram'], + self._frame_length, + self._F)[:tf.shape(x)[0], ...] + for instrument in self._instruments + ], + axis=3) + input_args = [v, x] + stft_function = tf.py_function( + lambda v, x: norbert.wiener(v.numpy(), x.numpy()), + input_args, + tf.complex64), + return { + instrument: self._inverse_stft(stft_function[0][:, :, :, k]) + for k, instrument in enumerate(self._instruments) + } + + def _extend_mask(self, mask): + """ Extend mask, from reduced number of frequency bin to the number of + frequency bin in the STFT. + + :param mask: restricted mask + :returns: extended mask + :raise ValueError: If invalid mask_extension parameter is set. + """ + extension = self._params['mask_extension'] + # Extend with average + # (dispatch according to energy in the processed band) + if extension == "average": + extension_row = tf.reduce_mean(mask, axis=2, keepdims=True) + # Extend with 0 + # (avoid extension artifacts but not conservative separation) + elif extension == "zeros": + mask_shape = tf.shape(mask) + extension_row = tf.zeros(( + mask_shape[0], + mask_shape[1], + 1, + mask_shape[-1])) + else: + raise ValueError(f'Invalid mask_extension parameter {extension}') + n_extra_row = (self._frame_length) // 2 + 1 - self._F + extension = tf.tile(extension_row, [1, 1, n_extra_row, 1]) + return tf.concat([mask, extension], axis=2) + + def _build_manual_output_waveform(self, output_dict): + """ Perform ratio mask separation + + :param output_dict: dictionary of estimated spectrogram (key: instrument + name, value: estimated spectrogram of the instrument) + :returns: dictionary of separated waveforms (key: instrument name, + value: estimated waveform of the instrument) + """ + separation_exponent = self._params['separation_exponent'] + output_sum = tf.reduce_sum( + [e ** separation_exponent for e in output_dict.values()], + axis=0 + ) + self.EPSILON + output_waveform = {} + for instrument in self._instruments: + output = output_dict[f'{instrument}_spectrogram'] + # Compute mask with the model. + instrument_mask = ( + output ** separation_exponent + + (self.EPSILON / len(output_dict))) / output_sum + # Extend mask; + instrument_mask = self._extend_mask(instrument_mask) + # Stack back mask. + old_shape = tf.shape(instrument_mask) + new_shape = tf.concat( + [[old_shape[0] * old_shape[1]], old_shape[2:]], + axis=0) + instrument_mask = tf.reshape(instrument_mask, new_shape) + # Remove padded part (for mask having the same size as STFT); + stft_feature = self._features[f'{self._mix_name}_stft'] + instrument_mask = instrument_mask[ + :tf.shape(stft_feature)[0], ...] + # Compute masked STFT and normalize it. + output_waveform[instrument] = self._inverse_stft( + tf.cast(instrument_mask, dtype=tf.complex64) * stft_feature) + return output_waveform + + def _build_output_waveform(self, output_dict): + """ Build output waveform from given output dict in order to be used in + prediction context. Regarding of the configuration building method will + be using MWF. + + :param output_dict: Output dict to build output waveform from. + :returns: Built output waveform. + """ + if self._params.get('MWF', False): + output_waveform = self._build_mwf_output_waveform(output_dict) + else: + output_waveform = self._build_manual_output_waveform(output_dict) + if 'audio_id' in self._features: + output_waveform['audio_id'] = self._features['audio_id'] + return output_waveform + + def build_predict_model(self): + """ Builder interface for creating model instance that aims to perform + prediction / inference over given track. The output of such estimator + will be a dictionary with a "" key per separated instrument + , associated to the estimated separated waveform of the instrument. + + :returns: An estimator for performing prediction. + """ + self._build_stft_feature() + output_dict = self._build_output_dict() + output_waveform = self._build_output_waveform(output_dict) + return tf.estimator.EstimatorSpec( + tf.estimator.ModeKeys.PREDICT, + predictions=output_waveform) + + def build_evaluation_model(self, labels): + """ Builder interface for creating model instance that aims to perform + model evaluation. The output of such estimator will be a dictionary + with a key "_spectrogram" per separated instrument, + associated to the estimated separated instrument magnitude spectrogram. + + :param labels: Model labels. + :returns: An estimator for performing model evaluation. + """ + output_dict = self._build_output_dict() + loss, metrics = self._build_loss(output_dict, labels) + return tf.estimator.EstimatorSpec( + tf.estimator.ModeKeys.EVAL, + loss=loss, + eval_metric_ops=metrics) + + def build_train_model(self, labels): + """ Builder interface for creating model instance that aims to perform + model training. The output of such estimator will be a dictionary + with a key "_spectrogram" per separated instrument, + associated to the estimated separated instrument magnitude spectrogram. + + :param labels: Model labels. + :returns: An estimator for performing model training. + """ + output_dict = self._build_output_dict() + loss, metrics = self._build_loss(output_dict, labels) + optimizer = self._build_optimizer() + train_operation = optimizer.minimize( + loss=loss, + global_step=tf.compat.v1.train.get_global_step()) + return tf.estimator.EstimatorSpec( + mode=tf.estimator.ModeKeys.TRAIN, + loss=loss, + train_op=train_operation, + eval_metric_ops=metrics, + ) + + +def model_fn(features, labels, mode, params, config): + """ + + :param features: + :param labels: + :param mode: Estimator mode. + :param params: + :param config: TF configuration (not used). + :returns: Built EstimatorSpec. + :raise ValueError: If estimator mode is not supported. + """ + builder = EstimatorSpecBuilder(features, params) + if mode == tf.estimator.ModeKeys.PREDICT: + return builder.build_predict_model() + elif mode == tf.estimator.ModeKeys.EVAL: + return builder.build_evaluation_model(labels) + elif mode == tf.estimator.ModeKeys.TRAIN: + return builder.build_train_model(labels) + raise ValueError(f'Unknown mode {mode}') diff --git a/spleeter/model/functions/__init__.py b/spleeter/model/functions/__init__.py new file mode 100644 index 00000000..abe52e9a --- /dev/null +++ b/spleeter/model/functions/__init__.py @@ -0,0 +1,27 @@ +#!/usr/bin/env python +# coding: utf8 + +""" This package provide model functions. """ + +__email__ = 'research@deezer.com' +__author__ = 'Deezer Research' +__license__ = 'MIT License' + + +def apply(function, input_tensor, instruments, params={}): + """ Apply given function to the input tensor. + + :param function: Function to be applied to tensor. + :param input_tensor: Tensor to apply blstm to. + :param instruments: Iterable that provides a collection of instruments. + :param params: (Optional) dict of BLSTM parameters. + :returns: Created output tensor dict. + """ + output_dict = {} + for instrument in instruments: + out_name = f'{instrument}_spectrogram' + output_dict[out_name] = function( + input_tensor, + output_name=out_name, + params=params) + return output_dict diff --git a/spleeter/model/functions/blstm.py b/spleeter/model/functions/blstm.py new file mode 100644 index 00000000..ff7ce020 --- /dev/null +++ b/spleeter/model/functions/blstm.py @@ -0,0 +1,76 @@ +#!/usr/bin/env python +# coding: utf8 + +""" + This system (UHL1) uses a bi-directional LSTM network as described in : + + `S. Uhlich, M. Porcu, F. Giron, M. Enenkl, T. Kemp, N. Takahashi and + Y. Mitsufuji. + + "Improving music source separation based on deep neural networks through + data augmentation and network blending", Proc. ICASSP, 2017.` + + It has three BLSTM layers, each having 500 cells. For each instrument, + a network is trained which predicts the target instrument amplitude from + the mixture amplitude in the STFT domain (frame size: 4096, hop size: + 1024). The raw output of each network is then combined by a multichannel + Wiener filter. The network is trained on musdb where we split train into + train_train and train_valid with 86 and 14 songs, respectively. The + validation set is used to perform early stopping and hyperparameter + selection (LSTM layer dropout rate, regularization strength). +""" + +# pylint: disable=import-error +from tensorflow.compat.v1.keras.initializers import he_uniform +from tensorflow.compat.v1.keras.layers import CuDNNLSTM +from tensorflow.keras.layers import ( + Bidirectional, + Dense, + Flatten, + Reshape, + TimeDistributed) +# pylint: enable=import-error + +from . import apply + +__email__ = 'research@deezer.com' +__author__ = 'Deezer Research' +__license__ = 'MIT License' + + +def apply_blstm(input_tensor, output_name='output', params={}): + """ Apply BLSTM to the given input_tensor. + + :param input_tensor: Input of the model. + :param output_name: (Optional) name of the output, default to 'output'. + :param params: (Optional) dict of BLSTM parameters. + :returns: Output tensor. + """ + units = params.get('lstm_units', 250) + kernel_initializer = he_uniform(seed=50) + flatten_input = TimeDistributed(Flatten())((input_tensor)) + + def create_bidirectional(): + return Bidirectional( + CuDNNLSTM( + units, + kernel_initializer=kernel_initializer, + return_sequences=True)) + + l1 = create_bidirectional()((flatten_input)) + l2 = create_bidirectional()((l1)) + l3 = create_bidirectional()((l2)) + dense = TimeDistributed( + Dense( + int(flatten_input.shape[2]), + activation='relu', + kernel_initializer=kernel_initializer))((l3)) + output = TimeDistributed( + Reshape(input_tensor.shape[2:]), + name=output_name)(dense) + return output + + +def blstm(input_tensor, output_name='output', params={}): + """ Model function applier. """ + return apply(apply_blstm, input_tensor, output_name, params) diff --git a/spleeter/model/functions/unet.py b/spleeter/model/functions/unet.py new file mode 100644 index 00000000..245a5e58 --- /dev/null +++ b/spleeter/model/functions/unet.py @@ -0,0 +1,201 @@ +#!/usr/bin/env python +# coding: utf8 + +""" +This module contains building functions for U-net source separation source +separation models. +Each instrument is modeled by a single U-net convolutional/deconvolutional +network that take a mix spectrogram as input and the estimated sound spectrogram +as output. +""" + +from functools import partial + +# pylint: disable=import-error +import tensorflow as tf + +from tensorflow.keras.layers import ( + BatchNormalization, + Concatenate, + Conv2D, + Conv2DTranspose, + Dropout, + ELU, + LeakyReLU, + Multiply, + ReLU, + Softmax) +from tensorflow.compat.v1 import logging +from tensorflow.compat.v1.keras.initializers import he_uniform +# pylint: enable=import-error + +from . import apply + +__email__ = 'research@deezer.com' +__author__ = 'Deezer Research' +__license__ = 'MIT License' + + +def _get_conv_activation_layer(params): + """ + + :param params: + :returns: Required Activation function. + """ + conv_activation = params.get('conv_activation') + if conv_activation == 'ReLU': + return ReLU() + elif conv_activation == 'ELU': + return ELU() + return LeakyReLU(0.2) + + +def _get_deconv_activation_layer(params): + """ + + :param params: + :returns: Required Activation function. + """ + deconv_activation = params.get('deconv_activation') + if deconv_activation == 'LeakyReLU': + return LeakyReLU(0.2) + elif deconv_activation == 'ELU': + return ELU() + return ReLU() + + +def apply_unet( + input_tensor, + output_name='output', + params={}, + output_mask_logit=False): + """ Apply a convolutionnal U-net to model a single instrument (one U-net + is used for each instrument). + + :param input_tensor: + :param output_name: (Optional) , default to 'output' + :param params: (Optional) , default to empty dict. + :param output_mask_logit: (Optional) , default to False. + """ + logging.info(f'Apply unet for {output_name}') + conv_n_filters = params.get('conv_n_filters', [16, 32, 64, 128, 256, 512]) + conv_activation_layer = _get_conv_activation_layer(params) + deconv_activation_layer = _get_deconv_activation_layer(params) + kernel_initializer = he_uniform(seed=50) + conv2d_factory = partial( + Conv2D, + strides=(2, 2), + padding='same', + kernel_initializer=kernel_initializer) + # First layer. + conv1 = conv2d_factory(conv_n_filters[0], (5, 5))(input_tensor) + batch1 = BatchNormalization(axis=-1)(conv1) + rel1 = conv_activation_layer(batch1) + # Second layer. + conv2 = conv2d_factory(conv_n_filters[1], (5, 5))(rel1) + batch2 = BatchNormalization(axis=-1)(conv2) + rel2 = conv_activation_layer(batch2) + # Third layer. + conv3 = conv2d_factory(conv_n_filters[2], (5, 5))(rel2) + batch3 = BatchNormalization(axis=-1)(conv3) + rel3 = conv_activation_layer(batch3) + # Fourth layer. + conv4 = conv2d_factory(conv_n_filters[3], (5, 5))(rel3) + batch4 = BatchNormalization(axis=-1)(conv4) + rel4 = conv_activation_layer(batch4) + # Fifth layer. + conv5 = conv2d_factory(conv_n_filters[4], (5, 5))(rel4) + batch5 = BatchNormalization(axis=-1)(conv5) + rel5 = conv_activation_layer(batch5) + # Sixth layer + conv6 = conv2d_factory(conv_n_filters[5], (5, 5))(rel5) + batch6 = BatchNormalization(axis=-1)(conv6) + _ = conv_activation_layer(batch6) + # + # + conv2d_transpose_factory = partial( + Conv2DTranspose, + strides=(2, 2), + padding='same', + kernel_initializer=kernel_initializer) + # + up1 = conv2d_transpose_factory(conv_n_filters[4], (5, 5))((conv6)) + up1 = deconv_activation_layer(up1) + batch7 = BatchNormalization(axis=-1)(up1) + drop1 = Dropout(0.5)(batch7) + merge1 = Concatenate(axis=-1)([conv5, drop1]) + # + up2 = conv2d_transpose_factory(conv_n_filters[3], (5, 5))((merge1)) + up2 = deconv_activation_layer(up2) + batch8 = BatchNormalization(axis=-1)(up2) + drop2 = Dropout(0.5)(batch8) + merge2 = Concatenate(axis=-1)([conv4, drop2]) + # + up3 = conv2d_transpose_factory(conv_n_filters[2], (5, 5))((merge2)) + up3 = deconv_activation_layer(up3) + batch9 = BatchNormalization(axis=-1)(up3) + drop3 = Dropout(0.5)(batch9) + merge3 = Concatenate(axis=-1)([conv3, drop3]) + # + up4 = conv2d_transpose_factory(conv_n_filters[1], (5, 5))((merge3)) + up4 = deconv_activation_layer(up4) + batch10 = BatchNormalization(axis=-1)(up4) + merge4 = Concatenate(axis=-1)([conv2, batch10]) + # + up5 = conv2d_transpose_factory(conv_n_filters[0], (5, 5))((merge4)) + up5 = deconv_activation_layer(up5) + batch11 = BatchNormalization(axis=-1)(up5) + merge5 = Concatenate(axis=-1)([conv1, batch11]) + # + up6 = conv2d_transpose_factory(1, (5, 5), strides=(2, 2))((merge5)) + up6 = deconv_activation_layer(up6) + batch12 = BatchNormalization(axis=-1)(up6) + # Last layer to ensure initial shape reconstruction. + if not output_mask_logit: + up7 = Conv2D( + 2, + (4, 4), + dilation_rate=(2, 2), + activation='sigmoid', + padding='same', + kernel_initializer=kernel_initializer)((batch12)) + output = Multiply(name=output_name)([up7, input_tensor]) + return output + return Conv2D( + 2, + (4, 4), + dilation_rate=(2, 2), + padding='same', + kernel_initializer=kernel_initializer)((batch12)) + + +def unet(input_tensor, instruments, params={}): + """ Model function applier. """ + return apply(apply_unet, input_tensor, instruments, params) + + +def softmax_unet(input_tensor, instruments, params={}): + """ Apply softmax to multitrack unet in order to have mask suming to one. + + :param input_tensor: Tensor to apply blstm to. + :param instruments: Iterable that provides a collection of instruments. + :param params: (Optional) dict of BLSTM parameters. + :returns: Created output tensor dict. + """ + logit_mask_list = [] + for instrument in instruments: + out_name = f'{instrument}_spectrogram' + logit_mask_list.append( + apply_unet( + input_tensor, + output_name=out_name, + params=params, + output_mask_logit=True)) + masks = Softmax(axis=4)(tf.stack(logit_mask_list, axis=4)) + output_dict = {} + for i, instrument in enumerate(instruments): + out_name = f'{instrument}_spectrogram' + output_dict[out_name] = Multiply(name=out_name)([ + masks[..., i], + input_tensor]) + return output_dict diff --git a/spleeter/model/provider/__init__.py b/spleeter/model/provider/__init__.py new file mode 100644 index 00000000..854b065c --- /dev/null +++ b/spleeter/model/provider/__init__.py @@ -0,0 +1,79 @@ +#!/usr/bin/env python +# coding: utf8 + +""" + This package provides tools for downloading model from network + using remote storage abstraction. + + :Example: + + >>> provider = MyProviderImplementation() + >>> provider.get('/path/to/local/storage', params) +""" + +from abc import ABC, abstractmethod +from os import environ, makedirs +from os.path import exists, isabs, join, sep + +__email__ = 'research@deezer.com' +__author__ = 'Deezer Research' +__license__ = 'MIT License' + + +class ModelProvider(ABC): + """ + A ModelProvider manages model files on disk and + file download is not available. + """ + + DEFAULT_MODEL_PATH = environ.get('MODEL_PATH', 'pretrained_models') + MODEL_PROBE_PATH = '.probe' + + @abstractmethod + def download(self, name, path): + """ Download model denoted by the given name to disk. + + :param name: Name of the model to download. + :param path: Path of the directory to save model into. + """ + pass + + def writeProbe(self, directory): + """ Write a model probe file into the given directory. + + :param directory: Directory to write probe into. + """ + with open(join(directory, self.MODEL_PROBE_PATH), 'w') as stream: + stream.write('OK') + + def get(self, model_directory): + """ Ensures required model is available at given location. + + :param model_directory: Expected model_directory to be available. + :raise IOError: If model can not be retrieved. + """ + # Expend model directory if needed. + if not isabs(model_directory): + model_directory = join(self.DEFAULT_MODEL_PATH, model_directory) + # Download it if not exists. + model_probe = join(model_directory, self.MODEL_PROBE_PATH) + if not exists(model_probe): + if not exists(model_directory): + makedirs(model_directory) + self.download( + model_directory.split(sep)[-1], + model_directory) + self.writeProbe(model_directory) + return model_directory + + +def get_default_model_provider(): + """ Builds and returns a default model provider. + + :returns: A default model provider instance to use. + """ + from .github import GithubModelProvider + host = environ.get('GITHUB_HOST', 'https://github.com') + repository = environ.get('GITHUB_REPOSITORY', 'deezer/spleeter') + release = environ.get('GITHUB_RELEASE', GithubModelProvider.LATEST_RELEASE) + return GithubModelProvider(host, repository, release) diff --git a/spleeter/model/provider/github.py b/spleeter/model/provider/github.py new file mode 100644 index 00000000..cc7028c8 --- /dev/null +++ b/spleeter/model/provider/github.py @@ -0,0 +1,73 @@ +#!/usr/bin/env python +# coding: utf8 + +""" + A ModelProvider backed by Github Release feature. + + :Example: + + >>> from spleeter.model.provider import github + >>> provider = github.GithubModelProvider( + 'github.com', + 'Deezer/spleeter', + 'latest') + >>> provider.download('2stems', '/path/to/local/storage') +""" + +import tarfile + +from os import environ +from tempfile import TemporaryFile +from shutil import copyfileobj + +import requests + +from . import ModelProvider +from ...utils.logging import get_logger + +__email__ = 'research@deezer.com' +__author__ = 'Deezer Research' +__license__ = 'MIT License' + + +class GithubModelProvider(ModelProvider): + """ A ModelProvider implementation backed on Github for remote storage. """ + + LATEST_RELEASE = 'v1.4.0' + RELEASE_PATH = 'releases/download' + + def __init__(self, host, repository, release): + """ Default constructor. + + :param host: Host to the Github instance to reach. + :param repository: Repository path within target Github. + :param release: Release name to get models from. + """ + self._host = host + self._repository = repository + self._release = release + + def download(self, name, path): + """ Download model denoted by the given name to disk. + + :param name: Name of the model to download. + :param path: Path of the directory to save model into. + """ + url = '{}/{}/{}/{}/{}.tar.gz'.format( + self._host, + self._repository, + self.RELEASE_PATH, + self._release, + name) + get_logger().info('Downloading model archive %s', url) + response = requests.get(url, stream=True) + if response.status_code != 200: + raise IOError(f'Resource {url} not found') + with TemporaryFile() as stream: + copyfileobj(response.raw, stream) + get_logger().debug('Extracting downloaded archive') + stream.seek(0) + tar = tarfile.open(fileobj=stream) + tar.extractall(path=path) + tar.close() + get_logger().debug('Model file extracted') diff --git a/spleeter/resources/2stems.json b/spleeter/resources/2stems.json new file mode 100644 index 00000000..5f0f5fe2 --- /dev/null +++ b/spleeter/resources/2stems.json @@ -0,0 +1,28 @@ +{ + "train_csv": "path/to/train.csv", + "validation_csv": "path/to/test.csv", + "model_dir": "2stems", + "mix_name": "mix", + "instrument_list": ["vocals", "accompaniment"], + "sample_rate":44100, + "frame_length":4096, + "frame_step":1024, + "T":512, + "F":1024, + "n_channels":2, + "separation_exponent":2, + "mask_extension":"zeros", + "learning_rate": 1e-4, + "batch_size":4, + "training_cache":"training_cache", + "validation_cache":"validation_cache", + "train_max_steps": 1000000, + "throttle_secs":300, + "random_seed":0, + "save_checkpoints_steps":150, + "save_summary_steps":5, + "model":{ + "type":"unet.unet", + "params":{} + } +} diff --git a/spleeter/resources/4stems.json b/spleeter/resources/4stems.json new file mode 100644 index 00000000..a6461546 --- /dev/null +++ b/spleeter/resources/4stems.json @@ -0,0 +1,31 @@ +{ + "train_csv": "path/to/train.csv", + "validation_csv": "path/to/val.csv", + "model_dir": "4stems", + "mix_name": "mix", + "instrument_list": ["vocals", "drums", "bass", "other"], + "sample_rate":44100, + "frame_length":4096, + "frame_step":1024, + "T":512, + "F":1024, + "n_channels":2, + "separation_exponent":2, + "mask_extension":"zeros", + "learning_rate": 1e-4, + "batch_size":4, + "training_cache":"training_cache", + "validation_cache":"validation_cache", + "train_max_steps": 1500000, + "throttle_secs":600, + "random_seed":3, + "save_checkpoints_steps":300, + "save_summary_steps":5, + "model":{ + "type":"unet.unet", + "params":{ + "conv_activation":"ELU", + "deconv_activation":"ELU" + } + } +} diff --git a/spleeter/resources/5stems.json b/spleeter/resources/5stems.json new file mode 100644 index 00000000..aad63121 --- /dev/null +++ b/spleeter/resources/5stems.json @@ -0,0 +1,31 @@ +{ + "train_csv": "path/to/train.csv", + "validation_csv": "path/to/test.csv", + "model_dir": "5stems", + "mix_name": "mix", + "instrument_list": ["vocals", "piano", "drums", "bass", "other"], + "sample_rate":44100, + "frame_length":4096, + "frame_step":1024, + "T":512, + "F":1024, + "n_channels":2, + "separation_exponent":2, + "mask_extension":"zeros", + "learning_rate": 1e-4, + "batch_size":4, + "training_cache":"training_cache", + "validation_cache":"validation_cache", + "train_max_steps": 2500000, + "throttle_secs":600, + "random_seed":8, + "save_checkpoints_steps":300, + "save_summary_steps":5, + "model":{ + "type":"unet.softmax_unet", + "params":{ + "conv_activation":"ELU", + "deconv_activation":"ELU" + } + } +} diff --git a/spleeter/resources/__init__.py b/spleeter/resources/__init__.py new file mode 100644 index 00000000..41d2a651 --- /dev/null +++ b/spleeter/resources/__init__.py @@ -0,0 +1,8 @@ +#!/usr/bin/env python +# coding: utf8 + +""" Packages that provides static resources file for the library. """ + +__email__ = 'research@deezer.com' +__author__ = 'Deezer Research' +__license__ = 'MIT License' diff --git a/spleeter/resources/musdb.json b/spleeter/resources/musdb.json new file mode 100644 index 00000000..0d447006 --- /dev/null +++ b/spleeter/resources/musdb.json @@ -0,0 +1,32 @@ +{ + "train_csv": "configs/musdb_train.csv", + "validation_csv": "configs/musdb_validation.csv", + "model_dir": "musdb_model", + "mix_name": "mix", + "instrument_list": ["vocals", "drums", "bass", "other"], + "sample_rate":44100, + "frame_length":4096, + "frame_step":1024, + "T":512, + "F":1024, + "n_channels":2, + "n_chunks_per_song":1, + "separation_exponent":2, + "mask_extension":"zeros", + "learning_rate": 1e-4, + "batch_size":4, + "training_cache":"training_cache", + "validation_cache":"validation_cache", + "train_max_steps": 100000, + "throttle_secs":600, + "random_seed":3, + "save_checkpoints_steps":300, + "save_summary_steps":5, + "model":{ + "type":"unet.unet", + "params":{ + "conv_activation":"ELU", + "deconv_activation":"ELU" + } + } +} diff --git a/spleeter/separator.py b/spleeter/separator.py new file mode 100644 index 00000000..a238037a --- /dev/null +++ b/spleeter/separator.py @@ -0,0 +1,127 @@ +#!/usr/bin/env python +# coding: utf8 + +""" + Module that provides a class wrapper for source separation. + + :Example: + + >>> from spleeter.separator import Separator + >>> separator = Separator('spleeter:2stems') + >>> separator.separate(waveform, lambda instrument, data: ...) + >>> separator.separate_to_file(...) +""" + +import os +import json + +from functools import partial +from multiprocessing import Pool +from pathlib import Path +from os.path import join + +from .model import model_fn +from .utils.audio.adapter import get_default_audio_adapter +from .utils.audio.convertor import to_stereo +from .utils.configuration import load_configuration +from .utils.estimator import create_estimator, to_predictor + +__email__ = 'research@deezer.com' +__author__ = 'Deezer Research' +__license__ = 'MIT License' + + +class Separator(object): + """ A wrapper class for performing separation. """ + + def __init__(self, params_descriptor, MWF=False): + """ Default constructor. + + :param params_descriptor: Descriptor for TF params to be used. + :param MWF: (Optional) True if MWF should be used, False otherwise. + """ + self._params = load_configuration(params_descriptor) + self._sample_rate = self._params['sample_rate'] + self._MWF = MWF + self._predictor = None + self._pool = Pool() + self._tasks = [] + + def _get_predictor(self): + """ Lazy loading access method for internal predictor instance. + + :returns: Predictor to use for source separation. + """ + if self._predictor is None: + estimator = create_estimator(self._params, self._MWF) + self._predictor = to_predictor(estimator) + return self._predictor + + def join(self, timeout=20): + """ Wait for all pending tasks to be finished. + + :param timeout: (Optional) task waiting timeout. + """ + while len(self._tasks) > 0: + task = self._tasks.pop() + task.get() + task.wait(timeout=timeout) + + def separate(self, waveform): + """ Performs source separation over the given waveform. + + The separation is performed synchronously but the result + processing is done asynchronously, allowing for instance + to export audio in parallel (through multiprocessing). + + Given result is passed by to the given consumer, which will + be waited for task finishing if synchronous flag is True. + + :param waveform: Waveform to apply separation on. + :returns: Separated waveforms. + """ + if not waveform.shape[-1] == 2: + waveform = to_stereo(waveform) + predictor = self._get_predictor() + prediction = predictor({ + 'waveform': waveform, + 'audio_id': ''}) + prediction.pop('audio_id') + return prediction + + def separate_to_file( + self, audio_descriptor, destination, + audio_adapter=get_default_audio_adapter(), + offset=0, duration=600., codec='wav', bitrate='128k', + synchronous=True): + """ Performs source separation and export result to file using + given audio adapter. + + :param audio_descriptor: Describe song to separate, used by audio + adapter to retrieve and load audio data, + in case of file based audio adapter, such + descriptor would be a file path. + :param destination: Target directory to write output to. + :param audio_adapter: (Optional) Audio adapter to use for I/O. + :param offset: (Optional) Offset of loaded song. + :param duration: (Optional) Duration of loaded song. + :param codec: (Optional) Export codec. + :param bitrate: (Optional) Export bitrate. + :param synchronous: (Optional) True is should by synchronous. + """ + waveform, _ = audio_adapter.load( + audio_descriptor, + offset=offset, + duration=duration, + sample_rate=self._sample_rate) + sources = self.separate(waveform) + for instrument, data in sources.items(): + task = self._pool.apply_async(audio_adapter.save, ( + join(destination, f'{instrument}.{codec}'), + data, + self._sample_rate, + codec, + bitrate)) + self._tasks.append(task) + if synchronous: + self.join() diff --git a/spleeter/utils/__init__.py b/spleeter/utils/__init__.py new file mode 100644 index 00000000..a4ccb5bc --- /dev/null +++ b/spleeter/utils/__init__.py @@ -0,0 +1,8 @@ +#!/usr/bin/env python +# coding: utf8 + +""" This package provides utility function and classes. """ + +__email__ = 'research@deezer.com' +__author__ = 'Deezer Research' +__license__ = 'MIT License' diff --git a/spleeter/utils/audio/__init__.py b/spleeter/utils/audio/__init__.py new file mode 100644 index 00000000..02f83c0f --- /dev/null +++ b/spleeter/utils/audio/__init__.py @@ -0,0 +1,15 @@ +#!/usr/bin/env python +# coding: utf8 + +""" + `spleeter.utils.audio` package provides various + tools for manipulating audio content such as : + + - Audio adapter class for abstract interaction with audio file. + - FFMPEG implementation for audio adapter. + - Waveform convertion and transforming functions. +""" + +__email__ = 'research@deezer.com' +__author__ = 'Deezer Research' +__license__ = 'MIT License' diff --git a/spleeter/utils/audio/adapter.py b/spleeter/utils/audio/adapter.py new file mode 100644 index 00000000..b2d7cb1f --- /dev/null +++ b/spleeter/utils/audio/adapter.py @@ -0,0 +1,144 @@ +#!/usr/bin/env python +# coding: utf8 + +""" AudioAdapter class defintion. """ + +import subprocess + +from abc import ABC, abstractmethod +from importlib import import_module +from os.path import exists + +# pylint: disable=import-error +import numpy as np +import tensorflow as tf + +from tensorflow.contrib.signal import stft, hann_window +# pylint: enable=import-error + +from ..logging import get_logger + +__email__ = 'research@deezer.com' +__author__ = 'Deezer Research' +__license__ = 'MIT License' + + +class AudioAdapter(ABC): + """ An abstract class for manipulating audio signal. """ + + # Default audio adapter singleton instance. + DEFAULT = None + + @abstractmethod + def load( + self, audio_descriptor, offset, duration, + sample_rate, dtype=np.float32): + """ Loads the audio file denoted by the given audio descriptor + and returns it data as a waveform. Aims to be implemented + by client. + + :param audio_descriptor: Describe song to load, in case of file + based audio adapter, such descriptor would + be a file path. + :param offset: Start offset to load from in seconds. + :param duration: Duration to load in seconds. + :param sample_rate: Sample rate to load audio with. + :param dtype: Numpy data type to use, default to float32. + :returns: Loaded data as (wf, sample_rate) tuple. + """ + pass + + def load_tf_waveform( + self, audio_descriptor, + offset=0.0, duration=1800., sample_rate=44100, + dtype=b'float32', waveform_name='waveform'): + """ Load the audio and convert it to a tensorflow waveform. + + :param audio_descriptor: Describe song to load, in case of file + based audio adapter, such descriptor would + be a file path. + :param offset: Start offset to load from in seconds. + :param duration: Duration to load in seconds. + :param sample_rate: Sample rate to load audio with. + :param dtype: Numpy data type to use, default to float32. + :param waveform_name: (Optional) Name of the key in output dict. + :returns: TF output dict with waveform as + (T x chan numpy array) and a boolean that + tells whether there were an error while + trying to load the waveform. + """ + # Cast parameters to TF format. + offset = tf.cast(offset, tf.float64) + duration = tf.cast(duration, tf.float64) + + # Defined safe loading function. + def safe_load(path, offset, duration, sample_rate, dtype): + get_logger().info( + f'Loading audio {path} from {offset} to {offset + duration}') + try: + (data, _) = self.load( + path.numpy(), + offset.numpy(), + duration.numpy(), + sample_rate.numpy(), + dtype=dtype.numpy()) + return (data, False) + except Exception as e: + get_logger().warning(e) + return (np.float32(-1.0), True) + + # Execute function and format results. + results = tf.py_function( + safe_load, + [audio_descriptor, offset, duration, sample_rate, dtype], + (tf.float32, tf.bool)), + waveform, error = results[0] + return { + waveform_name: waveform, + f'{waveform_name}_error': error + } + + @abstractmethod + def save( + self, path, data, sample_rate, + codec=None, bitrate=None): + """ Save the given audio data to the file denoted by + the given path. + + :param path: Path of the audio file to save data in. + :param data: Waveform data to write. + :param sample_rate: Sample rate to write file in. + :param codec: (Optional) Writing codec to use. + :param bitrate: (Optional) Bitrate of the written audio file. + """ + pass + + +def get_default_audio_adapter(): + """ Builds and returns a default audio adapter instance. + + :returns: An audio adapter instance. + """ + if AudioAdapter.DEFAULT is None: + from .ffmpeg import FFMPEGProcessAudioAdapter + AudioAdapter.DEFAULT = FFMPEGProcessAudioAdapter() + return AudioAdapter.DEFAULT + + +def get_audio_adapter(descriptor): + """ Load dynamically an AudioAdapter from given class descriptor. + + :param descriptor: Adapter class descriptor (module.Class) + :returns: Created adapter instance. + """ + if descriptor is None: + return get_default_audio_adapter() + module_path = descriptor.split('.') + adapter_class_name = module_path[-1] + module_path = '.'.join(module_path[:-1]) + adapter_module = import_module(module_path) + adapter_class = getattr(adapter_module, adapter_class_name) + if not isinstance(adapter_class, AudioAdapter): + raise ValueError( + f'{adapter_class_name} is not a valid AudioAdapter class') + return adapter_class() diff --git a/spleeter/utils/audio/convertor.py b/spleeter/utils/audio/convertor.py new file mode 100644 index 00000000..b6a79534 --- /dev/null +++ b/spleeter/utils/audio/convertor.py @@ -0,0 +1,88 @@ +#!/usr/bin/env python +# coding: utf8 + +""" This module provides audio data convertion functions. """ + +# pylint: disable=import-error +import numpy as np +import tensorflow as tf +# pylint: enable=import-error + +from ..tensor import from_float32_to_uint8, from_uint8_to_float32 + +__email__ = 'research@deezer.com' +__author__ = 'Deezer Research' +__license__ = 'MIT License' + + +def to_n_channels(waveform, n_channels): + """ Convert a waveform to n_channels by removing or + duplicating channels if needed (in tensorflow). + + :param waveform: Waveform to transform. + :param n_channels: Number of channel to reshape waveform in. + :returns: Reshaped waveform. + """ + return tf.cond( + tf.shape(waveform)[1] >= n_channels, + true_fn=lambda: waveform[:, :n_channels], + false_fn=lambda: tf.tile(waveform, [1, n_channels])[:, :n_channels] + ) + + +def to_stereo(waveform): + """ Convert a waveform to stereo by duplicating if mono, + or truncating if too many channels. + + :param waveform: a (N, d) numpy array. + :returns: A stereo waveform as a (N, 1) numpy array. + """ + if waveform.shape[1] == 1: + return np.repeat(waveform, 2, axis=-1) + if waveform.shape[1] > 2: + return waveform[:, :2] + return waveform + + +def gain_to_db(tensor, espilon=10e-10): + """ Convert from gain to decibel in tensorflow. + + :param tensor: Tensor to convert. + :param epsilon: Operation constant. + :returns: Converted tensor. + """ + return 20. / np.log(10) * tf.math.log(tf.maximum(tensor, espilon)) + + +def db_to_gain(tensor): + """ Convert from decibel to gain in tensorflow. + + :param tensor_db: Tensor to convert. + :returns: Converted tensor. + """ + return tf.pow(10., (tensor / 20.)) + + +def spectrogram_to_db_uint(spectrogram, db_range=100., **kwargs): + """ Encodes given spectrogram into uint8 using decibel scale. + + :param spectrogram: Spectrogram to be encoded as TF float tensor. + :param db_range: Range in decibel for encoding. + :returns: Encoded decibel spectrogram as uint8 tensor. + """ + db_spectrogram = gain_to_db(spectrogram) + max_db_spectrogram = tf.reduce_max(db_spectrogram) + db_spectrogram = tf.maximum(db_spectrogram, max_db_spectrogram - db_range) + return from_float32_to_uint8(db_spectrogram, **kwargs) + + +def db_uint_spectrogram_to_gain(db_uint_spectrogram, min_db, max_db): + """ Decode spectrogram from uint8 decibel scale. + + :param db_uint_spectrogram: Decibel pectrogram to decode. + :param min_db: Lower bound limit for decoding. + :param max_db: Upper bound limit for decoding. + :returns: Decoded spectrogram as float2 tensor. + """ + db_spectrogram = from_uint8_to_float32(db_uint_spectrogram, min_db, max_db) + return db_to_gain(db_spectrogram) diff --git a/spleeter/utils/audio/ffmpeg.py b/spleeter/utils/audio/ffmpeg.py new file mode 100644 index 00000000..ad24e331 --- /dev/null +++ b/spleeter/utils/audio/ffmpeg.py @@ -0,0 +1,263 @@ +#!/usr/bin/env python +# coding: utf8 + +""" + This module provides an AudioAdapter implementation based on FFMPEG + process. Such implementation is POSIXish and depends on nothing except + standard Python libraries. Thus this implementation is the default one + used within this library. +""" + +import os +import os.path +import platform +import re +import subprocess + +import numpy as np # pylint: disable=import-error + +from .adapter import AudioAdapter +from ..logging import get_logger + +__email__ = 'research@deezer.com' +__author__ = 'Deezer Research' +__license__ = 'MIT License' + +# Default FFMPEG binary name. +_UNIX_BINARY = 'ffmpeg' +_WINDOWS_BINARY = 'ffmpeg.exe' + + +def _which(program): + """ A pure python implementation of `which`command + for retrieving absolute path from command name or path. + + @see https://stackoverflow.com/a/377028/1211342 + + :param program: Program name or path to expend. + :returns: Absolute path of program if any, None otherwise. + """ + def is_exe(fpath): + return os.path.isfile(fpath) and os.access(fpath, os.X_OK) + + fpath, _ = os.path.split(program) + if fpath: + if is_exe(program): + return program + else: + for path in os.environ['PATH'].split(os.pathsep): + exe_file = os.path.join(path, program) + if is_exe(exe_file): + return exe_file + return None + + +def _get_ffmpeg_path(): + """ Retrieves FFMPEG binary path using ENVVAR if defined + or default binary name (Windows or UNIX style). + + :returns: Absolute path of FFMPEG binary. + :raise IOError: If FFMPEG binary cannot be found. + """ + ffmpeg_path = os.environ.get('FFMPEG_PATH', None) + if ffmpeg_path is None: + # Note: try to infer standard binary name regarding of platform. + if platform.system() == 'Windows': + ffmpeg_path = _WINDOWS_BINARY + else: + ffmpeg_path = _UNIX_BINARY + expended = _which(ffmpeg_path) + if expended is None: + raise IOError(f'FFMPEG binary ({ffmpeg_path}) not found') + return expended + + +def _to_ffmpeg_time(n): + """ Format number of seconds to time expected by FFMPEG. + + :param n: Time in seconds to format. + :returns: Formatted time in FFMPEG format. + """ + m, s = divmod(n, 60) + h, m = divmod(m, 60) + return '%d:%02d:%09.6f' % (h, m, s) + + +def _parse_ffmpg_results(stderr): + """ Extract number of channels and sample rate from + the given FFMPEG STDERR output line. + + :param stderr: STDERR output line to parse. + :returns: Parsed n_channels and sample_rate values. + """ + # Setup default value. + n_channels = 0 + sample_rate = 0 + # Find samplerate + match = re.search(r'(\d+) hz', stderr) + if match: + sample_rate = int(match.group(1)) + # Channel count. + match = re.search(r'hz, ([^,]+),', stderr) + if match: + mode = match.group(1) + if mode == 'stereo': + n_channels = 2 + else: + match = re.match(r'(\d+) ', mode) + n_channels = match and int(match.group(1)) or 1 + return n_channels, sample_rate + + +class _CommandBuilder(object): + """ A simple builder pattern class for CLI string. """ + + def __init__(self, binary): + """ Default constructor. """ + self._command = [binary] + + def flag(self, flag): + """ Add flag or unlabelled opt. """ + self._command.append(flag) + return self + + def opt(self, short, value, formatter=str): + """ Add option if value not None. """ + if value is not None: + self._command.append(short) + self._command.append(formatter(value)) + return self + + def command(self): + """ Build string command. """ + return self._command + + +class FFMPEGProcessAudioAdapter(AudioAdapter): + """ An AudioAdapter implementation that use FFMPEG binary through + subprocess in order to perform I/O operation for audio processing. + + When created, FFMPEG binary path will be checked and expended, + raising exception if not found. Such path could be infered using + FFMPEG_PATH environment variable. + """ + + def __init__(self): + """ Default constructor. """ + self._ffmpeg_path = _get_ffmpeg_path() + + def _get_command_builder(self): + """ Creates and returns a command builder using FFMPEG path. + + :returns: Built command builder. + """ + return _CommandBuilder(self._ffmpeg_path) + + def load( + self, path, offset=None, duration=None, + sample_rate=None, dtype=np.float32): + """ Loads the audio file denoted by the given path + and returns it data as a waveform. + + :param path: Path of the audio file to load data from. + :param offset: (Optional) Start offset to load from in seconds. + :param duration: (Optional) Duration to load in seconds. + :param sample_rate: (Optional) Sample rate to load audio with. + :param dtype: (Optional) Numpy data type to use, default to float32. + :returns: Loaded data a (waveform, sample_rate) tuple. + """ + if not isinstance(path, str): + path = path.decode() + command = ( + self._get_command_builder() + .opt('-ss', offset, formatter=_to_ffmpeg_time) + .opt('-t', duration, formatter=_to_ffmpeg_time) + .opt('-i', path) + .opt('-ar', sample_rate) + .opt('-f', 'f32le') + .flag('-') + .command()) + process = subprocess.Popen( + command, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE) + buffer = process.stdout.read(-1) + # Read STDERR until end of the process detected. + while True: + status = process.stderr.readline() + if not status: + raise OSError('Stream info not found') + if isinstance(status, bytes): # Note: Python 3 compatibility. + status = status.decode('utf8', 'ignore') + status = status.strip().lower() + if 'no such file' in status: + raise IOError(f'File {path} not found') + elif 'invalid data found' in status: + raise IOError(f'FFMPEG error : {status}') + elif 'audio:' in status: + n_channels, ffmpeg_sample_rate = _parse_ffmpg_results(status) + if sample_rate is None: + sample_rate = ffmpeg_sample_rate + break + # Load waveform and clean process. + waveform = np.frombuffer(buffer, dtype='0, default to 1. + :param mehtod: (Optional) Interpolation method, default to BILINEAR. + :returns: Time stretched spectrogram as tensor with same shape. + """ + T = tf.shape(spectrogram)[0] + T_ts = tf.cast(tf.cast(T, tf.float32) * factor, tf.int32)[0] + F = tf.shape(spectrogram)[1] + ts_spec = tf.image.resize_images( + spectrogram, + [T_ts, F], + method=method, + align_corners=True) + return tf.image.resize_image_with_crop_or_pad(ts_spec, T, F) + + +def random_time_stretch(spectrogram, factor_min=0.9, factor_max=1.1, **kwargs): + """ Time stretch a spectrogram preserving shape with random ratio in + tensorflow. Applies time_stretch to spectrogram with a random ratio drawn + uniformly in [factor_min, factor_max]. + + :param spectrogram: Input spectrogram to be time stretched as tensor. + :param factor_min: (Optional) Min time stretch factor, default to 0.9. + :param factor_max: (Optional) Max time stretch factor, default to 1.1. + :returns: Randomly time stretched spectrogram as tensor with same shape. + """ + factor = tf.random_uniform( + shape=(1,), + seed=0) * (factor_max - factor_min) + factor_min + return time_stretch(spectrogram, factor=factor, **kwargs) + + +def pitch_shift( + spectrogram, + semitone_shift=0.0, + method=tf.image.ResizeMethod.BILINEAR): + """ Pitch shift a spectrogram preserving shape in tensorflow. Note that + this is an approximation in the frequency domain. + + :param spectrogram: Input spectrogram to be pitch shifted as tensor. + :param semitone_shift: (Optional) Pitch shift in semitone, default to 0.0. + :param mehtod: (Optional) Interpolation method, default to BILINEAR. + :returns: Pitch shifted spectrogram (same shape as spectrogram). + """ + factor = 2 ** (semitone_shift / 12.) + T = tf.shape(spectrogram)[0] + F = tf.shape(spectrogram)[1] + F_ps = tf.cast(tf.cast(F, tf.float32) * factor, tf.int32)[0] + ps_spec = tf.image.resize_images( + spectrogram, + [T, F_ps], + method=method, + align_corners=True) + paddings = [[0, 0], [0, tf.maximum(0, F - F_ps)], [0, 0]] + return tf.pad(ps_spec[:, :F, :], paddings, 'CONSTANT') + + +def random_pitch_shift(spectrogram, shift_min=-1., shift_max=1., **kwargs): + """ Pitch shift a spectrogram preserving shape with random ratio in + tensorflow. Applies pitch_shift to spectrogram with a random shift + amount (expressed in semitones) drawn uniformly in [shift_min, shift_max]. + + :param spectrogram: Input spectrogram to be pitch shifted as tensor. + + :param shift_min: (Optional) Min pitch shift in semitone, default to -1. + :param shift_max: (Optional) Max pitch shift in semitone, default to 1. + :returns: Randomly pitch shifted spectrogram (same shape as spectrogram). + """ + semitone_shift = tf.random_uniform( + shape=(1,), + seed=0) * (shift_max - shift_min) + shift_min + return pitch_shift(spectrogram, semitone_shift=semitone_shift, **kwargs) diff --git a/spleeter/utils/configuration.py b/spleeter/utils/configuration.py new file mode 100644 index 00000000..03db2009 --- /dev/null +++ b/spleeter/utils/configuration.py @@ -0,0 +1,47 @@ +#!/usr/bin/env python +# coding: utf8 + +""" Module that provides configuration loading function. """ + +import json + +try: + import importlib.resources as loader +except ImportError: + # Try backported to PY<37 `importlib_resources`. + import importlib_resources as loader + +from os.path import exists + +from .. import resources + + +__email__ = 'research@deezer.com' +__author__ = 'Deezer Research' +__license__ = 'MIT License' + +_EMBEDDED_CONFIGURATION_PREFIX = 'spleeter:' + + +def load_configuration(descriptor): + """ Load configuration from the given descriptor. Could be + either a `spleeter:` prefixed embedded configuration name + or a file system path to read configuration from. + + :param descriptor: Configuration descriptor to use for lookup. + :returns: Loaded description as dict. + :raise ValueError: If required embedded configuration does not exists. + :raise IOError: If required configuration file does not exists. + """ + # Embedded configuration reading. + if descriptor.startswith(_EMBEDDED_CONFIGURATION_PREFIX): + name = descriptor[len(_EMBEDDED_CONFIGURATION_PREFIX):] + if not loader.is_resource(resources, f'{name}.json'): + raise ValueError(f'No embedded configuration {name} found') + with loader.open_text(resources, f'{name}.json') as stream: + return json.load(stream) + # Standard file reading. + if not exists(descriptor): + raise IOError(f'Configuration file {descriptor} not found') + with open(descriptor, 'r') as stream: + return json.load(stream) diff --git a/spleeter/utils/estimator.py b/spleeter/utils/estimator.py new file mode 100644 index 00000000..a908886f --- /dev/null +++ b/spleeter/utils/estimator.py @@ -0,0 +1,69 @@ +#!/usr/bin/env python +# coding: utf8 + +""" Utility functions for creating estimator. """ + +from pathlib import Path + +# pylint: disable=import-error +import tensorflow as tf + +from tensorflow.contrib import predictor +# pylint: enable=import-error + +from ..model import model_fn +from ..model.provider import get_default_model_provider + +# Default exporting directory for predictor. +DEFAULT_EXPORT_DIRECTORY = '/tmp/serving' + + +def create_estimator(params, MWF): + """ + Initialize tensorflow estimator that will perform separation + + Params: + - params: a dictionnary of parameters for building the model + + Returns: + a tensorflow estimator + """ + # Load model. + model_directory = params['model_dir'] + model_provider = get_default_model_provider() + params['model_dir'] = model_provider.get(model_directory) + params['MWF'] = MWF + # Setup config + session_config = tf.compat.v1.ConfigProto() + session_config.gpu_options.per_process_gpu_memory_fraction = 0.7 + config = tf.estimator.RunConfig(session_config=session_config) + # Setup estimator + estimator = tf.estimator.Estimator( + model_fn=model_fn, + model_dir=params['model_dir'], + params=params, + config=config + ) + return estimator + + +def to_predictor(estimator, directory=DEFAULT_EXPORT_DIRECTORY): + """ Exports given estimator as predictor into the given directory + and returns associated tf.predictor instance. + + :param estimator: Estimator to export. + :param directory: (Optional) path to write exported model into. + """ + def receiver(): + shape = (None, estimator.params['n_channels']) + features = { + 'waveform': tf.compat.v1.placeholder(tf.float32, shape=shape), + 'audio_id': tf.compat.v1.placeholder(tf.string)} + return tf.estimator.export.ServingInputReceiver(features, features) + + estimator.export_saved_model(directory, receiver) + versions = [ + model for model in Path(directory).iterdir() + if model.is_dir() and 'temp' not in str(model)] + latest = str(sorted(versions)[-1]) + return predictor.from_saved_model(latest) diff --git a/spleeter/utils/logging.py b/spleeter/utils/logging.py new file mode 100644 index 00000000..031e0c34 --- /dev/null +++ b/spleeter/utils/logging.py @@ -0,0 +1,45 @@ +#!/usr/bin/env python +# coding: utf8 + +""" Centralized logging facilities for Spleeter. """ + +from os import environ + +__email__ = 'research@deezer.com' +__author__ = 'Deezer Research' +__license__ = 'MIT License' + + +class _LoggerHolder(object): + """ Logger singleton instance holder. """ + + INSTANCE = None + + +def get_logger(): + """ Returns library scoped logger. + + :returns: Library logger. + """ + if _LoggerHolder.INSTANCE is None: + # pylint: disable=import-error + from tensorflow.compat.v1 import logging + # pylint: enable=import-error + _LoggerHolder.INSTANCE = logging + _LoggerHolder.INSTANCE.set_verbosity(_LoggerHolder.INSTANCE.ERROR) + environ['TF_CPP_MIN_LOG_LEVEL'] = '3' + return _LoggerHolder.INSTANCE + + +def enable_logging(): + """ Enable INFO level logging. """ + environ['TF_CPP_MIN_LOG_LEVEL'] = '1' + logger = get_logger() + logger.set_verbosity(logger.INFO) + + +def enable_verbose_logging(): + """ Enable DEBUG level logging. """ + environ['TF_CPP_MIN_LOG_LEVEL'] = '0' + logger = get_logger() + logger.set_verbosity(logger.DEBUG) diff --git a/spleeter/utils/tensor.py b/spleeter/utils/tensor.py new file mode 100644 index 00000000..402548c2 --- /dev/null +++ b/spleeter/utils/tensor.py @@ -0,0 +1,191 @@ +#!/usr/bin/env python +# coding: utf8 + +""" Utility function for tensorflow. """ + +# pylint: disable=import-error +import tensorflow as tf +import pandas as pd +# pylint: enable=import-error + +__email__ = 'research@deezer.com' +__author__ = 'Deezer Research' +__license__ = 'MIT License' + + +def sync_apply(tensor_dict, func, concat_axis=1): + """ Return a function that applies synchronously the provided func on the + provided dictionnary of tensor. This means that func is applied to the + concatenation of the tensors in tensor_dict. This is useful for performing + random operation that needs the same drawn value on multiple tensor, such + as a random time-crop on both input data and label (the same crop should be + applied to both input data and label, so random crop cannot be applied + separately on each of them). + + IMPORTANT NOTE: all tensor are assumed to be the same shape. + + Params: + - tensor_dict: dictionary (key: strings, values: tf.tensor) + a dictionary of tensor. + - func: function + function to be applied to the concatenation of the tensors in + tensor_dict + - concat_axis: int + The axis on which to perform the concatenation. + + Returns: + processed tensors dictionary with the same name (keys) as input + tensor_dict. + """ + if concat_axis not in {0, 1}: + raise NotImplementedError( + 'Function only implemented for concat_axis equal to 0 or 1') + tensor_list = list(tensor_dict.values()) + concat_tensor = tf.concat(tensor_list, concat_axis) + processed_concat_tensor = func(concat_tensor) + tensor_shape = tf.shape(list(tensor_dict.values())[0]) + D = tensor_shape[concat_axis] + if concat_axis == 0: + return { + name: processed_concat_tensor[index * D:(index + 1) * D, :, :] + for index, name in enumerate(tensor_dict) + } + return { + name: processed_concat_tensor[:, index * D:(index + 1) * D, :] + for index, name in enumerate(tensor_dict) + } + + +def from_float32_to_uint8( + tensor, + tensor_key='tensor', + min_key='min', + max_key='max'): + """ + + :param tensor: + :param tensor_key: + :param min_key: + :param max_key: + :returns: + """ + tensor_min = tf.reduce_min(tensor) + tensor_max = tf.reduce_max(tensor) + return { + tensor_key: tf.cast( + (tensor - tensor_min) / (tensor_max - tensor_min + 1e-16) + * 255.9999, dtype=tf.uint8), + min_key: tensor_min, + max_key: tensor_max + } + + +def from_uint8_to_float32(tensor, tensor_min, tensor_max): + """ + + :param tensor: + :param tensor_min: + :param tensor_max: + :returns: + """ + return ( + tf.cast(tensor, tf.float32) + * (tensor_max - tensor_min) + / 255.9999 + tensor_min) + + +def pad_and_partition(tensor, segment_len): + """ Pad and partition a tensor into segment of len segment_len + along the first dimension. The tensor is padded with 0 in order + to ensure that the first dimension is a multiple of segment_len. + + Tensor must be of known fixed rank + + :Example: + + >>> tensor = [[1, 2, 3], [4, 5, 6]] + >>> segment_len = 2 + >>> pad_and_partition(tensor, segment_len) + [[[1, 2], [4, 5]], [[3, 0], [6, 0]]] + + :param tensor: + :param segment_len: + :returns: + """ + tensor_size = tf.math.floormod(tf.shape(tensor)[0], segment_len) + pad_size = tf.math.floormod(segment_len - tensor_size, segment_len) + padded = tf.pad( + tensor, + [[0, pad_size]] + [[0, 0]] * (len(tensor.shape)-1)) + split = (tf.shape(padded)[0] + segment_len - 1) // segment_len + return tf.reshape( + padded, + tf.concat( + [[split, segment_len], tf.shape(padded)[1:]], + axis=0)) + + +def pad_and_reshape(instr_spec, frame_length, F): + """ + :param instr_spec: + :param frame_length: + :param F: + :returns: + """ + spec_shape = tf.shape(instr_spec) + extension_row = tf.zeros((spec_shape[0], spec_shape[1], 1, spec_shape[-1])) + n_extra_row = (frame_length) // 2 + 1 - F + extension = tf.tile(extension_row, [1, 1, n_extra_row, 1]) + extended_spec = tf.concat([instr_spec, extension], axis=2) + old_shape = tf.shape(extended_spec) + new_shape = tf.concat([ + [old_shape[0] * old_shape[1]], + old_shape[2:]], + axis=0) + processed_instr_spec = tf.reshape(extended_spec, new_shape) + return processed_instr_spec + + +def dataset_from_csv(csv_path, **kwargs): + """ Load dataset from a CSV file using Pandas. kwargs if any are + forwarded to the `pandas.read_csv` function. + + :param csv_path: Path of the CSV file to load dataset from. + :returns: Loaded dataset. + """ + df = pd.read_csv(csv_path, **kwargs) + dataset = ( + tf.data.Dataset.from_tensor_slices( + {key: df[key].values for key in df}) + ) + return dataset + + +def check_tensor_shape(tensor_tf, target_shape): + """ Return a Tensorflow boolean graph that indicates whether + sample[features_key] has the specified target shape. Only check + not None entries of target_shape. + + :param tensor_tf: Tensor to check shape for. + :param target_shape: Target shape to compare tensor to. + :returns: True if shape is valid, False otherwise (as TF boolean). + """ + result = tf.constant(True) + for i, target_length in enumerate(target_shape): + if target_length: + result = tf.logical_and( + result, + tf.equal(tf.constant(target_length), tf.shape(tensor_tf)[i])) + return result + + +def set_tensor_shape(tensor, tensor_shape): + """ Set shape for a tensor (not in place, as opposed to tf.set_shape) + + :param tensor: Tensor to reshape. + :param tensor_shape: Shape to apply to the tensor. + :returns: A reshaped tensor. + """ + # NOTE: That SOUND LIKE IN PLACE HERE ? + tensor.set_shape(tensor_shape) + return tensor