forked from zqevans/audio-diffusion
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_ad_dvae.py
executable file
·365 lines (273 loc) · 12.8 KB
/
train_ad_dvae.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
#!/usr/bin/env python3
from prefigure.prefigure import get_all_args, push_wandb_config
from contextlib import contextmanager
from copy import deepcopy
import math
from pathlib import Path
import sys
import torch
from torch import optim, nn
from torch.nn import functional as F
from torch.utils import data
from tqdm import trange
import pytorch_lightning as pl
import numpy as np
from pytorch_lightning.utilities.distributed import rank_zero_only
from einops import rearrange
import torchaudio
import wandb
from dataset.dataset import SampleDataset
from diffusion.pqmf import CachedPQMF as PQMF
from audio_diffusion_pytorch import UNet1d
from audio_encoders_pytorch import Encoder1d
from quantizer_pytorch import Quantizer1d
from diffusion.model import ema_update
from viz.viz import embeddings_table, pca_point_cloud, audio_spectrogram_image, tokens_spectrogram_image
# Define the noise schedule and sampling loop
def get_alphas_sigmas(t):
"""Returns the scaling factors for the clean image (alpha) and for the
noise (sigma), given a timestep."""
return torch.cos(t * math.pi / 2), torch.sin(t * math.pi / 2)
def get_crash_schedule(t):
sigma = torch.sin(t * math.pi / 2) ** 2
alpha = (1 - sigma ** 2) ** 0.5
return alpha_sigma_to_t(alpha, sigma)
def alpha_sigma_to_t(alpha, sigma):
"""Returns a timestep, given the scaling factors for the clean image and for
the noise."""
return torch.atan2(sigma, alpha) / math.pi * 2
@torch.no_grad()
def sample(model, x, steps, eta, **kwargs):
"""Draws samples from a model given starting noise."""
ts = x.new_ones([x.shape[0]])
# Create the noise schedule
t = torch.linspace(1, 0, steps + 1)[:-1]
#t = get_crash_schedule(t)
alphas, sigmas = get_alphas_sigmas(t)
# The sampling loop
for i in trange(steps):
# Get the model output (v, the predicted velocity)
with torch.cuda.amp.autocast():
v = model(x, ts * t[i], **kwargs).float()
# Predict the noise and the denoised image
pred = x * alphas[i] - v * sigmas[i]
eps = x * sigmas[i] + v * alphas[i]
# If we are not on the last timestep, compute the noisy image for the
# next timestep.
if i < steps - 1:
# If eta > 0, adjust the scaling factor for the predicted noise
# downward according to the amount of additional noise to add
ddim_sigma = eta * (sigmas[i + 1]**2 / sigmas[i]**2).sqrt() * \
(1 - alphas[i]**2 / alphas[i + 1]**2).sqrt()
adjusted_sigma = (sigmas[i + 1]**2 - ddim_sigma**2).sqrt()
# Recombine the predicted noise and predicted denoised image in the
# correct proportions for the next step
x = pred * alphas[i + 1] + eps * adjusted_sigma
# Add the correct amount of fresh noise
if eta:
x += torch.randn_like(x) * ddim_sigma
# If we are on the last timestep, output the denoised image
return pred
LAMBDA_QUANTIZER = 1e-5
class DiffusionDVAE(pl.LightningModule):
def __init__(self, global_args):
super().__init__()
self.pqmf_bands = global_args.pqmf_bands
if self.pqmf_bands > 1:
self.pqmf = PQMF(2, 70, global_args.pqmf_bands)
factors = [2, 2, 2, 2, 2]
self.encoder_ratio = np.prod(factors)
latent_dim = 32
self.encoder = Encoder1d(
in_channels=2*global_args.pqmf_bands,
channels=capacity,
multipliers = [1, 2, 4, 8, 16, 32],
factors = factors,
out_channels=latent_dim,
)
self.encoder_ema = deepcopy(self.encoder)
self.diffusion = UNet1d(
in_channels = 2,
channels = 128,
patch_blocks = 1,
patch_factor=1,
multipliers=[1, 1, 2, 2, 4, 8],
factors=[2, 2, 4, 5, 5],
num_blocks=[3, 3, 3, 3, 3],
attentions=[0, 0, 0, 1, 1, 1],
attention_heads=8,
attention_features=64,
attention_multiplier=2,
resnet_groups = 8,
kernel_multiplier_downsample = 2,
use_nearest_upsample = False,
use_skip_scale = True,
use_context_time = True,
use_magnitude_channels = False,
context_channels = [global_args.latent_dim]
)
self.diffusion_ema = deepcopy(self.diffusion)
self.rng = torch.quasirandom.SobolEngine(1, scramble=True)
self.ema_decay = global_args.ema_decay
self.num_quantizers = global_args.num_quantizers
if self.num_quantizers > 0:
self.quantizer = Quantizer1d(
channels = global_args.latent_dim,
num_groups = 1,
codebook_size = global_args.codebook_size,
num_residuals = self.num_quantizers,
shared_codebook = False,
expire_threshold=0.5
)
def configure_optimizers(self):
return optim.Adam([*self.encoder.parameters(), *self.diffusion.parameters()], lr=4e-5)
def get_context(self, latents):
# upsample_factor = self.encoder_ratio
# contexts = []
# for factor in [1, *self.factors]:
# upsample_factor /= factor
# contexts.append(F.interpolate(latents, (int(latents.shape[2] * upsample_factor), ), mode='linear', align_corners=False))
return [F.interpolate(latents, (int(latents.shape[2] * self.encoder_ratio), ), mode='linear', align_corners=False)]
def training_step(self, batch, batch_idx):
reals = batch[0]
encoder_input = reals
if self.pqmf_bands > 1:
encoder_input = self.pqmf(reals)
# Draw uniformly distributed continuous timesteps
t = self.rng.draw(reals.shape[0])[:, 0].to(self.device)
#t = get_crash_schedule(t)
# Calculate the noise schedule parameters for those timesteps
alphas, sigmas = get_alphas_sigmas(t)
# Combine the ground truth images and the noise
alphas = alphas[:, None, None]
sigmas = sigmas[:, None, None]
noise = torch.randn_like(reals)
noised_reals = reals * alphas + noise * sigmas
targets = noise * alphas - reals * sigmas
# Compute the model output and the loss.
with torch.cuda.amp.autocast():
tokens = self.encoder(encoder_input).float()
tokens = torch.tanh(tokens)
if self.num_quantizers > 0:
tokens, quantizer_info = self.quantizer(tokens)
with torch.cuda.amp.autocast():
v = self.diffusion(noised_reals, t, channels_list=self.get_context(tokens))
mse_loss = F.mse_loss(v, targets)
loss = mse_loss
if self.num_quantizers > 0:
quantizer_loss = LAMBDA_QUANTIZER * quantizer_info["loss"]
loss += quantizer_loss
log_dict = {
'train/loss': loss.detach(),
'train/mse_loss': mse_loss.detach(),
}
if self.num_quantizers > 0:
log_dict["train/perplexity"] = quantizer_info["perplexity"].sum()
log_dict["train/n_replaced_codes"] = quantizer_info["replaced_codes"].sum()
log_dict["train/quantizer_loss"] = quantizer_loss
# Log perplexity of each codebook used
for i, perplexity in enumerate(quantizer_info["perplexity"]):
log_dict[f"quantizer/train_perplexity_{i}"] = perplexity
# Log replaced codes of each codebook used
for i, replaced_codes in enumerate(quantizer_info["replaced_codes"]):
log_dict[f"quantizer/train_replaced_codes_{i}"] = replaced_codes
self.log_dict(log_dict, prog_bar=True, on_step=True)
return loss
def on_before_zero_grad(self, *args, **kwargs):
decay = 0.95 if self.current_epoch < 25 else self.ema_decay
ema_update(self.diffusion, self.diffusion_ema, decay)
ema_update(self.encoder, self.encoder_ema, decay)
class ExceptionCallback(pl.Callback):
def on_exception(self, trainer, module, err):
print(f'{type(err).__name__}: {err}', file=sys.stderr)
class DemoCallback(pl.Callback):
def __init__(self, demo_dl, global_args):
super().__init__()
self.demo_every = global_args.demo_every
self.demo_samples = global_args.sample_size
self.demo_steps = global_args.demo_steps
self.demo_dl = iter(demo_dl)
self.sample_rate = global_args.sample_rate
self.pqmf_bands = global_args.pqmf_bands
self.quantized = global_args.num_quantizers > 0
if self.pqmf_bands > 1:
self.pqmf = PQMF(2, 70, global_args.pqmf_bands)
@rank_zero_only
@torch.no_grad()
#def on_train_epoch_end(self, trainer, module):
def on_train_batch_end(self, trainer, module, outputs, batch, batch_idx):
last_demo_step = -1
if (trainer.global_step - 1) % self.demo_every != 0 or last_demo_step == trainer.global_step:
#if trainer.current_epoch % self.demo_every != 0:
return
last_demo_step = trainer.global_step
demo_reals, _ = next(self.demo_dl)
encoder_input = demo_reals
if self.pqmf_bands > 1:
encoder_input = self.pqmf(demo_reals)
encoder_input = encoder_input.to(module.device)
demo_reals = demo_reals.to(module.device)
noise = torch.randn([demo_reals.shape[0], 2, self.demo_samples]).to(module.device)
with torch.no_grad():
tokens = module.encoder_ema(encoder_input)
tokens = torch.tanh(tokens)
if self.quantized:
tokens, quantizer_info = module.quantizer(tokens)
fakes = sample(module.diffusion_ema, noise, self.demo_steps, 0, channels_list=module.get_context(tokens))
# Put the demos together
fakes = rearrange(fakes, 'b d n -> d (b n)')
demo_reals = rearrange(demo_reals, 'b d n -> d (b n)')
#demo_audio = torch.cat([demo_reals, fakes], -1)
try:
log_dict = {}
filename = f'recon_{trainer.global_step:08}.wav'
fakes = fakes.clamp(-1, 1).mul(32767).to(torch.int16).cpu()
torchaudio.save(filename, fakes, self.sample_rate)
reals_filename = f'reals_{trainer.global_step:08}.wav'
demo_reals = demo_reals.clamp(-1, 1).mul(32767).to(torch.int16).cpu()
torchaudio.save(reals_filename, demo_reals, self.sample_rate)
log_dict[f'recon'] = wandb.Audio(filename,
sample_rate=self.sample_rate,
caption=f'Reconstructed')
log_dict[f'real'] = wandb.Audio(reals_filename,
sample_rate=self.sample_rate,
caption=f'Real')
log_dict[f'embeddings'] = embeddings_table(tokens)
log_dict[f'embeddings_3dpca'] = pca_point_cloud(tokens)
log_dict[f'embeddings_spec'] = wandb.Image(tokens_spectrogram_image(tokens))
log_dict[f'real_melspec_left'] = wandb.Image(audio_spectrogram_image(demo_reals))
log_dict[f'recon_melspec_left'] = wandb.Image(audio_spectrogram_image(fakes))
trainer.logger.experiment.log(log_dict, step=trainer.global_step)
except Exception as e:
print(f'{type(e).__name__}: {e}', file=sys.stderr)
def main():
args = get_all_args()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
torch.manual_seed(args.seed)
train_set = SampleDataset([args.training_dir], args)
train_dl = data.DataLoader(train_set, args.batch_size, shuffle=True,
num_workers=args.num_workers, persistent_workers=True, pin_memory=True)
wandb_logger = pl.loggers.WandbLogger(project=args.name)
demo_dl = data.DataLoader(train_set, args.num_demos, num_workers=args.num_workers, shuffle=True)
exc_callback = ExceptionCallback()
ckpt_callback = pl.callbacks.ModelCheckpoint(every_n_train_steps=args.checkpoint_every, save_top_k=-1)
demo_callback = DemoCallback(demo_dl, args)
diffusion_model = DiffusionDVAE(args)
wandb_logger.watch(diffusion_model)
push_wandb_config(wandb_logger, args)
diffusion_trainer = pl.Trainer(
devices=args.num_gpus,
accelerator="gpu",
num_nodes = args.num_nodes,
strategy='ddp',
#precision=16,
accumulate_grad_batches=args.accum_batches,
callbacks=[ckpt_callback, demo_callback, exc_callback],
logger=wandb_logger,
log_every_n_steps=1,
max_epochs=10000000,
)
diffusion_trainer.fit(diffusion_model, train_dl, ckpt_path=args.ckpt_path)
if __name__ == '__main__':
main()