-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmetrics.py
72 lines (57 loc) · 2.13 KB
/
metrics.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import pandas as pd
from sklearn.metrics import accuracy_score, f1_score, recall_score, balanced_accuracy_score
class EvalMetrics(object):
def __init__(self) -> None:
self._subjects = None
@property
def subjects(self):
return self._subjects
@subjects.setter # setup test set
def subjects(self, subjects):
"""
args:
subjects: list, [test_users, test_items, test_scores]
"""
assert isinstance(subjects, list)
test_users, test_items, test_scores, test_y = (
subjects[0],
subjects[1],
subjects[2],
subjects[3],
)
test_set = pd.DataFrame(
{"user": test_users, "item": test_items, "score": test_scores, "y": test_y}
)
self._subjects = test_set
def cal_acc(self):
"""compute accuracy"""
test_set = self._subjects
test_set["pred"] = 0
test_set.loc[test_set["score"] >= 0.5, "pred"] = 1 # where the test_set['score'] vsriable is >=0.5, set the pred column to 1 -> to compare with actual binary y values of 0 and 1
pred = test_set["pred"].values
true = test_set["y"].values
return accuracy_score(true, pred)
def cal_recall(self):
"""compute recall"""
test_set = self._subjects
test_set["pred"] = 0
test_set.loc[test_set["score"] >= 0.5, "pred"] = 1
pred = test_set["pred"].values
true = test_set["y"].values
return recall_score(true, pred)
def cal_f1(self):
"""compute weighted F1 score"""
test_set = self._subjects
test_set["pred"] = 0
test_set.loc[test_set["score"] >= 0.5, "pred"] = 1
pred = test_set["pred"].values
true = test_set["y"].values
return f1_score(true, pred, average="weighted")
def cal_balanced_acc(self):
"""Compute balanced accuracy"""
test_set = self._subjects
test_set["pred"] = 0
test_set.loc[test_set["score"] >= 0.5, "pred"] = 1
pred = test_set["pred"].values
true = test_set["y"].values
return balanced_accuracy_score(true, pred)