forked from Xiaobin-Rong/gtcrn
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathinfer.py
24 lines (19 loc) · 782 Bytes
/
infer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import os
import torch
import soundfile as sf
from gtcrn import GTCRN
## load model
device = torch.device("cpu")
model = GTCRN().eval()
ckpt = torch.load(os.path.join('checkpoints', 'model_trained_on_dns3.tar'), map_location=device)
model.load_state_dict(ckpt['model'])
## load data
mix, fs = sf.read(os.path.join('test_wavs', 'mix.wav'), dtype='float32')
assert fs == 16000
## inference
input = torch.stft(torch.from_numpy(mix), 512, 256, 512, torch.hann_window(512).pow(0.5), return_complex=False)
with torch.no_grad():
output = model(input[None])[0]
enh = torch.istft(output, 512, 256, 512, torch.hann_window(512).pow(0.5), return_complex=False)
## save enhanced wav
sf.write(os.path.join('test_wavs', 'enh.wav'), enh.detach().cpu().numpy(), fs)