-
Notifications
You must be signed in to change notification settings - Fork 25
/
Copy patheval.py
123 lines (100 loc) · 3.64 KB
/
eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import pickle
import os
import argparse
import torch
import os
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn import metrics
from tqdm import tqdm
import cogmen
log = cogmen.utils.get_logger()
def load_pkl(file):
with open(file, "rb") as f:
return pickle.load(f)
def main(args):
data = load_pkl(f"data/{args.dataset}/data_{args.dataset}.pkl")
model_dict = torch.load(
"model_checkpoints/"
+ str(args.dataset)
+ "_best_dev_f1_model_"
+ str(args.modalities)
+ ".pt",
)
stored_args = model_dict["args"]
model = model_dict["state_dict"]
testset = cogmen.Dataset(data["test"], stored_args)
test = True
with torch.no_grad():
golds = []
preds = []
for idx in tqdm(range(len(testset)), desc="test" if test else "dev"):
data = testset[idx]
golds.append(data["label_tensor"])
for k, v in data.items():
if not k == "utterance_texts":
data[k] = v.to(stored_args.device)
y_hat = model(data)
preds.append(y_hat.detach().to("cpu"))
if stored_args.dataset == "mosei" and stored_args.emotion == "multilabel":
golds = torch.cat(golds, dim=0).numpy()
preds = torch.cat(preds, dim=0).numpy()
f1 = metrics.f1_score(golds, preds, average="weighted")
acc = metrics.accuracy_score(golds, preds)
else:
golds = torch.cat(golds, dim=-1).numpy()
preds = torch.cat(preds, dim=-1).numpy()
f1 = metrics.f1_score(golds, preds, average="weighted")
if test:
print(metrics.classification_report(golds, preds, digits=4))
if stored_args.dataset == "mosei" and stored_args.emotion == "multilabel":
happy = metrics.f1_score(golds[:, 0], preds[:, 0], average="weighted")
sad = metrics.f1_score(golds[:, 1], preds[:, 1], average="weighted")
anger = metrics.f1_score(golds[:, 2], preds[:, 2], average="weighted")
surprise = metrics.f1_score(
golds[:, 3], preds[:, 3], average="weighted"
)
disgust = metrics.f1_score(golds[:, 4], preds[:, 4], average="weighted")
fear = metrics.f1_score(golds[:, 5], preds[:, 5], average="weighted")
f1 = {
"happy": happy,
"sad": sad,
"anger": anger,
"surprise": surprise,
"disgust": disgust,
"fear": fear,
}
print(f"F1 Score: {f1}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="eval.py")
parser.add_argument(
"--dataset",
type=str,
required=True,
default="iemocap_4",
choices=["iemocap", "iemocap_4", "mosei"],
help="Dataset name.",
)
parser.add_argument(
"--data_dir_path", type=str, help="Dataset directory path", default="./data"
)
parser.add_argument("--device", type=str, default="cpu", help="Computing device.")
# Modalities
""" Modalities effects:
-> dimentions of input vectors in dataset.py
-> number of heads in transformer_conv in seqcontext.py"""
parser.add_argument(
"--modalities",
type=str,
default="atv",
# required=True,
choices=["a", "at", "atv"],
help="Modalities",
)
# emotion
parser.add_argument(
"--emotion", type=str, default=None, help="emotion class for mosei"
)
args = parser.parse_args()
main(args)