Skip to content

Commit

Permalink
Chunked inference for codec (#22)
Browse files Browse the repository at this point in the history
* Adding some devtools.

* Adding delay calculation

* Chunked inference for codec.

* Version bump

* Removing prod.yml, updating to recent main.

* Turning padding off only when chunking codes.

* Updating README, removing unused things.

* Missed a padding.

* Adding some checks to make sure pads are the same.

* Factoring out latent dim, backwards compatible.

* Adding latent dim, and the 44khz 16kbps model config.

* Ran pre-commit.

* Chunked vs unchunked inference.

* Fixing padding stuff.

* n quantizers back in encode

* don't load unsupported versions

* correct docstring

* bitrate config + 16kbps models

* update audiotools dep

* fix argbind issue

* minor correction

* bump version

* change model path

* update audiotools deps

---------

Co-authored-by: prem <[email protected]>
Co-authored-by: Ishaan Kumar <[email protected]>
  • Loading branch information
3 people authored Jul 20, 2023
1 parent 0202b4a commit c7cfc5d
Show file tree
Hide file tree
Showing 17 changed files with 703 additions and 337 deletions.
13 changes: 13 additions & 0 deletions Dockerfile.dev
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
ARG IMAGE=pytorch/pytorch:2.0.1-cuda11.7-cudnn8-runtime
ARG GITHUB_TOKEN=none

FROM $IMAGE

RUN echo machine github.com login ${GITHUB_TOKEN} > ~/.netrc

COPY requirements.txt /requirements.txt

RUN apt update && apt install -y git

# install the package
RUN pip install --upgrade -r /requirements.txt
61 changes: 46 additions & 15 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,33 +66,42 @@ for more options.
### Programmatic Usage
```py
import dac
from dac.utils import load_model
from dac.model import DAC

from dac.utils.encode import process as encode
from dac.utils.decode import process as decode

from audiotools import AudioSignal

# Init an empty model
model = DAC()
# Download a model
model_path = dac.utils.download(model_type="44khz")
model = dac.DAC.load(model_path)

# Load compatible pre-trained model
model = load_model(tag="latest", model_type="44khz")
model.eval()
model.to('cuda')

# Load audio signal file
signal = AudioSignal('input.wav')

# Encode audio signal
encoded_out = encode(signal, 'cuda', model)
# Encode audio signal as one long file
# (may run out of GPU memory on long files)
signal.to(model.device)

x = model.preprocess(signal.audio_data, signal.sample_rate)
z, codes, latents, _, _ = model.encode(x)

# Decode audio signal
recon = decode(encoded_out, 'cuda', model, preserve_sample_rate=True)
y = model.decode(z)

# Alternatively, use the `compress` and `decompress` functions
# to compress long files.

signal = signal.cpu()
x = model.compress(signal)

# Save and load to and from disk
x.save("compressed.dac")
x = dac.DACFile.load("compressed.dac")

# Decompress it back to an AudioSignal
y = model.decompress(x)

# Write to file
recon.write('recon.wav')
y.write('output.wav')
```

### Docker image
Expand Down Expand Up @@ -131,6 +140,28 @@ Please install the correct dependencies
pip install -e ".[dev]"
```
## Environment setup
We have provided a Dockerfile and docker compose setup that makes running experiments easy.
To build the docker image do:
```
docker compose build
```
Then, to launch a container, do:
```
docker compose run -p 8888:8888 -p 6006:6006 dev
```
The port arguments (`-p`) are optional, but useful if you want to launch a Jupyter and Tensorboard instances within the container. The
default password for Jupyter is `password`, and the current directory
is mounted to `/u/home/src`, which also becomes the working directory.
Then, run your training command.
### Single GPU training
```
Expand Down
124 changes: 124 additions & 0 deletions conf/final/44khz-16kbps.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
# Model setup
DAC.sample_rate: 44100
DAC.encoder_dim: 64
DAC.encoder_rates: [2, 4, 8, 8]
DAC.latent_dim: 128
DAC.decoder_dim: 1536
DAC.decoder_rates: [8, 8, 4, 2]

# Quantization
DAC.n_codebooks: 18 # Max bitrate of 16kbps
DAC.codebook_size: 1024
DAC.codebook_dim: 8
DAC.quantizer_dropout: 0.5

# Discriminator
Discriminator.sample_rate: 44100
Discriminator.rates: []
Discriminator.periods: [2, 3, 5, 7, 11]
Discriminator.fft_sizes: [2048, 1024, 512]
Discriminator.bands:
- [0.0, 0.1]
- [0.1, 0.25]
- [0.25, 0.5]
- [0.5, 0.75]
- [0.75, 1.0]

# Optimization
AdamW.betas: [0.8, 0.99]
AdamW.lr: 0.0001
ExponentialLR.gamma: 0.999996

amp: false
val_batch_size: 100
device: cuda
num_iters: 400000
save_iters: [10000, 50000, 100000, 200000]
valid_freq: 1000
sample_freq: 10000
num_workers: 32
val_idx: [0, 1, 2, 3, 4, 5, 6, 7]
seed: 0
lambdas:
mel/loss: 15.0
adv/feat_loss: 2.0
adv/gen_loss: 1.0
vq/commitment_loss: 0.25
vq/codebook_loss: 1.0

VolumeNorm.db: [const, -16]

# Transforms
build_transform.preprocess:
- Identity
build_transform.augment_prob: 0.0
build_transform.augment:
- Identity
build_transform.postprocess:
- VolumeNorm
- RescaleAudio
- ShiftPhase

# Loss setup
MultiScaleSTFTLoss.window_lengths: [2048, 512]
MelSpectrogramLoss.n_mels: [5, 10, 20, 40, 80, 160, 320]
MelSpectrogramLoss.window_lengths: [32, 64, 128, 256, 512, 1024, 2048]
MelSpectrogramLoss.mel_fmin: [0, 0, 0, 0, 0, 0, 0]
MelSpectrogramLoss.mel_fmax: [null, null, null, null, null, null, null]
MelSpectrogramLoss.pow: 1.0
MelSpectrogramLoss.clamp_eps: 1.0e-5
MelSpectrogramLoss.mag_weight: 0.0

# Data
batch_size: 72
train/AudioDataset.duration: 0.38
train/AudioDataset.n_examples: 10000000

val/AudioDataset.duration: 5.0
val/build_transform.augment_prob: 1.0
val/AudioDataset.n_examples: 250

test/AudioDataset.duration: 10.0
test/build_transform.augment_prob: 1.0
test/AudioDataset.n_examples: 1000

AudioLoader.shuffle: true
AudioDataset.without_replacement: true

train/build_dataset.folders:
speech_fb:
- /data/daps/train
speech_hq:
- /data/vctk
- /data/vocalset
- /data/read_speech
- /data/french_speech
speech_uq:
- /data/emotional_speech/
- /data/common_voice/
- /data/german_speech/
- /data/russian_speech/
- /data/spanish_speech/
music_hq:
- /data/musdb/train
music_uq:
- /data/jamendo
general:
- /data/audioset/data/unbalanced_train_segments/
- /data/audioset/data/balanced_train_segments/

val/build_dataset.folders:
speech_hq:
- /data/daps/val
music_hq:
- /data/musdb/test
general:
- /data/audioset/data/eval_segments/

test/build_dataset.folders:
speech_hq:
- /data/daps/test
music_hq:
- /data/musdb/test
general:
- /data/audioset/data/eval_segments/
5 changes: 4 additions & 1 deletion dac/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "0.0.5"
__version__ = "1.0.0"

# preserved here for legacy reasons
__model_version__ = "latest"
Expand All @@ -11,3 +11,6 @@

from . import nn
from . import model
from . import utils
from .model import DAC
from .model import DACFile
2 changes: 1 addition & 1 deletion dac/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import argbind

from dac.utils import ensure_default_model as download
from dac.utils import download
from dac.utils.decode import decode
from dac.utils.encode import encode

Expand Down
1 change: 1 addition & 0 deletions dac/model/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .base import CodecMixin
from .base import DACFile
from .dac import DAC
from .discriminator import Discriminator
Loading

0 comments on commit c7cfc5d

Please sign in to comment.