-
Notifications
You must be signed in to change notification settings - Fork 0
/
PruebaAsteroid1.py
34 lines (24 loc) · 1.36 KB
/
PruebaAsteroid1.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
from torch import optim
from pytorch_lightning import Trainer
# We train the same model architecture that we used for inference above.
from asteroid.models import DPRNNTasNet, SuDORMRFImprovedNet
# In this example we use Permutation Invariant Training (PIT) and the SI-SDR loss.
from asteroid.losses import pairwise_neg_sisdr, PITLossWrapper, pairwise_mse
# MiniLibriMix is a tiny version of LibriMix (https://github.com/JorisCos/LibriMix),
# which is a free speech separation dataset.
from asteroid.data import LibriMix
# Asteroid's System is a convenience wrapper for PyTorch-Lightning.
from asteroid.engine import System
# This will automatically download MiniLibriMix from Zenodo on the first run.
train_loader, val_loader = LibriMix.loaders_from_mini(task="sep_clean", batch_size=8)
# Tell DPRNN that we want to separate to 2 sources.
model = DPRNNTasNet(n_src=2)
# PITLossWrapper works with any loss function.
loss = PITLossWrapper(pairwise_neg_sisdr, pit_from="pw_mtx")
#loss = PITLossWrapper(pairwise_mse, pit_from="pw_mtx")
optimizer = optim.Adam(model.parameters(), lr=1e-2)
system = System(model, optimizer, loss, train_loader, val_loader)
# Train for 1 epoch using a single GPU. If you're running this on Google Colab,
# be sure to select a GPU runtime (Runtime → Change runtime type → Hardware accelarator).
trainer = Trainer(max_epochs=1, gpus=1)
trainer.fit(system)