forked from sp-uhh/sgmse
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcalc_metrics.py
69 lines (57 loc) · 2.9 KB
/
calc_metrics.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
from os.path import join
from glob import glob
from argparse import ArgumentParser
from soundfile import read
from tqdm import tqdm
from pesq import pesq
import pandas as pd
from pystoi import stoi
from sgmse.util.other import energy_ratios, mean_std
if __name__ == '__main__':
parser = ArgumentParser()
parser.add_argument("--test_dir", type=str, required=True, help='Directory containing the original test data (must have subdirectories clean/ and noisy/)')
parser.add_argument("--enhanced_dir", type=str, required=True, help='Directory containing the enhanced data')
args = parser.parse_args()
test_dir = args.test_dir
clean_dir = join(test_dir, "clean/")
noisy_dir = join(test_dir, "noisy/")
enhanced_dir = args.enhanced_dir
data = {"filename": [], "pesq": [], "estoi": [], "si_sdr": [], "si_sir": [], "si_sar": []}
sr = 16000
# Evaluate standard metrics
noisy_files = sorted(glob('{}/*.wav'.format(noisy_dir)))
for noisy_file in tqdm(noisy_files):
filename = noisy_file.split('/')[-1]
x, _ = read(join(clean_dir, filename))
y, _ = read(noisy_file)
n = y - x
x_method, _ = read(join(enhanced_dir, filename))
data["filename"].append(filename)
data["pesq"].append(pesq(sr, x, x_method, 'wb'))
data["estoi"].append(stoi(x, x_method, sr, extended=True))
data["si_sdr"].append(energy_ratios(x_method, x, n)[0])
data["si_sir"].append(energy_ratios(x_method, x, n)[1])
data["si_sar"].append(energy_ratios(x_method, x, n)[2])
# Save results as DataFrame
df = pd.DataFrame(data)
# POLQA evaluation - requires POLQA license and server, uncomment at your own peril.
# This is batch processed for speed reasons and thus runs outside the for loop.
# if not basic:
# clean_files = sorted(glob('{}/*.wav'.format(clean_dir)))
# enhanced_files = sorted(glob('{}/*.wav'.format(enhanced_dir)))
# clean_audios = [read(clean_file)[0] for clean_file in clean_files]
# enhanced_audios = [read(enhanced_file)[0] for enhanced_file in enhanced_files]
# polqa_vals = polqa(clean_audios, enhanced_audios, 16000, save_to=None)
# polqa_vals = [val[1] for val in polqa_vals]
# # Add POLQA column to DataFrame
# df['polqa'] = polqa_vals
# Print results
print(enhanced_dir)
#print("POLQA: {:.2f} ± {:.2f}".format(*mean_std(df["polqa"].to_numpy())))
print("PESQ: {:.2f} ± {:.2f}".format(*mean_std(df["pesq"].to_numpy())))
print("ESTOI: {:.2f} ± {:.2f}".format(*mean_std(df["estoi"].to_numpy())))
print("SI-SDR: {:.1f} ± {:.1f}".format(*mean_std(df["si_sdr"].to_numpy())))
print("SI-SIR: {:.1f} ± {:.1f}".format(*mean_std(df["si_sir"].to_numpy())))
print("SI-SAR: {:.1f} ± {:.1f}".format(*mean_std(df["si_sar"].to_numpy())))
# Save DataFrame as csv file
df.to_csv(join(enhanced_dir, "_results.csv"), index=False)