From 68e05afff486f85218cf4192803fbd287e25b05f Mon Sep 17 00:00:00 2001 From: jorshi Date: Tue, 2 Jul 2024 11:25:02 +0100 Subject: [PATCH] initial commit -- moving dev repo to release --- .github/workflows/test.yaml | 33 ++ .gitignore | 180 ++++++++ .pre-commit-config.yaml | 23 + LICENSE | 201 ++++++++ README.md | 81 ++++ cfg/data/carson_pop_damp.yaml | 6 + cfg/data/carson_pop_no_damp.yaml | 6 + cfg/data/carson_supra_damp.yaml | 6 + cfg/data/carson_supra_no_damp.yaml | 6 + cfg/data/snare_onset.yaml | 10 + cfg/features/onset_feature_full.yaml | 33 ++ cfg/features/onset_features.yaml | 14 + cfg/models/linear_mapper.yaml | 8 + cfg/models/mlp_mapper.yaml | 11 + cfg/models/mlp_mapper_lrg.yaml | 11 + cfg/onset_mapping_808.yaml | 32 ++ cfg/onset_mapping_808_linear.yaml | 32 ++ cfg/onset_mapping_808_lrg.yaml | 32 ++ cfg/presets/808_noisy_snare.json | 22 + cfg/presets/808_open_snare.json | 22 + cfg/presets/808_snare_1.json | 22 + cfg/presets/808_snare_2.json | 22 + cfg/presets/808_snare_3.json | 22 + cfg/presets/808_tom_1.json | 22 + cfg/synths/snare_808.yaml | 6 + pyproject.toml | 46 ++ scripts/direct_optimize.sh | 5 + scripts/results.py | 196 ++++++++ scripts/train_linear.sh | 14 + scripts/train_mlp.sh | 14 + scripts/train_mlp_lrg.sh | 14 + setup.cfg | 8 + test/__init__.py | 0 test/conftest.py | 3 + test/test_data.py | 62 +++ test/test_feature.py | 84 ++++ test/test_model.py | 30 ++ test/test_np_core.py | 51 +++ test/test_synth.py | 45 ++ test/test_tasks.py | 80 ++++ timbreremap/__init__.py | 0 timbreremap/callback.py | 162 +++++++ timbreremap/cli.py | 208 +++++++++ timbreremap/data.py | 313 +++++++++++++ timbreremap/export.py | 28 ++ timbreremap/feature.py | 519 +++++++++++++++++++++ timbreremap/loss.py | 27 ++ timbreremap/model.py | 102 +++++ timbreremap/np/__init__.py | 6 + timbreremap/np/core.py | 174 +++++++ timbreremap/np/features.py | 45 ++ timbreremap/optuna.py | 134 ++++++ timbreremap/synth.py | 654 +++++++++++++++++++++++++++ timbreremap/tasks.py | 178 ++++++++ timbreremap/utils/model.py | 51 +++ 55 files changed, 4116 insertions(+) create mode 100644 .github/workflows/test.yaml create mode 100644 .gitignore create mode 100644 .pre-commit-config.yaml create mode 100644 LICENSE create mode 100644 README.md create mode 100644 cfg/data/carson_pop_damp.yaml create mode 100644 cfg/data/carson_pop_no_damp.yaml create mode 100644 cfg/data/carson_supra_damp.yaml create mode 100644 cfg/data/carson_supra_no_damp.yaml create mode 100644 cfg/data/snare_onset.yaml create mode 100644 cfg/features/onset_feature_full.yaml create mode 100644 cfg/features/onset_features.yaml create mode 100644 cfg/models/linear_mapper.yaml create mode 100644 cfg/models/mlp_mapper.yaml create mode 100644 cfg/models/mlp_mapper_lrg.yaml create mode 100644 cfg/onset_mapping_808.yaml create mode 100644 cfg/onset_mapping_808_linear.yaml create mode 100644 cfg/onset_mapping_808_lrg.yaml create mode 100644 cfg/presets/808_noisy_snare.json create mode 100644 cfg/presets/808_open_snare.json create mode 100644 cfg/presets/808_snare_1.json create mode 100644 cfg/presets/808_snare_2.json create mode 100644 cfg/presets/808_snare_3.json create mode 100644 cfg/presets/808_tom_1.json create mode 100644 cfg/synths/snare_808.yaml create mode 100644 pyproject.toml create mode 100755 scripts/direct_optimize.sh create mode 100644 scripts/results.py create mode 100755 scripts/train_linear.sh create mode 100755 scripts/train_mlp.sh create mode 100755 scripts/train_mlp_lrg.sh create mode 100644 setup.cfg create mode 100644 test/__init__.py create mode 100644 test/conftest.py create mode 100644 test/test_data.py create mode 100644 test/test_feature.py create mode 100644 test/test_model.py create mode 100644 test/test_np_core.py create mode 100644 test/test_synth.py create mode 100644 test/test_tasks.py create mode 100644 timbreremap/__init__.py create mode 100644 timbreremap/callback.py create mode 100644 timbreremap/cli.py create mode 100644 timbreremap/data.py create mode 100644 timbreremap/export.py create mode 100644 timbreremap/feature.py create mode 100644 timbreremap/loss.py create mode 100644 timbreremap/model.py create mode 100644 timbreremap/np/__init__.py create mode 100644 timbreremap/np/core.py create mode 100644 timbreremap/np/features.py create mode 100644 timbreremap/optuna.py create mode 100644 timbreremap/synth.py create mode 100644 timbreremap/tasks.py create mode 100644 timbreremap/utils/model.py diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml new file mode 100644 index 0000000..f6ff9ab --- /dev/null +++ b/.github/workflows/test.yaml @@ -0,0 +1,33 @@ +name: Test + +on: + push: + branches: ["main"] + pull_request: + branches: ["main"] + +permissions: + contents: read + +jobs: + build: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - name: Intsall linux dependencies + run: | + sudo apt-get update + sudo apt-get install -y ffmpeg + - name: Set up Python 3.10 + uses: actions/setup-python@v4 + with: + python-version: "3.10" + cache: "pip" # cache pip dependencies + cache-dependency-path: "**/pyproject.toml" # cache dependencies based on pyproject.toml + - name: Install python dependencies + run: | + python -m pip install --upgrade pip + pip install ".[dev]" + - name: Test with pytest + run: | + pytest diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..78adeb5 --- /dev/null +++ b/.gitignore @@ -0,0 +1,180 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ + +audio/ +.DS_Store +*.code-workspace + +# Trained model files +*.pt + +# Logs +lightning_logs/ + +# Results (for now) +results/ + +# Archives +archives/ + +# Numerical experiments +experiment/ +table.tex diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..c462ff4 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,23 @@ +repos: + - repo: https://github.com/kynan/nbstripout + rev: 0.6.0 + hooks: + - id: nbstripout + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.4.0 + hooks: + - id: check-yaml + - id: end-of-file-fixer + - id: trailing-whitespace + - repo: https://github.com/psf/black + rev: 23.1.0 + hooks: + - id: black + - repo: https://github.com/pycqa/flake8 + rev: "6.0.0" # pick a git hash / tag to point to + hooks: + - id: flake8 + - repo: https://github.com/asottile/reorder_python_imports + rev: "v3.9.0" + hooks: + - id: reorder-python-imports diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..261eeb9 --- /dev/null +++ b/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/README.md b/README.md new file mode 100644 index 0000000..b065f84 --- /dev/null +++ b/README.md @@ -0,0 +1,81 @@ +
+ +# Real-time Timbre Remapping with Differentiable DSP + + +[![Demo](https://img.shields.io/badge/Web-Audio_Examples-blue)](https://jordieshier.com/projects/nime2024/) +[![Paper](https://img.shields.io/badge/PDF-Paper-green)](#) + +[Jordie Shier](https://jordieshier.com), [Charalampos Saitis](http://eecs.qmul.ac.uk/people/profiles/saitischaralampos.html), Andrew Robertson, and [Andrew McPherson](https://www.imperial.ac.uk/people/andrew.mcpherson) + +
+ +This repository contains training code for our NIME 2024 paper *Real-time Timbre Remapping with Differentiable DSP*. +This research explored the application of differentiable digital signal processing (DDSP) towards +timbral control of an audio synthesizer. In this work we explored mapping timbral changes represented +as audio features extracted from an input audio (i.e., drums) to synthesizer parameter modulations. To enable real-time control we introduced neural networks which learn +to map from short windows of audio features extracted from a detected onset to synthesizer +parameters. This allows for real-time control of a synthesizer from an audio input. + + +## Install +Clone the repo and then install the `timbreremap` package. Requires Python version 3.9 or greater. + +```bash +pip install --upgrade pip +pip install -e . +``` + +## Example Training and Inference + +Download snare drum performance audio: + +```bash +mkdir audio +cd audio +wget https://pub-814e66019388451395cf43c0b6f10300.r2.dev/carson.zip +unzip carson.zip +cd .. +``` + +## Numerical Experiments + +Instructions to reproduce numerical results from the NIME 2024 paper. + +### Dataset + +Download the [Snare Drum Data Set (SDSS)](https://aes2.org/publications/elibrary-page/?id=20912). +We used a subset of this dataset, which can be downloaded as follows: + +```bash +mkdir audio +cd audio +wget https://pub-814e66019388451395cf43c0b6f10300.r2.dev/sdss_filtered.zip +unzip sdss_filtered.zip +``` + +### Training + +The following scripts will run a series of trainings iterating over the snare drum +dataset and five different synthesizer presets. In total, 240 models are trained for +each mapping algorithm. +Each model takes around 2min to train on a GPU, which means training all models will take around 24 hours. + +```bash +./scripts/train_linear.sh && ./scripts/train_mlp.sh && ./scripts/train_mlp_lrg.sh +``` + +To run the baseline, which involves no neural network, just estimating the synthesis parameter directly using gradient descent: + +```bash +./scripts/direct_optimize.sh +``` + +### Results +To compile results from the numerical experiments into a summary table: + +```bash +python scripts/results.py experiment +``` + +Which will output a file named `table.tex` with a table similar to the one presented in the paper. diff --git a/cfg/data/carson_pop_damp.yaml b/cfg/data/carson_pop_damp.yaml new file mode 100644 index 0000000..1bd8175 --- /dev/null +++ b/cfg/data/carson_pop_damp.yaml @@ -0,0 +1,6 @@ +class_path: timbreremap.data.OnsetFeatureDataModule +init_args: + audio_path: "audio/carson/popcorn_snare_damp.wav" + feature: ../features/onset_feature_full.yaml + onset_feature: ../features/onset_features.yaml + batch_size: 16 diff --git a/cfg/data/carson_pop_no_damp.yaml b/cfg/data/carson_pop_no_damp.yaml new file mode 100644 index 0000000..a8d2757 --- /dev/null +++ b/cfg/data/carson_pop_no_damp.yaml @@ -0,0 +1,6 @@ +class_path: timbreremap.data.OnsetFeatureDataModule +init_args: + audio_path: "audio/carson/popcorn_snare_no_damp.wav" + feature: ../features/onset_feature_full.yaml + onset_feature: ../features/onset_features.yaml + batch_size: 16 diff --git a/cfg/data/carson_supra_damp.yaml b/cfg/data/carson_supra_damp.yaml new file mode 100644 index 0000000..a4de4d3 --- /dev/null +++ b/cfg/data/carson_supra_damp.yaml @@ -0,0 +1,6 @@ +class_path: timbreremap.data.OnsetFeatureDataModule +init_args: + audio_path: "audio/carson/supra_snare_damp.wav" + feature: ../features/onset_feature_full.yaml + onset_feature: ../features/onset_features.yaml + batch_size: 16 diff --git a/cfg/data/carson_supra_no_damp.yaml b/cfg/data/carson_supra_no_damp.yaml new file mode 100644 index 0000000..6675d67 --- /dev/null +++ b/cfg/data/carson_supra_no_damp.yaml @@ -0,0 +1,6 @@ +class_path: timbreremap.data.OnsetFeatureDataModule +init_args: + audio_path: "audio/carson/supra_snare_no_damp.wav" + feature: ../features/onset_feature_full.yaml + onset_feature: ../features/onset_features.yaml + batch_size: 16 diff --git a/cfg/data/snare_onset.yaml b/cfg/data/snare_onset.yaml new file mode 100644 index 0000000..518c181 --- /dev/null +++ b/cfg/data/snare_onset.yaml @@ -0,0 +1,10 @@ +class_path: timbreremap.data.OnsetFeatureDataModule +init_args: + audio_path: "audio/sdss_filtered/Premier_BigFatSnare" + feature: ../features/onset_feature_full.yaml + onset_feature: ../features/onset_features.yaml + batch_size: 16 + center_onset: true + val_split: 0.1 + test_split: 0.1 + return_norm: false diff --git a/cfg/features/onset_feature_full.yaml b/cfg/features/onset_feature_full.yaml new file mode 100644 index 0000000..490e193 --- /dev/null +++ b/cfg/features/onset_feature_full.yaml @@ -0,0 +1,33 @@ +class_path: timbreremap.feature.FeatureCollection +init_args: + features: + - class_path: timbreremap.feature.CascadingFrameExtactor + init_args: + extractors: + - class_path: timbreremap.feature.Loudness + init_args: + sample_rate: 44100 + - class_path: timbreremap.feature.SpectralCentroid + init_args: + sample_rate: 44100 + window: "flat_top" + compress: true + floor: 1e-4 + scaling: "kazazis" + - class_path: timbreremap.feature.SpectralFlatness + num_frames: + - 2 + - 64 + frame_size: 2048 + hop_size: 512 + - class_path: timbreremap.feature.CascadingFrameExtactor + init_args: + extractors: + - class_path: timbreremap.feature.TemporalCentroid + init_args: + sample_rate: 44100 + scaling: "schlauch" + num_frames: + - 1 + frame_size: 5512 + hop_size: 5512 diff --git a/cfg/features/onset_features.yaml b/cfg/features/onset_features.yaml new file mode 100644 index 0000000..7509a5f --- /dev/null +++ b/cfg/features/onset_features.yaml @@ -0,0 +1,14 @@ +class_path: timbreremap.feature.CascadingFrameExtactor +init_args: + extractors: + - class_path: timbreremap.feature.RMS + init_args: + db: True + - class_path: timbreremap.feature.SpectralCentroid + init_args: + sample_rate: 44100 + - class_path: timbreremap.feature.SpectralFlatness + num_frames: + - 1 + frame_size: 256 + hop_size: 256 diff --git a/cfg/models/linear_mapper.yaml b/cfg/models/linear_mapper.yaml new file mode 100644 index 0000000..26f2df2 --- /dev/null +++ b/cfg/models/linear_mapper.yaml @@ -0,0 +1,8 @@ +class_path: timbreremap.model.LinearMapping +init_args: + in_size: 3 + out_size: 14 + bias: false + input_bias: 0.0 + clamp: true + init_std: 1e-6 diff --git a/cfg/models/mlp_mapper.yaml b/cfg/models/mlp_mapper.yaml new file mode 100644 index 0000000..9fe223b --- /dev/null +++ b/cfg/models/mlp_mapper.yaml @@ -0,0 +1,11 @@ +class_path: timbreremap.model.MLP +init_args: + in_size: 3 + hidden_size: 32 + out_size: 14 + num_layers: 1 + activation: torch.nn.ReLU + input_bias: 0.0 + layer_norm: true + init_std: 0.001 + scale_output: true diff --git a/cfg/models/mlp_mapper_lrg.yaml b/cfg/models/mlp_mapper_lrg.yaml new file mode 100644 index 0000000..c56251d --- /dev/null +++ b/cfg/models/mlp_mapper_lrg.yaml @@ -0,0 +1,11 @@ +class_path: timbreremap.model.MLP +init_args: + in_size: 3 + hidden_size: 64 + out_size: 14 + num_layers: 3 + activation: torch.nn.ReLU + input_bias: 0.0 + layer_norm: true + init_std: 0.001 + scale_output: true diff --git a/cfg/onset_mapping_808.yaml b/cfg/onset_mapping_808.yaml new file mode 100644 index 0000000..e3b996a --- /dev/null +++ b/cfg/onset_mapping_808.yaml @@ -0,0 +1,32 @@ +model: + class_path: timbreremap.tasks.TimbreRemappingTask + init_args: + model: models/mlp_mapper.yaml + synth: synths/snare_808.yaml + feature: features/onset_feature_full.yaml + loss_fn: timbreremap.loss.FeatureDifferenceLoss + preset: cfg/presets/808_snare_1.json +optimizer: + class_path: torch.optim.Adam + init_args: + lr: 0.0005 +lr_scheduler: + class_path: lightning.pytorch.cli.ReduceLROnPlateau + init_args: + monitor: val/loss + factor: 0.5 + patience: 20 + verbose: true +data: data/snare_onset.yaml +trainer: + devices: 1 + accelerator: gpu + # accelerator: cpu + max_epochs: 250 + callbacks: + - class_path: timbreremap.callback.SaveAudioCallback + - class_path: timbreremap.callback.SaveTorchScriptCallback + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step +seed_everything: 3608013887 diff --git a/cfg/onset_mapping_808_linear.yaml b/cfg/onset_mapping_808_linear.yaml new file mode 100644 index 0000000..d986fbe --- /dev/null +++ b/cfg/onset_mapping_808_linear.yaml @@ -0,0 +1,32 @@ +model: + class_path: timbreremap.tasks.TimbreRemappingTask + init_args: + model: models/linear_mapper.yaml + synth: synths/snare_808.yaml + feature: features/onset_feature_full.yaml + loss_fn: timbreremap.loss.FeatureDifferenceLoss + preset: cfg/presets/808_snare_1.json +optimizer: + class_path: torch.optim.Adam + init_args: + lr: 0.005 +lr_scheduler: + class_path: lightning.pytorch.cli.ReduceLROnPlateau + init_args: + monitor: val/loss + factor: 0.5 + patience: 20 + verbose: true +data: data/snare_onset.yaml +trainer: + devices: 1 + accelerator: gpu + # accelerator: cpu + max_epochs: 250 + callbacks: + - class_path: timbreremap.callback.SaveAudioCallback + - class_path: timbreremap.callback.SaveTorchScriptCallback + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step +seed_everything: 3608013887 diff --git a/cfg/onset_mapping_808_lrg.yaml b/cfg/onset_mapping_808_lrg.yaml new file mode 100644 index 0000000..c9c2699 --- /dev/null +++ b/cfg/onset_mapping_808_lrg.yaml @@ -0,0 +1,32 @@ +model: + class_path: timbreremap.tasks.TimbreRemappingTask + init_args: + model: models/mlp_mapper_lrg.yaml + synth: synths/snare_808.yaml + feature: features/onset_feature_full.yaml + loss_fn: timbreremap.loss.FeatureDifferenceLoss + preset: cfg/presets/808_snare_1.json +optimizer: + class_path: torch.optim.Adam + init_args: + lr: 0.0005 +lr_scheduler: + class_path: lightning.pytorch.cli.ReduceLROnPlateau + init_args: + monitor: val/loss + factor: 0.5 + patience: 20 + verbose: true +data: data/snare_onset.yaml +trainer: + devices: 1 + accelerator: gpu + # accelerator: cpu + max_epochs: 250 + callbacks: + - class_path: timbreremap.callback.SaveAudioCallback + - class_path: timbreremap.callback.SaveTorchScriptCallback + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step +seed_everything: 3608013887 diff --git a/cfg/presets/808_noisy_snare.json b/cfg/presets/808_noisy_snare.json new file mode 100644 index 0000000..106fe44 --- /dev/null +++ b/cfg/presets/808_noisy_snare.json @@ -0,0 +1,22 @@ +{ + "preset": { + "osc1.freq": 500.0, + "osc1.mod": 0.42, + "osc2.freq": 700.0, + "osc2.mod": 1.0, + "freq_env.decay": 787.0, + "osc1_env.decay": 100.0, + "osc2_env.decay": 200.0, + "noise_env.decay": 1949.0, + "noise_filter.freq": 3915.0, + "noise_filter.q": 6.0, + "osc1_gain.gain": -50.0, + "osc2_gain.gain": -50.0, + "noise_gain.gain": -25.0, + "tanh.in_gain": 13.0 + }, + "damping": { + "osc1.freq": 0.01, + "osc2.freq": 0.01 + } +} diff --git a/cfg/presets/808_open_snare.json b/cfg/presets/808_open_snare.json new file mode 100644 index 0000000..3b185d3 --- /dev/null +++ b/cfg/presets/808_open_snare.json @@ -0,0 +1,22 @@ +{ + "preset": { + "osc1.freq": 135.0, + "osc1.mod": 0.43, + "osc2.freq": 236.0, + "osc2.mod": 0.53, + "freq_env.decay": 119.0, + "osc1_env.decay": 558.0, + "osc2_env.decay": 43.0, + "noise_env.decay": 100.0, + "noise_filter.freq": 10000.0, + "noise_filter.q": 1.2, + "osc1_gain.gain": -10.0, + "osc2_gain.gain": -14.0, + "noise_gain.gain": -53.0, + "tanh.in_gain": -3.0 + }, + "damping": { + "osc1.freq": 0.01, + "osc2.freq": 0.01 + } +} diff --git a/cfg/presets/808_snare_1.json b/cfg/presets/808_snare_1.json new file mode 100644 index 0000000..c514759 --- /dev/null +++ b/cfg/presets/808_snare_1.json @@ -0,0 +1,22 @@ +{ + "preset": { + "osc1.freq": 180.0, + "osc1.mod": 0.4, + "osc2.freq": 240.0, + "osc2.mod": 0.3, + "freq_env.decay": 50.0, + "osc1_env.decay": 200.0, + "osc2_env.decay": 150.0, + "noise_env.decay": 200.0, + "noise_filter.freq": 2500.0, + "noise_filter.q": 2.0, + "osc1_gain.gain": -6.0, + "osc2_gain.gain": -12.0, + "noise_gain.gain": -4.0, + "tanh.in_gain": -12.0 + }, + "damping": { + "osc1.freq": 0.01, + "osc2.freq": 0.01 + } +} diff --git a/cfg/presets/808_snare_2.json b/cfg/presets/808_snare_2.json new file mode 100644 index 0000000..35159f5 --- /dev/null +++ b/cfg/presets/808_snare_2.json @@ -0,0 +1,22 @@ +{ + "preset": { + "osc1.freq": 179.0, + "osc1.mod": 0.1, + "osc2.freq": 366.0, + "osc2.mod": 0.34, + "freq_env.decay": 21.59, + "osc1_env.decay": 133.0, + "osc2_env.decay": 106.5, + "noise_env.decay": 108.0, + "noise_filter.freq": 2295.0, + "noise_filter.q": 0.5, + "osc1_gain.gain": -18.0, + "osc2_gain.gain": -20.0, + "noise_gain.gain": -10.0, + "tanh.in_gain": 1.0 + }, + "damping": { + "osc1.freq": 0.01, + "osc2.freq": 0.01 + } +} diff --git a/cfg/presets/808_snare_3.json b/cfg/presets/808_snare_3.json new file mode 100644 index 0000000..62f362d --- /dev/null +++ b/cfg/presets/808_snare_3.json @@ -0,0 +1,22 @@ +{ + "preset": { + "osc1.freq": 173.0, + "osc1.mod": 0.66, + "osc2.freq": 351.0, + "osc2.mod": 0.24, + "freq_env.decay": 33.19, + "osc1_env.decay": 152.07, + "osc2_env.decay": 91.15, + "noise_env.decay": 269.0, + "noise_filter.freq": 1500.0, + "noise_filter.q": 1.0, + "osc1_gain.gain": -14.0, + "osc2_gain.gain": -13.0, + "noise_gain.gain": -44.0, + "tanh.in_gain": -3.5 + }, + "damping": { + "osc1.freq": 0.01, + "osc2.freq": 0.01 + } +} diff --git a/cfg/presets/808_tom_1.json b/cfg/presets/808_tom_1.json new file mode 100644 index 0000000..3b185d3 --- /dev/null +++ b/cfg/presets/808_tom_1.json @@ -0,0 +1,22 @@ +{ + "preset": { + "osc1.freq": 135.0, + "osc1.mod": 0.43, + "osc2.freq": 236.0, + "osc2.mod": 0.53, + "freq_env.decay": 119.0, + "osc1_env.decay": 558.0, + "osc2_env.decay": 43.0, + "noise_env.decay": 100.0, + "noise_filter.freq": 10000.0, + "noise_filter.q": 1.2, + "osc1_gain.gain": -10.0, + "osc2_gain.gain": -14.0, + "noise_gain.gain": -53.0, + "tanh.in_gain": -3.0 + }, + "damping": { + "osc1.freq": 0.01, + "osc2.freq": 0.01 + } +} diff --git a/cfg/synths/snare_808.yaml b/cfg/synths/snare_808.yaml new file mode 100644 index 0000000..d914e3e --- /dev/null +++ b/cfg/synths/snare_808.yaml @@ -0,0 +1,6 @@ +class_path: timbreremap.synth.Snare808 +init_args: + sample_rate: 44100 + num_samples: 44100 + buffer_noise: true + buffer_size: 44100 diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..9dbd5f4 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,46 @@ +[build-system] +requires = ["setuptools>=61.0"] +build-backend = "setuptools.build_meta" + +[tool.setuptools] +packages = ["timbreremap"] + +[project] +name = "timbreremap" +version = "0.0.1" +dependencies = [ + "einops", + "jsonargparse[signatures]>=4.18.0", + "lightning==2.1.3", + "matplotlib", + "numba", + "numpy<2.0", + "pandas", + "pyloudnorm", + "scipy", + "torch==2.1.2", + "torchaudio", +] + +[project.scripts] +timbreremap = "timbreremap.cli:main" +timbreremap-test = "timbreremap.cli:test_version" +timbreremap-train-sdss = "timbreremap.cli:train_sdss" +timbreremap-optuna = "timbreremap.optuna:run_optuna" +timbreremap-direct = "timbreremap.cli:direct_optimization" +timbreremap-optimize-sdss = "timbreremap.cli:optimize_sdss" + +[project.optional-dependencies] +dev = [ + "black[jupyter]", + "flake8", + "pytest", + "pytest-mock", + "pre-commit", + "pytest-cov", + "matplotlib", + "nbstripout", + "nbmake", + "ipywidgets", + "optuna", +] diff --git a/scripts/direct_optimize.sh b/scripts/direct_optimize.sh new file mode 100755 index 0000000..6e1729e --- /dev/null +++ b/scripts/direct_optimize.sh @@ -0,0 +1,5 @@ +timbreremap-optimize-sdss -c cfg/onset_mapping_808.yaml --model.preset cfg/presets/808_snare_1.json --trainer.default_root_dir experiment/test_logs_direct_opt +timbreremap-optimize-sdss -c cfg/onset_mapping_808.yaml --model.preset cfg/presets/808_snare_2.json --trainer.default_root_dir experiment/test_logs_direct_opt +timbreremap-optimize-sdss -c cfg/onset_mapping_808.yaml --model.preset cfg/presets/808_snare_3.json --trainer.default_root_dir experiment/test_logs_direct_opt +timbreremap-optimize-sdss -c cfg/onset_mapping_808.yaml --model.preset cfg/presets/808_open_snare.json --trainer.default_root_dir experiment/test_logs_direct_opt +timbreremap-optimize-sdss -c cfg/onset_mapping_808.yaml --model.preset cfg/presets/808_noisy_snare.json --trainer.default_root_dir experiment/test_logs_direct_opt diff --git a/scripts/results.py b/scripts/results.py new file mode 100644 index 0000000..8574870 --- /dev/null +++ b/scripts/results.py @@ -0,0 +1,196 @@ +""" +Compile numerical experiment results into a LaTeX table. +""" +import argparse +import os +import sys +from pathlib import Path + +import pandas as pd + +# Feature indices +IDX = [ + "SPL_T Mean", + "SPL_T Std", + "SPL_S Mean", + "SPL_S Std", + "SC_T Mean", + "SC_T Std", + "SC_S Mean", + "SC_S Std", + "SF_T Mean", + "SF_T Std", + "SF_S Mean", + "SF_S Std", + "TC Mean", + "TC Std", +] + + +def get_folder_metrics(dir): + # Initialize an empty list to store the dataframes + dfs = [] + + print(f"Processing directory: {dir} ... Found: {Path(dir).exists()}") + size = Path(dir).parent.name.split("_")[2] + # Recursively search through folders in the cwd + for root, dirs, files in os.walk(Path(dir)): + parts = Path(root).name.split("_") + if parts[0] == "version": + version = parts[1] + + for file in files: + # Check if the file is named "metrics.csv" and has a CSV extension + if file == "metrics.csv" and file.endswith(".csv"): + # Construct the file path + file_path = os.path.join(root, file) + + # Load the CSV file into a dataframe + df = pd.read_csv(file_path) + df["version"] = version + df["size"] = size + df["preset"] = None + + dfs.append(df) + + # Concatenate all the dataframes into a single dataframe + combined_df = pd.concat(dfs, ignore_index=True) + return combined_df + + +def update_columns(df): + new_columns = [] + for c in df.columns: + if "Loudness" in c[1]: + feature = "SPL" + elif "SpectralCentroid" in c[1]: + feature = "SC" + elif "SpectralFlatness" in c[1]: + feature = "SF" + + if "0_" in c[1]: + feature += "_T" + else: + feature += "_S" + + if "TemporalCentroid" in c[1]: + feature = "TC" + + if "mean" in c[0]: + feature += " Mean" + else: + feature += " Std" + + new_columns.append(feature) + df.columns = new_columns + return df + + +def convert_to_text(df): + feature = "" + f_dict = {} + c_names = df.columns + for r in df.iterrows(): + if r[0].endswith("Mean"): + feature = r[0].split(" ")[0] + values = r[1].values + elif r[0].endswith("Std"): + v = [] + for m, s in zip(values, r[1].values): + if s >= 10: + std = f"{s:.0f}" + else: + std = f"{s:.1f}" + + if m >= 10: + v.append(f"${m:.1f} \pm {std}$") # noqa W605 + else: + v.append(f"${m:.3f} \pm {std}$") # noqa W605 + + feature = f"${feature}$" + f_dict[feature] = v + + df1 = pd.DataFrame(f_dict).T + df1.columns = c_names + df1 = df1[ + [ + "preset", + "direct", + "linear", + "linear2048", + "mlp", + "mlp2048", + "mlplrg", + "mlplrg2048", + ] + ] + return df1 + + +def main(arguments): + parser = argparse.ArgumentParser( + description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter + ) + parser.add_argument("input", help="Input directory", type=str) + args = parser.parse_args(arguments) + + # Load in all the log files and combine them into a single dataframe + root = args.input + df_linear = get_folder_metrics(f"{root}/test_logs_linear/lightning_logs") + df_linear_2048 = get_folder_metrics(f"{root}/test_logs_linear2048/lightning_logs") + + df_mlp = get_folder_metrics(f"{root}/test_logs_mlp/lightning_logs") + df_mlp_2048 = get_folder_metrics(f"{root}/test_logs_mlp2048/lightning_logs") + + df_mlp_lrg = get_folder_metrics(f"{root}/test_logs_mlplrg/lightning_logs") + df_mlp_lrg_2048 = get_folder_metrics(f"{root}/test_logs_mlplrg2048/lightning_logs") + + df_direct = get_folder_metrics(f"{root}/test_logs_direct_opt/lightning_logs") + combined_df = pd.concat( + [ + df_linear, + df_linear_2048, + df_mlp, + df_mlp_2048, + df_direct, + df_mlp_lrg, + df_mlp_lrg_2048, + ], + ignore_index=True, + ) + + # Convert to a pivot table for easier processing + columns = [ + c + for c in combined_df.columns + if c not in ["epoch", "step", "test/loss", "version", "size", "preset"] + ] + columns = [c for c in columns if "pre" not in c] + df = pd.pivot_table( + combined_df, values=columns, index=["size"], aggfunc=["mean", "std"] + ) + df = update_columns(df) + df = df.T + + # Repeat for the preset + columns = [ + c + for c in combined_df.columns + if c not in ["epoch", "step", "test/loss", "version", "size", "preset"] + ] + columns = [c for c in columns if "pre" in c] + df_pre = pd.pivot_table( + combined_df, values=columns, index=["size"], aggfunc=["mean", "std"] + ) + df_pre = update_columns(df_pre) + df["preset"] = df_pre.T["linear"].values + + df = df.reindex(IDX) + df = convert_to_text(df) + + # Save the dataframe to a LaTeX table + df.to_latex("table.tex", escape=False) + + +if __name__ == "__main__": + sys.exit(main(sys.argv[1:])) diff --git a/scripts/train_linear.sh b/scripts/train_linear.sh new file mode 100755 index 0000000..9383f51 --- /dev/null +++ b/scripts/train_linear.sh @@ -0,0 +1,14 @@ +timbreremap-train-sdss fit -c cfg/onset_mapping_808_linear.yaml --model.preset cfg/presets/808_snare_1.json --data.onset_feature.frame_size 2048 --trainer.default_root_dir experiment/logs_linear2048_train +timbreremap-train-sdss fit -c cfg/onset_mapping_808_linear.yaml --model.preset cfg/presets/808_snare_2.json --data.onset_feature.frame_size 2048 --trainer.default_root_dir experiment/logs_linear2048_train +timbreremap-train-sdss fit -c cfg/onset_mapping_808_linear.yaml --model.preset cfg/presets/808_snare_3.json --data.onset_feature.frame_size 2048 --trainer.default_root_dir experiment/logs_linear2048_train +timbreremap-train-sdss fit -c cfg/onset_mapping_808_linear.yaml --model.preset cfg/presets/808_open_snare.json --data.onset_feature.frame_size 2048 --trainer.default_root_dir experiment/logs_linear2048_train +timbreremap-train-sdss fit -c cfg/onset_mapping_808_linear.yaml --model.preset cfg/presets/808_noisy_snare.json --data.onset_feature.frame_size 2048 --trainer.default_root_dir experiment/logs_linear2048_train + +timbreremap-train-sdss fit -c cfg/onset_mapping_808_linear.yaml --model.preset cfg/presets/808_snare_1.json --trainer.default_root_dir experiment/logs_linear_train +timbreremap-train-sdss fit -c cfg/onset_mapping_808_linear.yaml --model.preset cfg/presets/808_snare_2.json --trainer.default_root_dir experiment/logs_linear_train +timbreremap-train-sdss fit -c cfg/onset_mapping_808_linear.yaml --model.preset cfg/presets/808_snare_3.json --trainer.default_root_dir experiment/logs_linear_train +timbreremap-train-sdss fit -c cfg/onset_mapping_808_linear.yaml --model.preset cfg/presets/808_open_snare.json --trainer.default_root_dir experiment/logs_linear_train +timbreremap-train-sdss fit -c cfg/onset_mapping_808_linear.yaml --model.preset cfg/presets/808_noisy_snare.json --trainer.default_root_dir experiment/logs_linear_train + +timbreremap-test experiment/logs_linear2048_train/lightning_logs experiment/test_logs_linear2048 +timbreremap-test experiment/logs_linear_train/lightning_logs experiment/test_logs_linear diff --git a/scripts/train_mlp.sh b/scripts/train_mlp.sh new file mode 100755 index 0000000..5a59e06 --- /dev/null +++ b/scripts/train_mlp.sh @@ -0,0 +1,14 @@ +timbreremap-train-sdss fit -c cfg/onset_mapping_808.yaml --model.preset cfg/presets/808_snare_1.json --data.onset_feature.frame_size 2048 --trainer.default_root_dir experiment/logs_mlp2048_train +timbreremap-train-sdss fit -c cfg/onset_mapping_808.yaml --model.preset cfg/presets/808_snare_2.json --data.onset_feature.frame_size 2048 --trainer.default_root_dir experiment/logs_mlp2048_train +timbreremap-train-sdss fit -c cfg/onset_mapping_808.yaml --model.preset cfg/presets/808_snare_3.json --data.onset_feature.frame_size 2048 --trainer.default_root_dir experiment/logs_mlp2048_train +timbreremap-train-sdss fit -c cfg/onset_mapping_808.yaml --model.preset cfg/presets/808_open_snare.json --data.onset_feature.frame_size 2048 --trainer.default_root_dir experiment/logs_mlp2048_train +timbreremap-train-sdss fit -c cfg/onset_mapping_808.yaml --model.preset cfg/presets/808_noisy_snare.json --data.onset_feature.frame_size 2048 --trainer.default_root_dir experiment/logs_mlp2048_train + +timbreremap-train-sdss fit -c cfg/onset_mapping_808.yaml --model.preset cfg/presets/808_snare_1.json --trainer.default_root_dir experiment/logs_mlp_train +timbreremap-train-sdss fit -c cfg/onset_mapping_808.yaml --model.preset cfg/presets/808_snare_2.json --trainer.default_root_dir experiment/logs_mlp_train +timbreremap-train-sdss fit -c cfg/onset_mapping_808.yaml --model.preset cfg/presets/808_snare_3.json --trainer.default_root_dir experiment/logs_mlp_train +timbreremap-train-sdss fit -c cfg/onset_mapping_808.yaml --model.preset cfg/presets/808_open_snare.json --trainer.default_root_dir experiment/logs_mlp_train +timbreremap-train-sdss fit -c cfg/onset_mapping_808.yaml --model.preset cfg/presets/808_noisy_snare.json --trainer.default_root_dir experiment/logs_mlp_train + +timbreremap-test experiment/logs_mlp2048_train/lightning_logs experiment/test_logs_mlp2048 +timbreremap-test experiment/logs_mlp_train/lightning_logs experiment/test_logs_mlp diff --git a/scripts/train_mlp_lrg.sh b/scripts/train_mlp_lrg.sh new file mode 100755 index 0000000..2749133 --- /dev/null +++ b/scripts/train_mlp_lrg.sh @@ -0,0 +1,14 @@ +timbreremap-train-sdss fit -c cfg/onset_mapping_808_lrg.yaml --model.preset cfg/presets/808_snare_1.json --data.onset_feature.frame_size 2048 --trainer.default_root_dir experiment/logs_mlplrg2048_train +timbreremap-train-sdss fit -c cfg/onset_mapping_808_lrg.yaml --model.preset cfg/presets/808_snare_2.json --data.onset_feature.frame_size 2048 --trainer.default_root_dir experiment/logs_mlplrg2048_train +timbreremap-train-sdss fit -c cfg/onset_mapping_808_lrg.yaml --model.preset cfg/presets/808_snare_3.json --data.onset_feature.frame_size 2048 --trainer.default_root_dir experiment/logs_mlplrg2048_train +timbreremap-train-sdss fit -c cfg/onset_mapping_808_lrg.yaml --model.preset cfg/presets/808_open_snare.json --data.onset_feature.frame_size 2048 --trainer.default_root_dir experiment/logs_mlplrg2048_train +timbreremap-train-sdss fit -c cfg/onset_mapping_808_lrg.yaml --model.preset cfg/presets/808_noisy_snare.json --data.onset_feature.frame_size 2048 --trainer.default_root_dir experiment/logs_mlplrg2048_train + +timbreremap-train-sdss fit -c cfg/onset_mapping_808_lrg.yaml --model.preset cfg/presets/808_snare_1.json --trainer.default_root_dir experiment/logs_mlplrg_train +timbreremap-train-sdss fit -c cfg/onset_mapping_808_lrg.yaml --model.preset cfg/presets/808_snare_2.json --trainer.default_root_dir experiment/logs_mlplrg_train +timbreremap-train-sdss fit -c cfg/onset_mapping_808_lrg.yaml --model.preset cfg/presets/808_snare_3.json --trainer.default_root_dir experiment/logs_mlplrg_train +timbreremap-train-sdss fit -c cfg/onset_mapping_808_lrg.yaml --model.preset cfg/presets/808_open_snare.json --trainer.default_root_dir experiment/logs_mlplrg_train +timbreremap-train-sdss fit -c cfg/onset_mapping_808_lrg.yaml --model.preset cfg/presets/808_noisy_snare.json --trainer.default_root_dir experiment/logs_mlplrg_train + +timbreremap-test experiment/logs_mlplrg2048_train/lightning_logs experiment/test_logs_mlplrg2048 +timbreremap-test experiment/logs_mlplrg_train/lightning_logs experiment/test_logs_mlplrg diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 0000000..81558c6 --- /dev/null +++ b/setup.cfg @@ -0,0 +1,8 @@ +[flake8] +max-line-length = 88 +extend-ignore = E203 +per-file-ignores = __init__.py:F401 + +[tool:pytest] +testpaths = test +addopts = --cov=timbreremap --cov-report term-missing diff --git a/test/__init__.py b/test/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/test/conftest.py b/test/conftest.py new file mode 100644 index 0000000..685d12a --- /dev/null +++ b/test/conftest.py @@ -0,0 +1,3 @@ +import os + +os.environ["NUMBA_DISABLE_JIT"] = "1" diff --git a/test/test_data.py b/test/test_data.py new file mode 100644 index 0000000..eb428b3 --- /dev/null +++ b/test/test_data.py @@ -0,0 +1,62 @@ +import pytest +import torch +import torchaudio + +from timbreremap.data import OnsetFeatureDataModule + + +def test_onset_feature_datamodule_init(tmp_path): + audio_path = tmp_path / "test.wav" + feature = torch.nn.Linear(44100, 3) + onset_feature = torch.nn.Linear(44100, 1) + data_module = OnsetFeatureDataModule(audio_path, feature, onset_feature) + assert data_module.audio_path == audio_path + assert data_module.feature == feature + + +@pytest.fixture +def audio_folder(tmp_path): + """Create a folder with 8 noisy audio files""" + for i in range(8): + audio_path = tmp_path / f"{i}.wav" + test_audio = torch.rand(1, 44100) * 2.0 - 1.0 + amp_env = torch.linspace(1, 0, 44100) + test_audio = test_audio * amp_env + torchaudio.save(audio_path, test_audio, 44100) + + yield tmp_path + + # Clean up + for p in tmp_path.glob("*.wav"): + p.unlink() + + +def test_onset_feature_datamodule_prepare(audio_folder): + feature = torch.nn.Linear(44100, 6) + onset_feature = torch.nn.Linear(44100, 3) + data_module = OnsetFeatureDataModule( + audio_folder, feature, onset_feature, sample_rate=44100 + ) + data_module.prepare_data() + + assert hasattr(data_module, "full_features") + assert hasattr(data_module, "onset_features") + assert data_module.full_features.shape == (8, 6) + assert data_module.onset_features.shape == (8, 3) + + +def test_onset_feature_datamodule_setup(audio_folder): + feature = torch.nn.Linear(44100, 6) + onset_feature = torch.nn.Linear(44100, 3) + data_module = OnsetFeatureDataModule( + audio_folder, feature, onset_feature, sample_rate=44100, return_norm=True + ) + data_module.prepare_data() + data_module.setup("fit") + + assert hasattr(data_module, "train_dataset") + assert len(data_module.train_dataset) == 8 + o, f, w = data_module.train_dataset[0] + assert o.shape == (3,) + assert f.shape == (6,) + assert w.shape == f.shape diff --git a/test/test_feature.py b/test/test_feature.py new file mode 100644 index 0000000..e142981 --- /dev/null +++ b/test/test_feature.py @@ -0,0 +1,84 @@ +import numpy as np +import torch + +from timbreremap.feature import FeatureExtractor +from timbreremap.feature import OnsetSegment +from timbreremap.feature import RMS +from timbreremap.feature import SpectralCentroid +from timbreremap.feature import SpectralFlatness + + +def test_onset_segment(): + x = torch.arange(100).unsqueeze(0).float() + seg = OnsetSegment(50) + y = seg(x) + assert y.shape == (1, 50) + assert torch.all(y == torch.arange(50).unsqueeze(0).float()) + + +def test_onset_segment_delay(): + x = torch.arange(100).unsqueeze(0).float() + seg = OnsetSegment(50, 10) + y = seg(x) + assert y.shape == (1, 50) + assert torch.all(y == torch.arange(10, 60).unsqueeze(0).float()) + + +def test_rms(): + x = torch.ones(100).unsqueeze(0).float() + rms = RMS() + y = rms(x) + assert y.shape == (1,) + assert y.item() == 1.0 + + +def test_rms_db(): + x = torch.ones(100).unsqueeze(0).float() + rms = RMS(db=True) + y = rms(x) + assert y.shape == (1,) + assert y.item() == 0.0 + + +def test_spectral_centroid(): + sr = 44100 + f0 = sr / 128.0 + w0 = 2.0 * torch.pi * (f0 / sr) + phase = torch.cumsum(w0 * torch.ones(128), dim=-1) + x = torch.sin(phase).unsqueeze(0).float() + + # Test without windowing -- should be exact + sc = SpectralCentroid(sr, scaling="none", window="none") + y = sc(x) + assert y.shape == (1,) + torch.testing.assert_close(y, torch.tensor([f0]), atol=1e-4, rtol=1e-4) + + +def test_spectral_flatness(): + sr = 44100 + w0 = 2.0 * torch.pi * (440.0 / sr) + phase = torch.cumsum(w0 * torch.ones(128), dim=-1) + x = torch.sin(phase).unsqueeze(0).float() + + sf = SpectralFlatness(window="hann") + y = sf(x) + assert torch.all(y < -150.0) + + g = torch.Generator() + g.manual_seed(0) + noise = torch.rand(1, sr, generator=g) * 2.0 - 1.0 + y = sf(noise) + assert torch.all(y > -6.0) + + silence = torch.zeros(1, 1024) + y = sf(silence) + torch.testing.assert_close(y, torch.zeros_like(y)) + + +def test_feature_extractor(): + x = torch.zeros(100).unsqueeze(0) + x[:, 50:] = 1.0 + features = [OnsetSegment(50), RMS()] + extractor = FeatureExtractor(features) + y = extractor(x) + torch.testing.assert_close(y, torch.tensor([np.sqrt(1e-8)], dtype=torch.float)) diff --git a/test/test_model.py b/test/test_model.py new file mode 100644 index 0000000..f9506d6 --- /dev/null +++ b/test/test_model.py @@ -0,0 +1,30 @@ +""" +Unit tests for timbre remapping models +""" +import torch + +from timbreremap.model import LinearMapping + + +def test_linear_mapping_init(): + model = LinearMapping(3, 8) + assert isinstance(model, LinearMapping) + assert len(list(model.children())) == 1 + assert model.net.bias is None + assert model.net.weight.shape == (8, 3) + + +def test_linear_mapping_forward(): + model = LinearMapping(2, 3) + torch.nn.init.constant_(model.net.weight, 1.0) + y = model(torch.tensor([0.5, 1.0])) + assert torch.all(y == 1.5) + + +def test_linear_mapping_clamp(): + model = LinearMapping(2, 3, clamp=True) + torch.nn.init.constant_(model.net.weight, 0.5) + y = model(torch.tensor([2.0, 4.0])) + assert torch.all(y == 1.0) + y = model(torch.tensor([2.0, -4.0])) + assert torch.all(y == -1.0) diff --git a/test/test_np_core.py b/test/test_np_core.py new file mode 100644 index 0000000..e5a6d7f --- /dev/null +++ b/test/test_np_core.py @@ -0,0 +1,51 @@ +import numpy as np + +from timbreremap.np.core import envelope_follower +from timbreremap.np.core import EnvelopeFollower +from timbreremap.np.core import HighPassFilter + + +def test_construct_hpf(): + filter = HighPassFilter(sr=44100, cutoff=1000) + assert filter.sr == 44100 + assert filter.cutoff == 1000 + + +def test_hpf_call(): + x = np.random.randn(1, 44100) + filter = HighPassFilter(sr=44100, cutoff=10000) + y = filter(x) + assert y.shape == x.shape + + +def test_envelope_follower(): + # Test envelope follower rising up + x = np.zeros((1, 1000)) + x[0, 500:] = 1.0 + y = envelope_follower(x, up=1.0 / 100.0, down=1.0) + assert y.shape == x.shape + assert np.all(y[:, :500] == 0.0) + assert np.all(y[:, 500:600] > 0.0) and np.all(y[:, 500:599] < (1.0 - 1.0 / np.e)) + assert np.all(y[:, 600:] >= (1.0 - 1.0 / np.e)) + + # Test envelope follower falling down + x = np.zeros((1, 1000)) + x[0, :500] = 1.0 + y = envelope_follower(x, up=1.0, down=1.0 / 100.0) + assert y.shape == x.shape + assert np.all(y[:, :500] == 1.0) + assert np.all(y[:, 500:600] < 1.0) and np.all(y[:, 500:599] > 1.0 / np.e) + assert np.all(y[:, 600:] <= 1.0 / np.e) + + +def test_envelope_follower_class(mocker): + follower = EnvelopeFollower(attack_samples=100, release_samples=100) + assert follower.up == 0.01 + assert follower.down == 0.01 + + x = np.zeros((1, 1000)) + x[0, 500:600] = 1.0 + + mocked_ef = mocker.patch("timbreremap.np.core.envelope_follower") + follower(x) + mocked_ef.assert_called_once_with(x, 0.01, 0.01, initial=0.0) diff --git a/test/test_synth.py b/test/test_synth.py new file mode 100644 index 0000000..c204096 --- /dev/null +++ b/test/test_synth.py @@ -0,0 +1,45 @@ +import torch + +from timbreremap.synth import CrossFade +from timbreremap.synth import ExponentialDecay +from timbreremap.synth import ParamaterNormalizer + + +def test_parameter_normalizer_init(): + normalizer = ParamaterNormalizer(0.0, 1.0) + assert normalizer.min_value == 0.0 + assert normalizer.max_value == 1.0 + + +def test_parameter_normalizer_from_0to1(): + normalizer = ParamaterNormalizer(10.0, 20.0) + x = torch.linspace(0.0, 1.0, 10) + y = normalizer.from_0to1(x) + expected = torch.linspace(10.0, 20.0, 10) + assert torch.allclose(y, expected) + + +def test_parameter_normalizer_from_1to0(): + normalizer = ParamaterNormalizer(10.0, 20.0) + x = torch.linspace(10.0, 20.0, 10) + y = normalizer.to_0to1(x) + expected = torch.linspace(0.0, 1.0, 10) + assert torch.allclose(y, expected) + + +def test_exponential_decay_init(): + decay = ExponentialDecay(sample_rate=44100) + assert decay.normalizers["decay"].min_value == 10.0 + + +def test_exponential_decay_forward(): + env = ExponentialDecay(sample_rate=44100) + decay = torch.tensor([0.1, 0.5])[..., None] + y = env(1000, decay) + assert y.shape == (2, 1000) + + +def test_crossfade_init(): + fade = CrossFade(sample_rate=44100) + assert fade.normalizers["fade"].min_value == 0.0 + assert fade.normalizers["fade"].max_value == 1.0 diff --git a/test/test_tasks.py b/test/test_tasks.py new file mode 100644 index 0000000..a4a49f9 --- /dev/null +++ b/test/test_tasks.py @@ -0,0 +1,80 @@ +import pytest +import torch + +from timbreremap.loss import FeatureDifferenceLoss +from timbreremap.model import MLP +from timbreremap.synth import SimpleDrumSynth +from timbreremap.tasks import TimbreRemappingTask + + +@pytest.fixture +def features(): + # Dummy feature extractor + return torch.nn.Linear(44100, 3) + + +@pytest.fixture +def preset_json(tmp_path): + # Crate a dummy preset json + synth = SimpleDrumSynth(sample_rate=44100, num_samples=44100) + num_params = synth.get_num_params() + preset = torch.rand(1, num_params) + preset_dict = synth.get_param_dict() + for i, (k, n) in enumerate(synth.get_param_dict().items()): + preset_dict[k] = n.from_0to1(preset[0, i]).item() + preset_json = tmp_path.joinpath("preset.json") + synth.save_params_json(preset_json, preset_dict) + return preset_json + + +def test_init_timbre_remapping_task(features, preset_json): + synth = SimpleDrumSynth(sample_rate=44100, num_samples=44100) + model = MLP(in_size=1, hidden_size=1, out_size=1, num_layers=1) + _ = TimbreRemappingTask( + model=model, + synth=synth, + preset=preset_json, + feature=features, + loss_fn=None, + ) + + +def test_timbre_remapping_task_forward(features, preset_json): + synth = SimpleDrumSynth(sample_rate=44100, num_samples=44100) + num_params = synth.get_num_params() + model = MLP(in_size=1, hidden_size=1, out_size=num_params, num_layers=1) + mapping = TimbreRemappingTask( + model=model, + synth=synth, + preset=preset_json, + feature=features, + loss_fn=None, + ) + + inputs = torch.rand(1, 1) + y = mapping(inputs) + assert y.shape == (1, 44100) + + +def test_timbre_remapping_task_train_step(features, preset_json): + # Initialize a synth and preset + synth = SimpleDrumSynth(sample_rate=44100, num_samples=44100) + num_params = synth.get_num_params() + + # Initialize a parameter mapping model + mapping = TimbreRemappingTask( + model=torch.nn.Linear(1, num_params), + synth=synth, + preset=preset_json, + feature=features, + loss_fn=FeatureDifferenceLoss(), + ) + + # Create a dummy batch of inputs + inputs = torch.rand(1, 1) + target = torch.rand(1, 1) + loss = mapping.training_step((inputs, target), 0) + + # Check that the loss is a scalar + assert isinstance(loss, torch.Tensor) + assert loss.ndim == 0 diff --git a/timbreremap/__init__.py b/timbreremap/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/timbreremap/callback.py b/timbreremap/callback.py new file mode 100644 index 0000000..bcb852b --- /dev/null +++ b/timbreremap/callback.py @@ -0,0 +1,162 @@ +""" +PyTorch Lightning Callbacks +""" +from pathlib import Path + +import lightning as L +import matplotlib.pyplot as plt +import pandas as pd +import torch +import torchaudio +from einops import rearrange +from einops import repeat + +from timbreremap.data import OnsetFeatureDataset +from timbreremap.export import ParameterMapper + + +class SaveAudioCallback(L.Callback): + def __init__(self, num_samples: int = 16): + super().__init__() + self.num_samples = num_samples + + def on_train_end(self, trainer: L.Trainer, module: L.LightningModule) -> None: + data = trainer.train_dataloader.dataset + + # If the onset reference values are being used in an OnsetFeatureDataset, + # pass those values in to normalize the input features + onset_ref = None + if isinstance(data, OnsetFeatureDataset): + onset_ref = data.onset_ref + + self.render_audio(module, onset_ref) + self.render_fig(module, onset_ref) + + def render_audio(self, module, onset_ref=None): + outdir = Path(module.logger.log_dir).joinpath("audio") + outdir.mkdir(exist_ok=True) + + num_features = module.model.in_size + + # Generate input ranging from 0 to 1 + x = torch.linspace(0, 1, self.num_samples, device=module.device) + x = repeat(x, "n -> n f", f=num_features) + + # Offset onset features + if onset_ref is not None: + x = x - onset_ref.to(module.device) + + # Generate audio + y = module(x) + y = rearrange(y, "b n -> 1 (b n)") + y = y.detach().cpu() + + torchaudio.save( + outdir.joinpath("gradient_all.wav"), y, module.synth.sample_rate + ) + + def render_fig(self, module, onset_ref=None): + # Save a plot of parameter changes + outdir = Path(module.logger.log_dir).joinpath("plots") + outdir.mkdir(exist_ok=True) + + # Generate input ranging from 0 to 1 + x = torch.linspace(0, 1, self.num_samples * 100, device=module.device) + x = repeat(x, "n -> n f", f=module.model.in_size) + + # Offset onset features + if onset_ref is not None: + print(onset_ref) + x = x - onset_ref.to(module.device) + + labels = list(module.synth.get_param_dict().keys()) + + fig, ax = plt.subplots(1, 1, figsize=(7, 5)) + preset = module.preset.detach().cpu().numpy() + param_mod = module.model(x).detach().cpu().numpy() + damping = module.damping.detach().cpu().numpy() + + for i in range(preset.shape[-1]): + mod = param_mod[:, i] * damping[0, i] + ax.plot(mod, label="-".join(labels[i])) + + fig.legend(loc=7) + fig.subplots_adjust(right=0.75) + fig.savefig(outdir.joinpath("parameter_modulationss.png"), dpi=150) + + +class SaveTorchScriptCallback(L.Callback): + def __init__(self): + super().__init__() + + def on_train_end(self, trainer: L.Trainer, module: L.LightningModule) -> None: + model = module.model.cpu() + damping = module.damping.detach().cpu() + preset = module.preset.detach().cpu() + mapper = ParameterMapper(model, damping, preset) + sm = torch.jit.script(mapper) + + # Test the torchscript module + x = torch.rand(1, module.model.in_size) + y = sm(x) + + num_params = module.synth.get_num_params() + assert y.shape == (2, num_params) + + # Save the torchscript module + outdir = Path(module.logger.log_dir).joinpath("torchscript") + outdir.mkdir(exist_ok=True) + torch.jit.save(sm, outdir.joinpath("drum_mapper.pt")) + + +class SaveTimbreIntervalAudio(L.Callback): + def __init__(self): + super().__init__() + + def on_train_end(self, trainer: L.Trainer, module: L.LightningModule) -> None: + outdir = Path(module.logger.log_dir).joinpath("audio") + outdir.mkdir(exist_ok=True) + + # Save the reference and target audio + data = trainer.train_dataloader.dataset + torchaudio.save( + outdir.joinpath("A.wav"), data.reference_audio, data.sample_rate + ) + torchaudio.save(outdir.joinpath("B.wav"), data.target_audio, data.sample_rate) + + # Generate audio for the timbre intervals + y_true = module.synth(module.preset).detach().cpu() + y_pred = module().detach().cpu() + + torchaudio.save(outdir.joinpath("C.wav"), y_true, module.synth.sample_rate) + torchaudio.save(outdir.joinpath("D.wav"), y_pred, module.synth.sample_rate) + + # What is the error per feature? + ref_true = module.feature(y_true) + ref_pred = module.feature(y_pred) + + # Calculate feature differences + diff = ref_pred - ref_true + error = diff - data[0] + error = error[0].tolist() + + meta = [] + for feature in module.feature.features: + frame_size = feature.frame_size + hop_size = feature.hop_size + for f in feature.flattened_features: + meta.append( + { + "feature": f"{f[0]}.{f[1]}", + "frame_size": frame_size, + "hop_size": hop_size, + } + ) + + feature_error = [] + for e, f in zip(error, meta): + feature_error.append({**f, "error": e}) + + df = pd.DataFrame(feature_error) + outdir = Path(module.logger.log_dir) + df.to_csv(outdir.joinpath("feature_error.csv"), index=False) diff --git a/timbreremap/cli.py b/timbreremap/cli.py new file mode 100644 index 0000000..e769c98 --- /dev/null +++ b/timbreremap/cli.py @@ -0,0 +1,208 @@ +""" +timbre remapping cli entry point +""" +import argparse +import copy +import logging +import os +import sys +import time +from pathlib import Path + +import lightning as L +import pandas as pd +import torch +from lightning.pytorch.cli import LightningCLI +from tqdm import tqdm + + +# Setup logging +logging.basicConfig() +log = logging.getLogger(__name__) +log.setLevel(level=os.environ.get("LOGLEVEL", "INFO")) + + +def run_cli(): + """ """ + _ = LightningCLI() + return + + +def main(): + """ """ + start_time = time.time() + run_cli() + end_time = time.time() + log.info(f"Total time: {end_time - start_time} seconds") + + +def test_version(): + parser = argparse.ArgumentParser( + description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter + ) + parser.add_argument("train_logs", help="Input log directory", type=str) + parser.add_argument("output_dir", help="Output log directory", type=str) + + # Get input directory + args = parser.parse_args(sys.argv[1:]) + logs = Path(args.train_logs) + output = Path(args.output_dir) + + # Reset arguments + args = sys.argv[:1] + configs = [] + for version in sorted(logs.iterdir()): + if version.is_dir(): + config = version.joinpath("config.yaml") + ckpt = list(version.joinpath("checkpoints").glob("*.ckpt")) + if len(ckpt) > 0: + ckpt = ckpt[0] + else: + continue + + if config.exists(): + cfg_args = ["test", "-c", str(config), "--ckpt", str(ckpt)] + configs.append(cfg_args) + + for c in configs: + run_args = copy.deepcopy(args) + run_args.extend(c) + run_args.extend(["--trainer.logger", "CSVLogger"]) + run_args.extend(["--trainer.logger.save_dir", str(output)]) + sys.argv = run_args + print(run_args) + _ = LightningCLI() + + return + + +def train_sdss(): + """ """ + args = sys.argv + configs = [] + folders = Path("audio/sdss_filtered") + for f in folders.iterdir(): + if f.is_dir(): + cfg_args = copy.deepcopy(args) + cfg_args.extend(["--data.audio_path", str(f)]) + configs.append(cfg_args) + + for i, c in enumerate(configs): + sys.argv = c + print(c) + _ = LightningCLI() + + return + + +def optimize_sdss(): + """ + Optimize the synthesis parameters for the sdss dataset + """ + args = sys.argv + configs = [] + folders = Path("audio/sdss_filtered") + for f in folders.iterdir(): + if f.is_dir(): + cfg_args = copy.deepcopy(args) + cfg_args.extend(["--data.audio_path", str(f)]) + configs.append(cfg_args) + + outdir = Path("experiment/test_logs_direct_opt").joinpath("lightning_logs") + outdir.mkdir(exist_ok=True, parents=True) + + # Get the starting version number from the output directory + versions = sorted(outdir.glob("version_*")) + version = len(versions) + log.info(f"Starting version: {version}") + + # Run the optimization for each configuration + for i, c in enumerate(configs): + sys.argv = c + print(c) + + # Create the output directory + outdir_i = outdir.joinpath(f"version_{version + i}") + outdir_i.mkdir(exist_ok=True) + output = outdir_i.joinpath("metrics.csv") + + # Run the optimization + direct_optimization(output=output) + + return + + +def direct_optimization(output=None): + """ + Direct optimization of synthesis parameters + """ + cli = LightningCLI(run=False) + + # Set up the training dataset with no validation split + datamodule = cli.datamodule + datamodule.batch_size = 256 + datamodule.prepare_data() + datamodule.setup("test") + dataloader = datamodule.test_dataloader() + + for i, item in enumerate(dataloader): + _optimize_synth(cli.model, target=item[1]) + + metrics = {} + for i, (k, v) in enumerate(cli.model.feature_metrics.items()): + metrics[f"test/{k}"] = [ + v.compute().item(), + ] + + df = pd.DataFrame.from_dict(metrics, orient="columns") + if output is not None: + df.to_csv(output, index=False) + else: + df.to_csv("metrics.csv", index=False) + + return + + +def _optimize_synth(model: L.LightningModule, target: torch.Tensor): + """ + Optimize the synthesis parameters to match a feature target + """ + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model = model.to(device) + target = target.to(device) + + synth = model.synth + preset = model.preset + modulation = torch.zeros( + target.shape[0], preset.shape[1], device=device, requires_grad=True + ) + + reference = model.feature(synth(preset)) + + # Create an optimizer + optimizer = torch.optim.Adam([modulation], lr=0.005) + schedule = torch.optim.lr_scheduler.ReduceLROnPlateau( + optimizer, patience=50, factor=0.5, verbose=True + ) + + pbar = tqdm(range(1000)) + for i in pbar: + optimizer.zero_grad() + + params = preset + modulation + params = torch.clip(params, 0.0, 1.0) + + y = synth(params) + y_features = model.feature(y) + + loss = model.loss_fn(y_features, reference, target) + loss.backward() + optimizer.step() + + schedule.step(loss) + + pbar.set_description(f"Loss: {loss.detach().item():.4f}") + + diff = y_features - reference + for i, (k, v) in enumerate(model.feature_metrics.items()): + v.update(diff[..., i], target[..., i]) diff --git a/timbreremap/data.py b/timbreremap/data.py new file mode 100644 index 0000000..ab45e65 --- /dev/null +++ b/timbreremap/data.py @@ -0,0 +1,313 @@ +""" +Datasets for training and testing. +""" +import logging +import os +from pathlib import Path +from typing import Optional +from typing import Union + +import lightning as L +import torch +import torchaudio +from torch.utils.data import DataLoader + +from timbreremap.np import OnsetFrames + +# Setup logging +logging.basicConfig() +log = logging.getLogger(__name__) +log.setLevel(level=os.environ.get("LOGLEVEL", "INFO")) + + +class OnsetFeatureDataset(torch.utils.data.Dataset): + """ + Dataset that returns pairs of onset features with full features + """ + + def __init__( + self, + onset_features: torch.Tensor, # Onset features + full_features: torch.Tensor, # Full features + weight: Optional[torch.Tensor] = None, # Feature weighting + onset_ref: Optional[torch.Tensor] = None, # Onset feature values for reference + ): + super().__init__() + self.onset_features = onset_features + self.full_features = full_features + assert self.onset_features.shape[0] == self.full_features.shape[0] + self.size = self.onset_features.shape[0] + self.weight = weight + self.onset_ref = onset_ref + + def __len__(self): + return self.size + + def __getitem__(self, idx): + onset_features = self.onset_features[idx] + if self.onset_ref is not None: + onset_features = self.onset_features[idx] - self.onset_ref + + if self.weight is None: + return onset_features, self.full_features[idx] + + return onset_features, self.full_features[idx], self.weight + + +class OnsetFeatureDataModule(L.LightningDataModule): + """ + A LightningDataModule for datasets with onset features + """ + + def __init__( + self, + audio_path: Union[Path, str], # Path to an audio file or directory + feature: torch.nn.Module, # A feature extractor + onset_feature: torch.nn.Module, # A feature extractor for short onsets + sample_rate: int = 48000, # Sample rate to compute features at + batch_size: int = 64, # Batch size + return_norm: bool = False, # Whether to return the feature norm + center_onset: bool = False, # Whether to return reference onset features as 0 + val_split: float = 0.0, # Fraction of data to use for validation + test_split: float = 0.0, # Fraction of data to use from testing + data_seed: int = 0, # Seed for random data splits + ): + super().__init__() + self.audio_path = Path(audio_path) + self.feature = feature + self.onset_feature = onset_feature + self.sample_rate = sample_rate + self.batch_size = batch_size + self.return_norm = return_norm + self.center_onset = center_onset + self.val_split = val_split + self.test_split = test_split + self.data_seed = data_seed + + def prepare_data(self) -> None: + """ + Load the audio file and prepare the dataset + """ + # Calculate the onset frames + onset_frames = OnsetFrames( + self.sample_rate, + frame_size=self.sample_rate, + on_thresh=10.0, + wait=10000, + backtrack=16, + overlap_buffer=1024, + ) + + # Load audio files from a directory + if self.audio_path.is_dir(): + audio_files = list(self.audio_path.glob("*.wav")) + assert len(audio_files) > 0, "No audio files found in directory" + log.info(f"Found {len(audio_files)} audio files.") + + audio = [] + for f in audio_files: + x, sr = torchaudio.load(f) + + # Resample if necessary + if sr != self.sample_rate: + x = torchaudio.transforms.Resample(sr, self.sample_rate)(x) + + # Onset detection and frame extraction to ensure all audio + # is the same length and is aligned at an onset + frames = onset_frames(x) + frames = torch.from_numpy(frames).float() + audio.append(frames) + + audio = torch.cat(audio, dim=0) + log.info(f"{len(audio)} samples after onset detection.") + + elif self.audio_path.is_file and self.audio_path.suffix == ".wav": + x, sr = torchaudio.load(self.audio_path) + + # Resample if necessary + if sr != self.sample_rate: + x = torchaudio.transforms.Resample(sr, self.sample_rate)(x) + + # Onset detection and frame extraction + audio = onset_frames(x) + audio = torch.from_numpy(audio).float() + + log.info(f"Found {len(audio)} samples.") + + else: + raise RuntimeError("Invalid audio path") + + # Cache audio + self.audio = audio + + # Compute full features + self.full_features = self.feature(audio) + loudsort = torch.argsort(self.full_features[:, 0], descending=True) + idx = int(len(loudsort) * 0.5) + idx = loudsort[idx] + + # Cache the index of the reference sample + self.ref_idx = idx + + # Compute the difference between the features of each audio and the centroid + self.diff = self.full_features - self.full_features[idx] + assert torch.allclose(self.diff[idx], torch.zeros_like(self.diff[idx])) + + # Create a per feature weighting + self.norm = torch.max(self.diff, dim=0)[0] - torch.min(self.diff, dim=0)[0] + self.norm = torch.abs(1.0 / self.norm).float() + + # Compute onset features for each sample + self.onset_features = self.onset_feature(audio) + + # Normalize onset features so each feature is in the range [0, 1] + self.onset_features = self.onset_features - self.onset_features.min(dim=0)[0] + self.onset_features = self.onset_features / self.onset_features.max(dim=0)[0] + assert torch.all(self.onset_features >= 0.0) + assert torch.all(self.onset_features <= 1.0) + + # Split the training data into train and test sets + if self.test_split > 0.0: + self.train_ids, self.test_ids = self.split_data( + loudsort, self.ref_idx, self.test_split + ) + else: + self.train_ids = loudsort + + # Split the remaing data into train and validation sets + if self.val_split > 0.0: + self.train_ids, self.val_ids = self.split_data( + self.train_ids, self.ref_idx, self.val_split + ) + + # Log the number of samples in each set + log.info(f"Training samples: {len(self.train_ids)}") + if hasattr(self, "val_ids"): + log.info(f"Validation samples: {len(self.val_ids)}") + if hasattr(self, "test_ids"): + log.info(f"Test samples: {len(self.test_ids)}") + + def split_data(self, ids: torch.Tensor, ref_idx: int, split: float): + """ + Select a subset of the data for validation + """ + assert split > 0.0 and split < 1.0 + + # Chunk the data into number of groups equal to the numbe of validation samples + # and then select a random sample from each chunk. + chunk_size = int(len(ids) * split) + 1 + chunks = torch.chunk(ids, chunk_size) + + train_ids = [] + val_ids = [] + + g = torch.Generator() + g.manual_seed(self.data_seed) + for chunk in chunks: + idx = torch.randint(0, len(chunk), (1,), generator=g).item() + # Ensure the validation sample is not the reference sample + if chunk[idx] == ref_idx: + idx = (idx + 1) % len(chunk) + + val_ids.append(chunk[idx].item()) + train_ids.extend(chunk[chunk != chunk[idx]].tolist()) + + assert len(train_ids) + len(val_ids) == len(ids) + assert len(set(train_ids).intersection(set(val_ids))) == 0 + return torch.tensor(train_ids), torch.tensor(val_ids) + + def setup(self, stage: str): + """ + Assign train/val/test datasets for use in dataloaders. + + Args: + stage: Current stage (fit, validate, test) + """ + assert hasattr(self, "onset_features"), "Must call prepare_data() first" + assert hasattr(self, "full_features"), "Must call prepare_data() first" + + onset_feature_ref = None + if self.center_onset: + onset_feature_ref = self.onset_features[self.ref_idx] + + norm = self.norm if self.return_norm else None + if stage == "fit": + self.train_dataset = OnsetFeatureDataset( + self.onset_features[self.train_ids], + self.diff[self.train_ids], + norm, + onset_ref=onset_feature_ref, + ) + if hasattr(self, "val_ids"): + self.val_dataset = OnsetFeatureDataset( + self.onset_features[self.val_ids], + self.diff[self.val_ids], + norm, + onset_ref=onset_feature_ref, + ) + elif stage == "validate": + if hasattr(self, "val_ids"): + self.val_dataset = OnsetFeatureDataset( + self.onset_features[self.val_ids], + self.diff[self.val_ids], + norm, + onset_ref=onset_feature_ref, + ) + else: + raise ValueError("No validation data available") + elif stage == "test": + if hasattr(self, "test_ids"): + self.test_dataset = OnsetFeatureDataset( + self.onset_features[self.test_ids], + self.diff[self.test_ids], + norm, + onset_ref=onset_feature_ref, + ) + else: + self.train_dataset = OnsetFeatureDataset( + self.onset_features, + self.diff, + norm, + onset_ref=onset_feature_ref, + ) + else: + raise NotImplementedError("Unknown stage") + + def train_dataloader(self, shuffle=True): + batch_size = min(self.batch_size, len(self.train_dataset)) + if batch_size < self.batch_size: + log.warning( + f"Reducing batch size to {batch_size}, " + "only that many samples available" + ) + + return DataLoader( + self.train_dataset, + batch_size=self.batch_size, + num_workers=0, + shuffle=shuffle, + ) + + def val_dataloader(self): + if not hasattr(self, "val_dataset"): + return None + + return DataLoader( + self.val_dataset, + batch_size=self.batch_size, + num_workers=0, + shuffle=False, + ) + + def test_dataloader(self): + if not hasattr(self, "test_dataset"): + log.info("No test dataset available, using full dataset for testing") + return self.train_dataloader(shuffle=False) + + log.info("Testing on the test dataset") + return DataLoader( + self.test_dataset, + batch_size=self.batch_size, + num_workers=0, + shuffle=False, + ) diff --git a/timbreremap/export.py b/timbreremap/export.py new file mode 100644 index 0000000..9420443 --- /dev/null +++ b/timbreremap/export.py @@ -0,0 +1,28 @@ +import torch + + +class ParameterMapper(torch.nn.Module): + """ + Wrapper for the parameter mapping model + to be used within the TorchDrum Plugin. + """ + + def __init__( + self, + model: torch.nn.Module, + damp: torch.Tensor, + patch: torch.Tensor, + param_project: torch.nn.Module = torch.nn.Identity(), + ): + super().__init__() + self.model = model + self.register_buffer("patch", patch) + self.register_buffer("damp", damp) + self.projection = param_project + + def forward(self, x: torch.Tensor) -> torch.Tensor: + param_mod = self.model(self.projection(x)) + params = param_mod * self.damp + params = torch.clip(params, -1.0, 1.0) + params = torch.cat([params, self.patch], dim=0) + return params diff --git a/timbreremap/feature.py b/timbreremap/feature.py new file mode 100644 index 0000000..cfce3a0 --- /dev/null +++ b/timbreremap/feature.py @@ -0,0 +1,519 @@ +""" +Differentiable Audio Features +""" +from collections import OrderedDict +from typing import List +from typing import Literal + +import numpy as np +import pyloudnorm +import torch +import torchaudio + + +class FeatureExtractor(torch.nn.Module): + """ + A serial connection of feature extraction layers. + """ + + def __init__(self, features: List[torch.nn.Module]): + super().__init__() + self.features = torch.nn.Sequential(*features) + + def forward(self, x: torch.Tensor): + return self.features(x) + + +class FeatureCollection(torch.nn.Module): + """ + A collection of feature extractors that are flattened + """ + + def __init__(self, features: List[torch.nn.Module]): + super().__init__() + self.features = torch.nn.ModuleList(features) + + def forward(self, x: torch.Tensor): + results = [] + for feature in self.features: + results.append(feature(x)) + return torch.cat(results, dim=-1) + + +class NumpyWrapper(torch.nn.Module): + """ + Wrap a feature extractor for numpy inputs. + """ + + def __init__(self, func): + super().__init__() + self.func = func + + def forward(self, x: np.ndarray): + return self.func(torch.from_numpy(x).float()).numpy() + + +class OnsetSegment(torch.nn.Module): + """ + Segment a signal in relation to the start of the signal. + """ + + def __init__(self, window: int = 2048, delay: int = 0): + super().__init__() + self.window = window + self.delay = delay + + def forward(self, x: torch.Tensor): + return x[..., self.delay : self.delay + self.window] + + +class CascadingFrameExtactor(torch.nn.Module): + """ + Given frames. Computes features and computes summary statistics + over an increasing number of frames from the onset + """ + + def __init__( + self, + extractors: List[torch.nn.Module], + num_frames: list[int], + frame_size: int = 2048, + hop_size: int = 1024, + pad_start: int = None, + include_mean: bool = True, + include_diff: bool = False, + always_from_onset: bool = False, + ): + super().__init__() + self.extractors = extractors + self.num_frames = num_frames + self.frame_size = frame_size + self.hop_size = hop_size + self.pad_start = pad_start + self.include_mean = include_mean + self.include_diff = include_diff + self.always_from_onset = always_from_onset + self.flattened_features = [] + + frame_feature_names = [] + k = 0 + for n in num_frames: + if include_mean: + frame_feature_names.append(f"{k}_{n}_mean") + if include_diff and n > 1: + frame_feature_names.append(f"{k}_{n}_diff") + k = k if always_from_onset else k + n + + for extractor in self.extractors: + ename = extractor._get_name() + for n in frame_feature_names: + self.flattened_features.append((ename, n)) + + self.num_features = len(self.flattened_features) + + def forward( + self, + x: torch.Tensor, # (batch, samples) + ): + y = self.get_as_dict(x) + flattened_y = [] + for extractor, feature in self.flattened_features: + flattened_y.append(y[extractor][feature].unsqueeze(-1)) + + return torch.cat(flattened_y, dim=-1) + + def get_as_dict(self, x: torch.Tensor): + """ + Returns a dictionary of features + """ + if self.pad_start is not None: + x = torch.nn.functional.pad(x, (self.pad_start, 0)) + + x = x.unfold(-1, self.frame_size, self.hop_size) + assert x.ndim == 3 + assert x.shape[1] >= sum(self.num_frames) + + results = OrderedDict() + for extractor in self.extractors: + ename = extractor._get_name() + if ename not in results: + results[ename] = OrderedDict() + + k = 0 + for n in self.num_frames: + frames = x[..., k : k + n, :] + y = extractor(frames) + + if self.include_mean: + y_mean = y.mean(dim=-1) + results[ename][f"{k}_{n}_mean"] = y_mean + + if self.include_diff and n > 1: + y_diff = torch.diff(y, dim=-1) + y_diff_mean = y_diff.mean(dim=-1) + results[ename][f"{k}_{n}_diff"] = y_diff_mean + + k = k if self.always_from_onset else k + n + + return results + + +class RMS(torch.nn.Module): + """ + Root mean square of a signal. + """ + + def __init__( + self, + db: bool = False, # Convert to dB + ): + super().__init__() + self.db = db + + def forward(self, x: torch.Tensor): + rms = torch.sqrt(torch.mean(torch.square(x), dim=-1) + 1e-8) + if self.db: + rms = 20 * torch.log10(rms + 1e-8) + # Clipping isn't good for gradients ? + # rms = torch.clamp(rms, min=-120.0) + return rms + + +class Loudness(torch.nn.Module): + """ + Computes loudness (LKFS) by applying K-weighting filters based on ITU-R BS.1770-4 + """ + + def __init__( + self, + sample_rate: int, + epsilon: float = 1e-8, + ): + super().__init__() + self.sample_rate = sample_rate + self.epsilon = epsilon + + # Setup K-weighting filters + a_coefs = [] + b_coefs = [] + filters = pyloudnorm.Meter(sample_rate, "K-weighting")._filters + for filt in filters.items(): + a_coefs.append(filt[1].a) + b_coefs.append(filt[1].b) + + a_coefs = np.array(a_coefs) + b_coefs = np.array(b_coefs) + self.register_buffer("a_coefs", torch.tensor(a_coefs, dtype=torch.float)) + self.register_buffer("b_coefs", torch.tensor(b_coefs, dtype=torch.float)) + + def prefilter(self, x: torch.Tensor): + """ + Prefilter the signal with K-weighting filters. + """ + # Apply K-weighting filters in series + a_coefs = self.a_coefs.to(x.device).split(1, dim=0) + b_coefs = self.b_coefs.to(x.device).split(1, dim=0) + for a, b in zip(a_coefs, b_coefs): + x = torchaudio.functional.lfilter(x, a.squeeze(), b.squeeze()) + + return x + + def forward(self, x: torch.Tensor): + """ + Compute loudness (LKFS) of a signal. + """ + x = self.prefilter(x) + loudness = torch.mean(torch.square(x), dim=-1) + loudness = -0.691 + 10.0 * torch.log10(loudness + self.epsilon) + return loudness + + +class SpectralCentroid(torch.nn.Module): + """ + Spectral centroid of a signal. + """ + + def __init__( + self, + sample_rate: int, + window: Literal["hann", "flat_top", "none"] = "hann", + compress: bool = False, + floor: float = None, # Floor spectral magnitudes to this value + scaling: Literal["semitone", "kazazis", "none"] = "semitone", + ): + super().__init__() + self.sample_rate = sample_rate + self.scaling = scaling + self.window_fn = get_window_fn(window) + self.compress = compress + self.floor = floor + + def forward(self, x: torch.Tensor): + # Apply a window + if self.window_fn is not None: + window = self.window_fn(x.shape[-1], device=x.device) + x = x * window + + # Calculate FFT + X = torch.fft.rfft(x, dim=-1) + X = torch.abs(X) + + if self.floor is not None: + X = torch.where(X < self.floor, self.floor, X) + + # Compression + if self.compress: + X = torch.log(1 + X) + + X_norm = torch.nn.functional.normalize(X, p=1, dim=-1) + + # Calculate spectral centroid + bins = torch.arange(X.shape[-1], device=x.device) + spectral_centroid = torch.sum(bins * X_norm, dim=-1) + + # Convert to Hz + bin_hz = self.sample_rate / x.shape[-1] + spectral_centroid = spectral_centroid * bin_hz + + # Convert to semitones + if self.scaling == "semitone": + spectral_centroid = ( + 12 * torch.log2((spectral_centroid + 1e-8) / 440.0) + 69.0 + ) + elif self.scaling == "kazazis": + spectral_centroid = -34.61 * torch.pow(spectral_centroid, -0.1621) + 21.2985 + + return spectral_centroid + + +class SpectralSpread(torch.nn.Module): + """ + Spectral spread of a signal. + + TODO: there is a lot of repeated code here with SpectralCentroid, and in general + with the spectral features. This should be refactored to avoid repeated calls to + FFTs etc. + """ + + def __init__( + self, + window: Literal["hann", "flat_top", "none"] = "hann", + compress: bool = False, + floor: float = None, # Floor spectral magnitudes to this value + ): + super().__init__() + self.window_fn = get_window_fn(window) + self.compress = compress + self.floor = floor + + def forward(self, x: torch.Tensor): + # Apply a window + if self.window_fn is not None: + window = self.window_fn(x.shape[-1], device=x.device) + x = x * window + + # Calculate FFT + X = torch.fft.rfft(x, dim=-1) + X = torch.abs(X) + + if self.floor is not None: + X = torch.where(X < self.floor, self.floor, X) + + # Compression + if self.compress: + X = torch.log(1 + X) + + X_norm = torch.nn.functional.normalize(X, p=1, dim=-1) + + # Calculate spectral centroid + bins = torch.arange(X.shape[-1], device=x.device) + spectral_centroid = torch.sum(bins * X_norm, dim=-1) + + # Calculate spectral spread + spectral_spread = torch.sum( + torch.square(bins - spectral_centroid[..., None]) * X_norm, dim=-1 + ) + + return spectral_spread + + +class SpectralFlatness(torch.nn.Module): + """ + Spectral flatness of a signal. + """ + + def __init__( + self, + amin: float = 1e-10, + window: Literal["hann", "flat_top", "none"] = "hann", + compress: bool = False, + ) -> None: + super().__init__() + self.amin = amin + self.window_fn = get_window_fn(window) + self.compress = compress + + def forward(self, x: torch.Tensor): + # Apply a window + if self.window_fn is not None: + window = self.window_fn(x.shape[-1], device=x.device) + x = x * window + + # Calculate FFT + X = torch.fft.rfft(x, dim=-1) + X = torch.abs(X) + + # Compression + if self.compress: + X = torch.log(1 + X) + + X_power = torch.where(X**2.0 < self.amin, self.amin, X**2.0) + gmean = torch.exp(torch.mean(torch.log(X_power), dim=-1)) + amean = torch.mean(X_power, dim=-1) + + # Calculate spectral flatness + spectral_flatness = gmean / amean + + # Convert to dB + spectral_flatness = 20.0 * torch.log10(spectral_flatness + 1e-8) + + return spectral_flatness + + +class SpectralFlux(torch.nn.Module): + """ + Spectral flux of a signal. + """ + + def __init__(self): + super().__init__() + + def forward(self, x: torch.Tensor): + assert x.ndim == 3 + assert x.shape[1] > 1, "Must have at least two frames" + + # Apply a window + window = torch.hann_window(x.shape[-1], device=x.device) + x = x * window + + # Calculate FFT + X = torch.fft.rfft(x, dim=-1) + X = torch.abs(X) + + flux = torch.diff(X, dim=-2) + flux = (flux + torch.abs(flux)) / 2 + flux = torch.square(flux) + flux = torch.sum(flux, dim=-1) + + return flux + + +class AmplitudeEnvelope(torch.nn.Module): + """ + Get the amplitude envelope of a signal by convolution with a window. + """ + + def __init__(self, window: int = 2048): + super().__init__() + self.window = window + + def forward(self, x: torch.Tensor): + assert x.ndim == 3 + + # Calculate the amplitude envelope + window = torch.hann_window(self.window, device=x.device) + window = window[None, None, :] + + x = torch.square(x) + y = torch.nn.functional.conv1d(x, window, padding="same") + assert torch.all(torch.isfinite(y)) + + return y + + +class TemporalCentroid(torch.nn.Module): + """ + Temporal centroid of a signal. + """ + + def __init__( + self, + sample_rate: int, + window_size: int = 2048, + scaling: Literal["schlauch", "none"] = "none", + ): + super().__init__() + self.sample_rate = sample_rate + self.window_size = window_size + self.envelope = AmplitudeEnvelope(window=window_size) + self.scaling = scaling + + def forward(self, x: torch.Tensor): + env = self.envelope(x) + y = torch.sum(env * torch.arange(env.shape[-1], device=x.device), dim=-1) + y = y / (torch.sum(env, dim=-1) + 1e-8) + y = y / self.sample_rate * 1000.0 + if self.scaling == "schlauch": + y = 0.03 * torch.pow(y, 1.864) + return y + + +class NormSum(torch.nn.Module): + """ + Sum of the normalized signal. + """ + + def __init__(self): + super().__init__() + + def forward(self, x: torch.Tensor): + assert x.ndim == 2 + assert not torch.all(torch.isnan(x)) + + norm = x / torch.max(torch.abs(x)) + norm = torch.nan_to_num(norm, nan=0.0) + y = torch.sum(norm, dim=-1) + y = y / float(x.shape[-1]) + return y + + +class AmpEnvSum(FeatureExtractor): + """ + Sum of the amplitude envelope. + """ + + def __init__(self, window: int = 2048): + super().__init__([AmplitudeEnvelope(window), NormSum()]) + + +def get_window_fn(window: str): + if window == "hann": + return torch.hann_window + elif window == "flat_top": + return flat_top_window + elif window == "none": + return None + else: + raise ValueError(f"Unknown window type: {window}") + + +def flat_top_window(size, device="cpu"): + """ + Flat top window for spectral analysis. + https://en.wikipedia.org/wiki/Window_function#Flat_top_window + """ + a0 = 0.21557895 + a1 = 0.41663158 + a2 = 0.277263158 + a3 = 0.083578947 + a4 = 0.006947368 + + n = torch.arange(size, dtype=torch.float, device=device) + window = ( + a0 + - a1 * torch.cos(2 * torch.pi * n / (size - 1)) + + a2 * torch.cos(4 * torch.pi * n / (size - 1)) + - a3 * torch.cos(6 * torch.pi * n / (size - 1)) + + a4 * torch.cos(8 * torch.pi * n / (size - 1)) + ) + return window diff --git a/timbreremap/loss.py b/timbreremap/loss.py new file mode 100644 index 0000000..4c36e76 --- /dev/null +++ b/timbreremap/loss.py @@ -0,0 +1,27 @@ +""" +Differentiable Loss Functions +""" +import torch + + +class FeatureDifferenceLoss(torch.nn.Module): + """ + Loss function that calculates the error between the difference of two features + and a target difference. + """ + + def __init__(self, loss: callable = torch.nn.L1Loss()): + super().__init__() + self.loss = loss + + def forward( + self, + y_pred: torch.tensor, + y_true: torch.tensor, + target_diff: torch.tensor, + weight: float = 1.0, + ): + diff = y_pred - y_true + error = torch.abs(diff - target_diff) + error = error * weight + return torch.mean(error) diff --git a/timbreremap/model.py b/timbreremap/model.py new file mode 100644 index 0000000..868a793 --- /dev/null +++ b/timbreremap/model.py @@ -0,0 +1,102 @@ +""" +Neural Network Models and related functions +""" +import math + +import torch + + +def scale_function( + x: torch.Tensor, + exponent: float = 4.0, + max_value: float = 2.0, + threshold: float = -1.0, +): + """ + Scales a parameter to a range of [threshold, max_value] with a slope of exponent. + A threshold is used to stabilize the gradient near zero. + """ + return max_value * torch.sigmoid(x) ** math.log(exponent) + threshold + + +class MLP(torch.nn.Module): + """ + Configurable multilayer perceptron + """ + + def __init__( + self, + in_size: int, # Input parameter size + hidden_size: int, # Hidden layer size + out_size: int, # Output parameter size + num_layers: int, # Number of hidden layers + activation: torch.nn.Module = torch.nn.Sigmoid(), # Activation function + scale_output: bool = False, # Scale output to [-1, 1] + input_bias: float = 0.0, # Bias for the input layer + layer_norm: bool = False, # Use layer normalization + normalize_input: bool = False, # Normalize input + init_std: float = 1e-3, # Standard deviation of initial weights + ): + super().__init__() + channels = [in_size] + (num_layers) * [hidden_size] + net = [] + for i in range(num_layers): + net.append(torch.nn.Linear(channels[i], channels[i + 1])) + if layer_norm: + net.append( + torch.nn.LayerNorm(channels[i + 1], elementwise_affine=False) + ) + net.append(activation) + + net.append(torch.nn.Linear(channels[-1], out_size)) + self.in_size = in_size + self.net = torch.nn.Sequential(*net) + self.scale_output = scale_output + self.input_bias = input_bias + self.normalize_input = normalize_input + self.init_std = init_std + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, torch.nn.Linear): + with torch.no_grad(): + if isinstance(m, torch.nn.Linear) and m.bias is not None: + torch.nn.init.constant_(m.bias, 0) + torch.nn.init.normal_(m.weight, 0, self.init_std) + + def forward(self, x: torch.Tensor): + x = x + self.input_bias + x = self.net(x) + if self.scale_output: + x = torch.tanh(x) + return x + + +class LinearMapping(torch.nn.Module): + """ + Single layer of linear mappings from input to output. + """ + + def __init__( + self, + in_size: int, # Input parameter size + out_size: int, # Output parameter size + input_bias: float = 0.0, # Bias for the input layer + bias: bool = False, # Optional bias parameters + clamp: bool = False, # Clamp output range + init_std: float = 1e-3, # Standard deviation of initial weights + ): + super().__init__() + self.in_size = in_size + self.clamp = clamp + self.input_bias = input_bias + self.net = torch.nn.Linear(in_size, out_size, bias=bias) + torch.nn.init.normal_(self.net.weight, 0.0, init_std) + if bias: + torch.nn.init.constant_(self.net.bias, 0.0) + + def forward(self, x: torch.Tensor): + y = self.net(x + self.input_bias) + if self.clamp: + y = torch.clamp(y, -1.0, 1.0) + return y diff --git a/timbreremap/np/__init__.py b/timbreremap/np/__init__.py new file mode 100644 index 0000000..1fe2e4c --- /dev/null +++ b/timbreremap/np/__init__.py @@ -0,0 +1,6 @@ +from .core import EnvelopeFollower +from .core import HighPassFilter +from .core import OnsetDetection +from .core import OnsetFrames +from .features import Loudness +from .features import SpectralCentroid diff --git a/timbreremap/np/core.py b/timbreremap/np/core.py new file mode 100644 index 0000000..099a84f --- /dev/null +++ b/timbreremap/np/core.py @@ -0,0 +1,174 @@ +""" +Core funcionality for the numpy backend. +""" +import numpy as np +from numba import jit +from scipy import signal + + +class HighPassFilter: + """ + Simple implementation of a high-pass filter + """ + + def __init__( + self, sr: int, cutoff: float, q: float = 0.707, peak_gain: float = 0.0 + ): + self.sr = sr + self.cutoff = cutoff + self.q = q + self.peak_gain = peak_gain + self._init_filter() + + def _init_filter(self): + K = np.tan(np.pi * self.cutoff / self.sr) + norm = 1 / (1 + K / self.q + K * K) + self.a0 = 1 * norm + self.a1 = -2 * self.a0 + self.a2 = self.a0 + self.b1 = 2 * (K * K - 1) * norm + self.b2 = (1 - K / self.q + K * K) * norm + + def __call__(self, x: np.array): + assert x.ndim == 2 and x.shape[0] == 1 + y = signal.lfilter( + [self.a0, self.a1, self.a2], [1, self.b1, self.b2], x, axis=1, zi=None + ) + return y + + +@jit(nopython=True) +def envelope_follower(x: np.array, up: float, down: float, initial: float = 0.0): + y = np.zeros_like(x) + y0 = initial + for i in range(y.shape[-1]): + if x[0, i] > y0: + y0 = up * (x[0, i] - y0) + y0 + else: + y0 = down * (x[0, i] - y0) + y0 + y[0, i] = y0 + return y + + +class EnvelopeFollower: + def __init__(self, attack_samples: int, release_samples: int): + self.up = 1.0 / attack_samples + self.down = 1.0 / release_samples + + def __call__(self, x: np.array, initial: float = 0.0): + assert x.ndim == 2 and x.shape[0] == 1 + return envelope_follower(x, self.up, self.down, initial=initial) + + +@jit(nopython=True) +def detect_onset(x: np.array, on_thresh: float, off_thresh: float, wait: int): + debounce = -1 + onsets = [] + for i in range(1, x.shape[-1]): + if x[0, i] >= on_thresh and x[0, i - 1] < on_thresh and debounce == -1: + onsets.append(i) + debounce = wait + + if debounce > 0: + debounce -= 1 + + if debounce == 0 and x[0, i] < off_thresh: + debounce = -1 + + return onsets + + +class OnsetDetection: + def __init__( + self, + sr: int, + on_thresh: float = 16.0, + off_thresh: float = 4.6666, + wait: int = 1323, + min_db: float = -55.0, + eps: float = 1e-8, + ): + self.env_fast = EnvelopeFollower(3.0, 383.0) + self.env_slow = EnvelopeFollower(2205.0, 2205.0) + self.high_pass = HighPassFilter(sr, 600.0) + self.on_thresh = on_thresh + self.off_thresh = off_thresh + self.min_db = min_db + self.wait = wait + self.eps = eps + + def _onset_signal(self, x: np.array): + # Filter + x = self.high_pass(x) + + # Rectify, convert to dB, and set minimum value + x = np.abs(x) + x = 20 * np.log10(x + self.eps) + x[x < self.min_db] = self.min_db + + # Calculate envelope + env_fast = self.env_fast(x, initial=self.min_db) + env_slow = self.env_slow(x, initial=self.min_db) + diff = env_fast - env_slow + + return diff + + def __call__(self, x: np.array): + assert x.ndim == 2 and x.shape[0] == 1, "Monophone audio only." + + # Calculate envelope + onset = self._onset_signal(x) + onsets = detect_onset(onset, self.on_thresh, self.off_thresh, self.wait) + + return onsets + + +class OnsetFrames: + def __init__( + self, + sr: int, + frame_size: int, + pad_overlap: bool = True, # Prevent overlap between frames with padding + overlap_buffer: int = 32, # Number of samples to look ahead for overlap + backtrack: int = 0, # Number of samples to backtrack for extraction + **kwargs + ): + self.sr = sr + self.frame_size = frame_size + self.pad_overlap = pad_overlap + self.overlap_buffer = overlap_buffer + self.backtrack = backtrack + self.onset = OnsetDetection(sr, **kwargs) + + def __call__(self, x: np.array): + assert x.ndim == 2 and x.shape[0] == 1, "Monophone audio only." + + # Compute onsets + onsets = self.onset(x) + + # Extract frames + frames = [] + for j, onset in enumerate(onsets): + # Extract the frame -- avoid overlap if pad_overlap is True + start = max(onset - self.backtrack, 0) + if ( + self.pad_overlap + and j < len(onsets) - 1 + and (onsets[j + 1] - start < self.frame_size) + ): + frame = x[0, start : onsets[j + 1] - self.overlap_buffer] + + # Apply a fade out to the end of the frame + fade = np.hanning(self.overlap_buffer * 2)[self.overlap_buffer :] + frame[-self.overlap_buffer :] *= fade + else: + frame = x[0, start : start + self.frame_size] + + # Pad with zeros if necessary + if frame.shape[-1] < self.frame_size: + frame = np.pad(frame, (0, self.frame_size - frame.shape[-1])) + + assert frame.shape[-1] == self.frame_size + frames.append(frame) + + return np.array(frames) diff --git a/timbreremap/np/features.py b/timbreremap/np/features.py new file mode 100644 index 0000000..2069622 --- /dev/null +++ b/timbreremap/np/features.py @@ -0,0 +1,45 @@ +""" +Audio feature extraction functions for numpy arrays. +""" +import numpy as np +import torch + + +class Loudness: + def __init__(self, db: bool = False, eps: float = 1e-8): + super().__init__() + self.db = db + self.eps = eps + + def __call__(self, frames: np.array): + assert frames.ndim == 2 + + # Calculate RMS + rms = np.sqrt(np.mean(np.square(frames), axis=1)) + + # Convert to dB + if self.db: + rms = 20 * np.log10(rms + self.eps) + + return rms + + +class SpectralCentroid: + def __init__(self): + pass + + def __call__(self, frames: np.array): + assert frames.ndim == 2 + + # Calculate FFT + X = np.fft.rfft(frames, axis=1) + X = np.abs(X) + + # Normalize -- using the torch version for compatibility + X_norm = torch.nn.functional.normalize(torch.from_numpy(X), p=1, dim=-1).numpy() + + # Calculate spectral centroid + bins = np.arange(X.shape[1]) + spectral_centroid = np.sum(bins * X_norm, axis=1) + + return spectral_centroid diff --git a/timbreremap/optuna.py b/timbreremap/optuna.py new file mode 100644 index 0000000..3c5df47 --- /dev/null +++ b/timbreremap/optuna.py @@ -0,0 +1,134 @@ +""" +Hyperparameter optimization using Optuna CLI +""" +import copy +import logging +import os +import sys +from functools import partial + +import optuna +from lightning.pytorch.cli import LightningCLI +from optuna.integration import PyTorchLightningPruningCallback + +# Setup logging +logging.basicConfig() +log = logging.getLogger(__name__) +log.setLevel(level=os.environ.get("LOGLEVEL", "INFO")) + + +def objective(trial, args=None): + n_layers = trial.suggest_int("num_layers", 1, 3) + hidden_size = trial.suggest_categorical( + "hidden_size", [2, 4, 8, 16, 32, 64, 128, 256, 512, 1024] + ) + init_var = trial.suggest_float("init_var", 1e-7, 1.0, log=True) + activation = trial.suggest_categorical( + "activation", + [ + "torch.nn.LeakyReLU", + "torch.nn.ReLU", + "torch.nn.Sigmoid", + "torch.nn.Tanh", + ], + ) + learning_rate = trial.suggest_float("learning_rate", 1e-7, 0.01, log=True) + scale_output = trial.suggest_categorical("scale_output", [True, False]) + + cfg_args = copy.deepcopy(args) + cfg_args.extend(["--trainer.max_epochs", "200"]) + + # Hyperparameters + cfg_args.extend(["--model.model.num_layers", str(n_layers)]) + cfg_args.extend(["--model.model.hidden_size", str(hidden_size)]) + cfg_args.extend(["--model.model.init_var", str(init_var)]) + cfg_args.extend(["--model.model.activation", str(activation)]) + cfg_args.extend(["--model.model.scale_output", str(scale_output)]) + + cfg_args.extend(["--optimizer.lr", str(learning_rate)]) + + sys.argv = cfg_args + + cli = LightningCLI(run=False) + cli.trainer.callbacks.append( + PyTorchLightningPruningCallback(trial, monitor="val/loss") + ) + + try: + cli.trainer.fit(cli.model, cli.datamodule) + except optuna.TrialPruned as e: + log.info(e) + return cli.trainer.callback_metrics["val/loss"].item() + except Exception as e: + log.error(e) + return None + + assert cli.trainer.callback_metrics["val/loss"].item() is not None + return cli.trainer.callback_metrics["val/loss"].item() + + +def objective_linear(trial, args=None): + init_var = trial.suggest_float("init_std", 1e-7, 1.0, log=True) + learning_rate = trial.suggest_float("learning_rate", 1e-7, 0.01, log=True) + + cfg_args = copy.deepcopy(args) + cfg_args.extend(["--trainer.max_epochs", "200"]) + + # Hyperparameters + cfg_args.extend(["--model.model.init_std", str(init_var)]) + cfg_args.extend(["--optimizer.lr", str(learning_rate)]) + + sys.argv = cfg_args + + cli = LightningCLI(run=False) + cli.trainer.callbacks.append( + PyTorchLightningPruningCallback(trial, monitor="val/loss") + ) + + try: + cli.trainer.fit(cli.model, cli.datamodule) + except optuna.TrialPruned as e: + log.info(e) + return cli.trainer.callback_metrics["val/loss"].item() + except Exception as e: + log.error(e) + return None + + assert cli.trainer.callback_metrics["val/loss"].item() is not None + return cli.trainer.callback_metrics["val/loss"].item() + + +def run_optuna(): + """ """ + args = sys.argv + if "--linear" in args: + args.remove("--linear") + args.extend(["--model.model", "cfg/models/linear_mapper.yaml"]) + objective_func = partial(objective_linear, args=args) + name = "Linear Mapper" + else: + objective_func = partial(objective, args=args) + name = "MLP Mapper" + + pruner = optuna.pruners.MedianPruner() + study = optuna.create_study( + direction="minimize", + pruner=pruner, + storage="sqlite:///db.sqlite3", + study_name=f"808 Snare HyperParam Jan29 - {name}", + load_if_exists=True, + ) + study.optimize(objective_func, n_trials=200) + + print("Number of finished trials: {}".format(len(study.trials))) + + print("Best trial:") + trial = study.best_trial + + print(" Value: {}".format(trial.value)) + + print(" Params: ") + for key, value in trial.params.items(): + print(" {}: {}".format(key, value)) + + return diff --git a/timbreremap/synth.py b/timbreremap/synth.py new file mode 100644 index 0000000..1af9490 --- /dev/null +++ b/timbreremap/synth.py @@ -0,0 +1,654 @@ +""" +Modules for synthesizing drum sounds + +TODO: should we use torchsynth for this? +""" +import json +from collections import OrderedDict +from typing import Dict +from typing import Optional +from typing import Tuple +from typing import Union + +import torch +import torchaudio +from einops import repeat + + +class ParamaterNormalizer: + """ + Holds min and max values for a parameter and provides methods for normalizing + between 0 and 1 and vice versa + """ + + def __init__(self, min_value, max_value, description=None): + self.min_value = min_value + self.max_value = max_value + self.description = description + + def __repr__(self): + return ( + f"ParamaterNormalizer(Min: {self.min_value}, Max: {self.max_value}, " + f"Desc: {self.description})" + ) + + def from_0to1(self, x): + return self.min_value + (self.max_value - self.min_value) * x + + def to_0to1(self, x): + return (x - self.min_value) / (self.max_value - self.min_value) + + +class AbstractModule(torch.nn.Module): + def __init__(self, sample_rate: int): + super().__init__() + self.sample_rate = sample_rate + self.normalizers = OrderedDict() + + def forward(self, *args, **kwargs): + raise NotImplementedError + + +class ExpDecayEnvelope(AbstractModule): + """ + Exponential decay envelope + C++ version: ExpDecayEnvelope + """ + + def __init__( + self, sample_rate: int, decay_min: float = 10.0, decay_max: float = 2000.0 + ): + super().__init__(sample_rate=sample_rate) + self.sample_rate = sample_rate + self.normalizers["decay"] = ParamaterNormalizer( + decay_min, decay_max, "decay time ms" + ) + self.attack_samples = int(0.001 * self.sample_rate) + self.attack_incr = 1.0 / self.attack_samples + + def forward(self, num_samples: int, decay: torch.Tensor): + assert decay.ndim == 2 + assert decay.shape[1] == 1 + assert ( + decay.min() >= 0.0 and decay.max() <= 1.0 + ), "param must be between 0 and 1" + + # Calculated the samplewise decay rate + decay_ms = self.normalizers["decay"].from_0to1(decay) + decay_samples = decay_ms * self.sample_rate / 1000.0 + decay_rate = 1.0 - (6.91 / decay_samples) + + # Calculate the decay envelope + decay_samples = num_samples - self.attack_samples + assert decay_samples > 0, "num_samples must be greater than attack_samples" + + env = torch.ones(decay_rate.shape[0], decay_samples, device=decay_rate.device) + env[:, 1:] = decay_rate + env = torch.cumprod(env, dim=-1) + + # Add attack + attack = torch.ones( + decay_rate.shape[0], self.attack_samples, device=decay_rate.device + ) + attack = torch.cumsum(attack * self.attack_incr, dim=-1) + + # Combine attack and decay + env = torch.cat((attack, env), dim=-1) + return env + + +class ExponentialDecay(AbstractModule): + """ + Exponential decay envelope + """ + + def __init__(self, sample_rate: int): + super().__init__(sample_rate=sample_rate) + self.sample_rate = sample_rate + self.normalizers["decay"] = ParamaterNormalizer(10.0, 2000.0, "decay time ms") + + def forward(self, num_samples: int, decay: torch.Tensor): + assert decay.ndim == 2 + assert decay.shape[1] == 1 + + # Calculated the samplewise decay rate + decay_ms = self.normalizers["decay"].from_0to1(decay) + decay_samples = decay_ms * self.sample_rate / 1000.0 + decay_rate = 1.0 - (6.91 / decay_samples) + + # Calculate the envelope + env = torch.ones(decay_rate.shape[0], num_samples, device=decay_rate.device) + env[:, 1:] = decay_rate + env = torch.cumprod(env, dim=-1) + + return env + + +class SinusoidalOscillator(AbstractModule): + """ + A sinusoidal oscillator + C++ version: SinusoidalOscillator + + TODO: slight numerical differences between cpp and py versions when using modulation + """ + + def __init__(self, sample_rate: int): + super().__init__(sample_rate=sample_rate) + self.normalizers["freq"] = ParamaterNormalizer(20.0, 2000.0, "frequency (Hz)") + self.normalizers["mod"] = ParamaterNormalizer( + -1.0, 2.0, "freq envelope amount (ratio)" + ) + + def forward( + self, + num_samples: int, + freq: torch.Tensor, + mod_env: torch.Tensor, + mod_amount: torch.Tensor, + ): + assert freq.ndim == 2 + assert freq.shape[1] == 1 + assert mod_amount.shape == freq.shape + assert mod_env.ndim == 2 + assert mod_env.shape[1] == num_samples + + # Calculate the phase + f0 = self.normalizers["freq"].from_0to1(freq) + f0 = 2 * torch.pi * f0 / self.sample_rate + + freq_env = torch.ones(f0.shape[0], num_samples, device=f0.device) * f0 + + mod_amount = self.normalizers["mod"].from_0to1(mod_amount) + mod_env = mod_env * mod_amount * f0 + freq_env = freq_env + mod_env + + # Add a zero to the beginning of the envelope for zero intial phase + freq_env = torch.cat((torch.zeros_like(freq_env[:, :1]), freq_env), dim=-1) + + # Integrate to get the phhase + phase = torch.cumsum(freq_env, dim=-1)[:, :-1] + + # Generate the signal + y = torch.sin(phase) + + return y + + +class Tanh(AbstractModule): + """ + tanh waveshaper + C++ version: Tanh + """ + + def __init__(self, sample_rate: int): + super().__init__(sample_rate) + self.normalizers["in_gain"] = ParamaterNormalizer( + -24.0, 24.0, "input gain (db)" + ) + + def forward(self, x: torch.Tensor, in_gain: torch.Tensor): + in_gain = self.normalizers["in_gain"].from_0to1(in_gain) + in_gain = torch.pow(10.0, in_gain / 20.0) + return torch.tanh(in_gain * x) + + +class Gain(AbstractModule): + """ + Gain module with a gain parameter in decibels + C++ version: Gain + """ + + def __init__( + self, sample_rate: int, min_gain: float = -60.0, max_gain: float = 6.0 + ): + super().__init__(sample_rate) + self.normalizers["gain"] = ParamaterNormalizer(min_gain, max_gain, "gain (db)") + + def forward(self, x: torch.Tensor, gain: torch.Tensor): + gain = self.normalizers["gain"].from_0to1(gain) + gain = torch.pow(10.0, gain / 20.0) + return gain * x + + +class WhiteNoise(AbstractModule): + """ + White noise generator + C++ version: WhiteNoise + """ + + def __init__( + self, + sample_rate: int, + buffer_noise: bool = False, + buffer_size: int = 0, + device: torch.device = "cpu", + ): + super().__init__(sample_rate) + if buffer_noise: + assert buffer_size > 0, "buffer_size must be greater than 0" + noise = torch.rand(1, buffer_size, device=device) * 2.0 - 1.0 + self.register_buffer("noise", noise) + + def forward(self, batch_size: int, num_samples: int, device: torch.device): + if hasattr(self, "noise"): + y = repeat(self.noise, "1 n -> b n", b=batch_size) + y = y[:, :num_samples] + else: + y = torch.rand(batch_size, num_samples, device=device) * 2.0 - 1.0 + return y + + +class CrossFade(AbstractModule): + """ + Cross fade between two signals + C++ version: CrossFade + """ + + def __init__(self, sample_rate: int): + super().__init__(sample_rate) + self.normalizers["fade"] = ParamaterNormalizer(0.0, 1.0, "fade amount") + + def forward(self, x1: torch.Tensor, x2: torch.Tensor, fade: torch.Tensor): + fade = self.normalizers["fade"].from_0to1(fade) + return torch.sqrt(fade) * x1 + torch.sqrt((1.0 - fade)) * x2 + + +class Biquad(AbstractModule): + """ + Biquad Filter with Butterworth coefficients + """ + + def __init__(self, sample_rate: int, filter_type: str = "lowpass"): + super().__init__(sample_rate) + + # Only lowpass and highpass are supported for now + self.filter_type = filter_type + assert self.filter_type in ["lowpass", "highpass"] + + self.normalizers["freq"] = ParamaterNormalizer( + 20.0, sample_rate // 2, "cutoff freq" + ) + self.normalizers["q"] = ParamaterNormalizer(0.5, 10.0, "q") + + def lowpass_coefficients(self, freq: torch.Tensor, q: torch.Tensor): + """ + Calculate the coefficients for a lowpass filter + """ + freq = torch.tan(torch.pi * freq / self.sample_rate) + norm = 1 / (1 + freq / q + freq * freq) + a0 = freq * freq * norm + a1 = 2 * a0 + a2 = a0 + b1 = 2 * (freq * freq - 1) * norm + b2 = (1 - freq / q + freq * freq) * norm + a_coefs = torch.cat((a0, a1, a2), dim=-1) + b_coefs = torch.cat((torch.ones_like(a0), b1, b2), dim=-1) + return a_coefs, b_coefs + + def highpass_coefficients(self, freq: torch.Tensor, q: torch.Tensor): + """ + Calculate the coefficients for a highpass filter + """ + freq = torch.tan(torch.pi * freq / self.sample_rate) + norm = 1 / (1 + freq / q + freq * freq) + a0 = 1 * norm + a1 = -2 * a0 + a2 = a0 + b1 = 2 * (freq * freq - 1) * norm + b2 = (1 - freq / q + freq * freq) * norm + a_coefs = torch.cat((a0, a1, a2), dim=-1) + b_coefs = torch.cat((torch.ones_like(a0), b1, b2), dim=-1) + return a_coefs, b_coefs + + def forward(self, x: torch.Tensor, freq: torch.Tensor, q: torch.Tensor): + freq = self.normalizers["freq"].from_0to1(freq) + q = self.normalizers["q"].from_0to1(q) + + if self.filter_type == "lowpass": + a_coefs, b_coefs = self.lowpass_coefficients(freq, q) + elif self.filter_type == "highpass": + a_coefs, b_coefs = self.highpass_coefficients(freq, q) + + # Order of coefficients is reversed from documentation + y = torchaudio.functional.lfilter(x, b_coefs, a_coefs, batching=True) + return y + + +class AbstractSynth(torch.nn.Module): + """ + Abstract synthesizer class + """ + + def __init__(self, sample_rate: int, num_samples: int): + super().__init__() + self.sample_rate = sample_rate + self.num_samples = num_samples + + def forward(self, params: Dict[str, torch.Tensor]): + raise NotImplementedError + + def get_param_dict(self): + """ + Returns a dictionary of parameters and their normalizers + """ + param_dict = OrderedDict() + for name, module in self.named_modules(): + if not hasattr(module, "normalizers"): + continue + for param_name, normalizer in module.normalizers.items(): + param_dict[(name, param_name)] = normalizer + return param_dict + + def get_num_params(self): + """ + Returns the number of parameters in the synthesizer + """ + return len(self.get_param_dict()) + + def params_from_dict(self, param_dict: Dict[str, Union[float, torch.Tensor]]): + """ + Converts a dictionary of parameter values to a tensor of parameters that + are normalized between 0 and 1 + """ + normalizers = self.get_param_dict() + params = [] + for key, value in param_dict.items(): + normalizer = normalizers[key] + if isinstance(value, float): + value = torch.tensor([value]) + params.append(normalizer.to_0to1(value)) + + return torch.vstack(params).T + + def damping_from_dict(self, damping_dict: Dict[str, Union[float, torch.Tensor]]): + damping = [] + for name, module in self.named_modules(): + if not hasattr(module, "normalizers"): + continue + for param_name, _ in module.normalizers.items(): + key = (name, param_name) + if key in damping_dict: + damping.append(damping_dict[key]) + else: + damping.append(1.0) + + return torch.tensor(damping).unsqueeze(0) + + @staticmethod + def save_params_json( + path: str, + patch: Dict[Tuple[str, str], float], + damping: Dict[Tuple[str, str], float] = None, + ): + param_json = {"preset": {}} + for k, v in patch.items(): + assert "." not in k[0], "Parameter names cannot contain '.'" + assert "." not in k[1], "Parameter names cannot contain '.'" + new_key = k[0] + "." + k[1] + param_json["preset"][new_key] = v + + if damping is not None: + param_json["damping"] = {} + for k, v in damping.items(): + assert "." not in k[0], "Parameter names cannot contain '.'" + assert "." not in k[1], "Parameter names cannot contain '.'" + new_key = k[0] + "." + k[1] + param_json["damping"][new_key] = v + + with open(path, "w") as f: + json.dump(param_json, f, indent=4) + + def load_params_json(self, path: str, as_tensor: bool = True): + with open(path, "r") as f: + param_json = json.load(f) + + patch = {} + for k, v in param_json["preset"].items(): + module_name, param_name = k.split(".") + patch[(module_name, param_name)] = v + + damping = {} + if "damping" in param_json: + for k, v in param_json["damping"].items(): + module_name, param_name = k.split(".") + damping[(module_name, param_name)] = v + + # Normalize into tensors + if as_tensor: + patch = self.params_from_dict(patch) + damping = self.damping_from_dict(damping) + + return patch, damping + + +class SimpleDrumSynth(AbstractSynth): + def __init__( + self, sample_rate, num_samples: int, buffer_noise=False, buffer_size=0 + ): + super().__init__(sample_rate=sample_rate, num_samples=num_samples) + self.amp_env = ExpDecayEnvelope(sample_rate=sample_rate) + self.freq_env = ExpDecayEnvelope(sample_rate=sample_rate) + self.osc = SinusoidalOscillator(sample_rate=sample_rate) + self.tanh = Tanh(sample_rate=sample_rate) + self.gain = Gain(sample_rate=sample_rate) + self.noise_env = ExpDecayEnvelope(sample_rate=sample_rate) + self.noise = WhiteNoise( + sample_rate=sample_rate, buffer_noise=buffer_noise, buffer_size=buffer_size + ) + self.tonal_gain = Gain(sample_rate=sample_rate) + self.noise_gain = Gain(sample_rate=sample_rate) + + def forward(self, params: torch.Tensor, num_samples: Optional[int] = None): + if num_samples is None: + num_samples = self.num_samples + + # Split params -- These should be the same order as the normalizers, + # which is defined by the order returned by get_param_dict() + # TODO: this is easy to mess up, is there a better way? + assert params.shape[-1] == self.get_num_params() + ( + decay, + freq_decay, + freq, + freq_mod, + in_gain, + out_gain, + noise_decay, + tonal_gain, + noise_gain, + ) = torch.split(params, 1, dim=-1) + + freq_env = self.freq_env(num_samples, freq_decay) + + # Generate signal + y = self.osc(num_samples, freq, freq_env, freq_mod) + + # Generate envelope + env = self.amp_env(num_samples, decay) + + # Apply envelope + y = y * env + + # Generate noise + noise = self.noise(params.shape[0], num_samples, device=params.device) + noise_env = self.noise_env(num_samples, noise_decay) + noise = noise * noise_env + noise = self.noise_gain(noise, noise_gain) + + # Add noise + y = self.tonal_gain(y, tonal_gain) + y = y + noise + + y = self.tanh(y, in_gain) + y = self.gain(y, out_gain) + + return y + + +class FMDrumSynth(AbstractSynth): + """ + FM drum synthesizer + """ + + def __init__( + self, sample_rate, num_samples: int, buffer_noise=False, buffer_size=0 + ): + super().__init__(sample_rate=sample_rate, num_samples=num_samples) + self.amp_env = ExpDecayEnvelope(sample_rate=sample_rate) + self.freq_env = ExpDecayEnvelope(sample_rate=sample_rate) + self.osc = SinusoidalOscillator(sample_rate=sample_rate) + self.mod_amp_env = ExpDecayEnvelope(sample_rate=sample_rate) + self.mod_freq_env = ExpDecayEnvelope(sample_rate=sample_rate) + self.mod_osc = SinusoidalOscillator(sample_rate=sample_rate) + self.mod_gain = Gain(sample_rate, -48.0, 48.0) + self.noise_env = ExpDecayEnvelope(sample_rate=sample_rate) + self.noise = WhiteNoise( + sample_rate=sample_rate, buffer_noise=buffer_noise, buffer_size=buffer_size + ) + self.tonal_gain = Gain(sample_rate=sample_rate) + self.noise_gain = Gain(sample_rate=sample_rate) + self.tanh = Tanh(sample_rate=sample_rate) + + def forward(self, params: torch.Tensor, num_samples: Optional[int] = None): + if num_samples is None: + num_samples = self.num_samples + # Split params -- These should be the same order as the normalizers, + # which is defined by the order returned by get_param_dict() + assert params.shape[-1] == self.get_num_params() + ( + amp_decay, + freq_decay, + freq, + osc_mod, + mod_amp_decay, + mod_freq_decay, + mod_freq, + mod_osc_mod, + mod_gain, + noise_decay, + tonal_gain, + noise_gain, + tanh_gain, + ) = torch.split(params, 1, dim=-1) + + mod_freq_env = self.mod_freq_env(num_samples, mod_freq_decay) + mod_amp_decay = self.mod_amp_env(num_samples, mod_amp_decay) + y_mod = self.mod_osc(num_samples, mod_freq, mod_freq_env, mod_osc_mod) + y_mod = y_mod * mod_amp_decay + y_mod = self.mod_gain(y_mod, mod_gain) + + freq_env = self.freq_env(num_samples, freq_decay) + amp_env = self.amp_env(num_samples, amp_decay) + y = self.osc(num_samples, freq, freq_env + y_mod, osc_mod) + y = y * amp_env + + # Generate noise + noise = self.noise(params.shape[0], num_samples, device=params.device) + noise_env = self.noise_env(num_samples, noise_decay) + noise = noise * noise_env + noise = self.noise_gain(noise, noise_gain) + + # Add noise + y = self.tonal_gain(y, tonal_gain) + y = y + noise + + y = self.tanh(y, tanh_gain) + + return y + + +class Snare808(AbstractSynth): + def __init__( + self, sample_rate, num_samples: int, buffer_noise=False, buffer_size=0 + ): + super().__init__(sample_rate=sample_rate, num_samples=num_samples) + self.osc1 = SinusoidalOscillator(sample_rate=sample_rate) + self.osc2 = SinusoidalOscillator(sample_rate=sample_rate) + self.freq_env = ExpDecayEnvelope(sample_rate=sample_rate) + self.osc1_env = ExpDecayEnvelope(sample_rate=sample_rate) + self.osc2_env = ExpDecayEnvelope(sample_rate=sample_rate) + self.noise = WhiteNoise( + sample_rate=sample_rate, buffer_noise=buffer_noise, buffer_size=buffer_size + ) + self.noise_env = ExpDecayEnvelope(sample_rate=sample_rate) + self.noise_filter = Biquad(sample_rate=sample_rate, filter_type="highpass") + self.osc1_gain = Gain(sample_rate=sample_rate) + self.osc2_gain = Gain(sample_rate=sample_rate) + self.noise_gain = Gain(sample_rate=sample_rate) + self.tanh = Tanh(sample_rate=sample_rate) + + def forward(self, params: torch.Tensor, num_samples: Optional[int] = None): + if num_samples is None: + num_samples = self.num_samples + + # Split params -- These should be the same order as the normalizers, + # which is defined by the order returned by get_param_dict() + assert params.shape[-1] == self.get_num_params() + ( + osc1_freq, + osc1_mod, + osc2_freq, + osc2_mod, + freq_decay, + osc1_decay, + osc2_decay, + noise_decay, + noise_freq, + noise_q, + osc1_gain, + osc2_gain, + noise_gain, + tanh_gain, + ) = torch.split(params, 1, dim=-1) + + freq_env = self.freq_env(num_samples, freq_decay) + + # Generate oscillators + y1 = self.osc1(num_samples, osc1_freq, freq_env, osc1_mod) + y2 = self.osc2(num_samples, osc2_freq, freq_env, osc2_mod) + + # Generate oscillator envelopee + env1 = self.osc1_env(num_samples, osc1_decay) + env2 = self.osc2_env(num_samples, osc2_decay) + + # Apply envelopes and sum oscillators + y = self.osc1_gain(y1, osc1_gain) * env1 + y = y + self.osc2_gain(y2, osc2_gain) * env2 + + # Generate noise + noise = self.noise(params.shape[0], num_samples, device=params.device) + noise_env = self.noise_env(num_samples, noise_decay) + noise = noise * noise_env + noise = self.noise_filter(noise, noise_freq, noise_q) + noise = self.noise_gain(noise, noise_gain) + + # Add noise and waveshape + y = y + noise + y = self.tanh(y, tanh_gain) + + return y + + +class SimpleSynth(torch.nn.Module): + """ + Basic synthesizer that generates a sine wave with a static envelope + """ + + def __init__(self, sample_rate: int = 44100): + super().__init__() + self.sample_rate = sample_rate + + # Static envelope + self.decay_rate = 1.0 - (6.91 / sample_rate) + self.env = torch.ones(sample_rate) + self.env[1:] = self.decay_rate + self.env = torch.cumprod(self.env, dim=0) + + def forward(self, gain: torch.Tensor): + f0 = torch.ones(1, self.sample_rate) * 100.0 + f0 = 2 * torch.pi * f0 / self.sample_rate + phase = torch.cumsum(f0, dim=-1) + y = torch.sin(phase) + y = y * self.env + + return y * gain diff --git a/timbreremap/tasks.py b/timbreremap/tasks.py new file mode 100644 index 0000000..8ae72a4 --- /dev/null +++ b/timbreremap/tasks.py @@ -0,0 +1,178 @@ +""" +PyTorch Lightning modules for training models +""" +import logging +import os + +import lightning as L +import torch +from einops import repeat +from torchmetrics import Metric + +from timbreremap.feature import CascadingFrameExtactor +from timbreremap.feature import FeatureCollection +from timbreremap.synth import AbstractSynth + +# Setup logging +logging.basicConfig() +log = logging.getLogger(__name__) +log.setLevel(level=os.environ.get("LOGLEVEL", "INFO")) + + +class TimbreRemappingTask(L.LightningModule): + """ + A LightningModule to train a synthesizer timbre remapping model + """ + + def __init__( + self, + model: torch.nn.Module, # Parameter mapping model + synth: AbstractSynth, # Synthesizer + feature: torch.nn.Module, # A feature extractor + loss_fn: torch.nn.Module, # A loss function + preset: str = None, # The preset to be modulated (loaded from a json) + ): + super().__init__() + self.model = model + self.synth = synth + self.feature = feature + self.loss_fn = loss_fn + + preset, damping = load_preset_and_damping(synth, preset) + self.register_buffer("preset", preset) + log.info(f"Loaded preset: {preset}") + + if damping is not None: + self.register_buffer("damping", damping) + log.info(f"Loaded damping: {damping}") + + # Compute the reference features from the preset + reference = self.feature(self.synth(self.preset)) + self.register_buffer("reference", reference) + + # Setup feature metrics if the feature is a CascadingFrameExtactor + feature_metrics = {} + pretrain_feature_metrics = {} + labels = [] + if isinstance(self.feature, CascadingFrameExtactor): + labels.extend(self.feature.flattened_features) + elif isinstance(self.feature, FeatureCollection): + for feature in self.feature.features: + if isinstance(feature, CascadingFrameExtactor): + labels.extend(feature.flattened_features) + else: + log.warning( + "Feature is not a CascadingFrameExtactor, " + "feature metrics will not be calculated" + ) + for label in labels: + feature_metrics["_".join(label)] = FeatureErrorMetric() + pretrain_feature_metrics["pre" + "_".join(label)] = FeatureErrorMetric() + self.feature_metrics = torch.nn.ModuleDict(feature_metrics) + self.pretrain_feature_metrics = torch.nn.ModuleDict(pretrain_feature_metrics) + + def forward(self, inputs: torch.Tensor): + # Pass the input features through a the parameter mapping model + param_mod = self.model(inputs) + + # Apply parameter-wise damping if provided + if hasattr(self, "damping"): + param_mod = param_mod * self.damping + + # Modulate the preset with output of the parameter mapping model + assert param_mod.shape[-1] == self.preset.shape[-1] + params = self.preset + param_mod + params = torch.clip(params, 0.0, 1.0) + + # Pass the output of the parameter mapping model through the synth + y = self.synth(params) + return y + + def _do_step(self, batch, batch_idx): + if len(batch) == 3: + inputs, target, norm = batch + else: + inputs, target = batch + norm = 1.0 + + y = self(inputs) + + # Calculate features from the input audio + y_features = self.feature(y) + + # Calculate the feature difference loss + # TODO: add feature norm? Should it come from the dataset? + # Can this actually just be something like LayerNorm in the loss module? + features = (y_features, self.reference, target, norm) + loss = self.loss_fn(y_features, self.reference, target, norm) + return loss, features, y + + def training_step(self, batch, batch_idx): + loss, _, _ = self._do_step(batch, batch_idx) + self.log("train/loss", loss, prog_bar=True, on_epoch=True) + return loss + + def validation_step(self, batch, batch_idx): + loss, _, _ = self._do_step(batch, batch_idx) + self.log("val/loss", loss, prog_bar=True, on_epoch=True) + return loss + + def test_step(self, batch, batch_idx): + loss, features, _ = self._do_step(batch, batch_idx) + self.log("test/loss", loss, prog_bar=True) + + y_features, reference, target, norm = features + diff = y_features - reference + + for i, (k, v) in enumerate(self.feature_metrics.items()): + v.update(diff[..., i], target[..., i]) + self.log(f"test/{k}", v, prog_bar=True) + + for i, (k, v) in enumerate(self.pretrain_feature_metrics.items()): + ref = repeat(reference, "1 n -> b n", b=target.shape[0]) + v.update(ref[..., i], target[..., i]) + self.log(f"test/{k}", v, prog_bar=True) + + return loss + + +def load_preset_and_damping(synth: AbstractSynth, preset: str = None): + # Try loading the preset + damping = None + if preset is not None: + preset, damping = synth.load_params_json(preset) + else: + log.warning("No preset provided, using random preset") + preset = torch.rand(1, synth.get_num_params()) + + if preset.shape != (1, synth.get_num_params()): + raise ValueError( + f"preset must be of shape (1, {synth.get_num_params()}), " + f"received {preset.shape}" + ) + + # Register the damping if provided + if damping is not None: + if damping.shape != (1, synth.get_num_params()): + raise ValueError( + f"damping must be of shape (1, {synth.get_num_params()}), " + f"received {damping.shape}" + ) + + return preset, damping + + +class FeatureErrorMetric(Metric): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.add_state("error", default=torch.tensor(0.0), dist_reduce_fx="sum") + self.add_state("total", default=torch.tensor(0.0), dist_reduce_fx="sum") + + def update(self, diff: torch.Tensor, target: torch.Tensor) -> None: + assert diff.shape == target.shape + error = torch.abs(diff - target) + self.error += torch.sum(error) + self.total += diff.shape[0] + + def compute(self) -> torch.Tensor: + return self.error.float() / self.total diff --git a/timbreremap/utils/model.py b/timbreremap/utils/model.py new file mode 100644 index 0000000..2a3c283 --- /dev/null +++ b/timbreremap/utils/model.py @@ -0,0 +1,51 @@ +""" +Helpful utils for handling pre-trained models +""" +import sys +from typing import List +from typing import Optional +from unittest.mock import patch + +import lightning as L +import torch +from lightning.pytorch.cli import LightningArgumentParser +from lightning.pytorch.cli import LightningCLI + + +class CustomCLI(LightningCLI): + """ + PyTorch Lightning CLI + """ + + def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None: + super().add_arguments_to_parser(parser) + parser.add_argument("--ckpt_path", type=str, help="Placeholder") + + +def load_model( + config: str, + ckpt: Optional[str] = None, + device: str = "cpu", + extra_args: Optional[List[str]] = None, + load_data: bool = True, +): + """ + Load a model from a checkpoint using a config file. + """ + args = ["fit", "-c", str(config), "--trainer.accelerator", device] + if extra_args is not None: + args.extend(extra_args) + + datamodule = None + if not load_data: + datamodule = L.LightningDataModule + + with patch.object(sys, "argv", args): + cli = CustomCLI(run=False, datamodule_class=datamodule) + model = cli.model + + if ckpt is not None: + state_dict = torch.load(ckpt, map_location=device)["state_dict"] + model.load_state_dict(state_dict) + + return model, cli