-
Notifications
You must be signed in to change notification settings - Fork 2
/
validate.py
48 lines (40 loc) · 1.53 KB
/
validate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
from tqdm import tqdm
import torch
import pandas as pd
from asteroid.metrics import get_metrics
def validate_sisnr(model, val_loader):
with tqdm(total=len(val_loader.dataset)) as pbar:
model.eval()
total_val_loss = 0
for mix, clean in val_loader:
with torch.no_grad():
mix, clean = mix.cuda(), clean.cuda()
outputs = model(mix) # [B, fft//2, 4803]
val_loss = model.loss(outputs[1], clean, loss_mode='SI-SNR')
total_val_loss += float(val_loss)
pbar.set_description(
f"val_loss: {val_loss.item():.5f}"
)
pbar.update(mix.size(0))
return total_val_loss
def validate_pesq(model, val_loader):
with tqdm(total=len(val_loader.dataset)) as pbar:
model.eval()
total_val_loss = 0
for mix, clean in val_loader:
with torch.no_grad():
# mix, clean = mix.cuda(), clean.cuda()
outputs = model(mix.cuda())
utt_metrics = get_metrics(
mix=mix.data.numpy(),
clean=clean.data.numpy(),
estimate=outputs[1].cpu().data.numpy(),
sample_rate=16000,
metrics_list=["pesq"],
)
total_val_loss += utt_metrics["pesq"]
pbar.set_description(
f"pesq: {utt_metrics['pesq']:.5f}"
)
pbar.update(mix.size(0))
return total_val_loss