-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy patheval.py
111 lines (84 loc) · 3.07 KB
/
eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import logging
from typing import List, Tuple
import clingo
import torch
from torch.utils.data import DataLoader
from analysis import MultiClassAccuracyMeter, JaccardScoreMeter
from common import CUBDNDataItem
from rule_learner import DNFBasedClassifier
from utils import get_dnf_classifier_x_and_y
log = logging.getLogger()
def dnf_eval(
model: DNFBasedClassifier,
use_cuda: bool,
data_loader: DataLoader,
use_jaccard_meter: bool = False,
jaccard_threshold: float = 0.0,
do_logging: bool = False,
):
model.eval()
performance_meter = (
JaccardScoreMeter() if use_jaccard_meter else MultiClassAccuracyMeter()
)
for i, data in enumerate(data_loader):
iter_perf_meter = (
JaccardScoreMeter()
if use_jaccard_meter
else MultiClassAccuracyMeter()
)
with torch.no_grad():
x, y = get_dnf_classifier_x_and_y(data, use_cuda)
y_hat = model(x)
if use_jaccard_meter:
y_hat = (torch.tanh(y_hat) > jaccard_threshold).long()
iter_perf_meter.update(y_hat, y)
performance_meter.update(y_hat, y)
if do_logging:
log.info(
"[%3d] Test avg perf: %.3f"
% (i + 1, iter_perf_meter.get_average())
)
if do_logging:
log.info(
"Overall Test avg perf: %.3f" % performance_meter.get_average()
)
return performance_meter.get_average()
def asp_eval(
test_data: List[CUBDNDataItem], rules: List[str], debug: bool = False
) -> Tuple[float, float]:
total_sample_count = 0
correct_count = 0
jaccard_scores = []
for d in test_data:
asp_base = []
for i, a in enumerate(d.attr_present_label):
if a == 1:
asp_base.append(f"has_attr_{i}.")
asp_base += rules
asp_base.append("#show class/1.")
ctl = clingo.Control(["--warn=none"])
ctl.add("base", [], " ".join(asp_base))
ctl.ground([("base", [])])
with ctl.solve(yield_=True) as handle: # type: ignore
all_answer_sets = [str(a) for a in handle]
target_class = f"class({d.label - 1})"
if debug:
# Print out
log.info(f"y: {target_class} AS: {all_answer_sets}")
if len(all_answer_sets) != 1:
# No model or multiple answer sets, should not happen
log.warn(f"No model or multiple answer sets when evaluating rules.")
continue
output_classes = all_answer_sets[0].split(" ")
output_classes_set = set(output_classes)
target_class_set = {target_class}
jacc = len(output_classes_set & target_class_set) / len(
output_classes_set | target_class_set
)
jaccard_scores.append(jacc)
if len(output_classes) == 1 and target_class in output_classes:
correct_count += 1
total_sample_count += 1
accuracy = correct_count / total_sample_count
avg_jacc_score = sum(jaccard_scores) / len(jaccard_scores)
return accuracy, avg_jacc_score