-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathGatewayExtractionBenchmark.py
201 lines (169 loc) · 9.8 KB
/
GatewayExtractionBenchmark.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
# add parent dir to sys path for import of modules
import os
import sys
# find recursively the project root dir
parent_dir = str(os.getcwdb())
while not os.path.exists(os.path.join(parent_dir, "README.md")):
parent_dir = os.path.abspath(os.path.join(parent_dir, os.pardir))
sys.path.insert(0, parent_dir)
from typing import Dict, List
from copy import deepcopy
import logging
import json
from abc import abstractmethod
from petreader.labels import *
from relation_approaches.AbstractClassificationBenchmark import AbstractClassificationBenchmark
from relation_approaches.GatewayExtractor import GatewayExtractor, Gateway
from relation_approaches.RelationClassifier import NeuralRelationClassifierEnsemble
from relation_approaches.activity_relation_data_preparation import DOC_BLACK_LIST
from relation_approaches.activity_relation_dataset_preparation import TEST_DOCS
from relation_approaches.RelationClassificationBenchmark import get_static_args
from relation_approaches import metrics
from utils import ROOT_DIR, save_as_pickle, flatten_list, config
from PetReader import pet_reader
from labels import *
logger = logging.getLogger('Gateway Extraction Benchmark')
GATEWAY_TYPES = [XOR_GATEWAY, AND_GATEWAY, NO_GATEWAY_RELATIONS]
class GatewayExtractionBenchmark(AbstractClassificationBenchmark):
"""
Creates and evaluates extraction of gateways created by a GatewayExtractor instance
"""
def __init__(self, approach_name: str, gateway_extractor: GatewayExtractor, output_folder: str = None,
round_digits: int = 2):
self.gateway_extractor = gateway_extractor
# prepare output folder
if not output_folder:
output_folder = os.path.join(ROOT_DIR,
f"data/results_relation_approaches/gateway_extraction/{approach_name}")
AbstractClassificationBenchmark.__init__(self, GATEWAY_TYPES, approach_name, output_folder, round_digits)
def evaluate_documents(self, doc_names: List[str] = None):
"""
evaluate list of documents with relation_classifier
:param doc_names: doc names, if none -> all
:return:
"""
if not doc_names:
doc_names = pet_reader.document_names
doc_names = [d for d in doc_names if d not in DOC_BLACK_LIST]
logger.info(f"Create predictions for {len(doc_names)} documents")
gateway_extractions = {doc_name: self.gateway_extractor.extract_document_gateways(doc_name, i + 1)
for i, doc_name in enumerate(doc_names)}
logger.info(f"Compute metrics of all labels in all documents")
all_doc_metrics = self.compute_document_label_metrics(gateway_extractions)
logger.info(f"Average label metrics")
label_avg_metrics = self.average_label_wise(all_doc_metrics)
logger.info(f"Average metrics over all documents")
overall_avg_metrics = metrics.average_metrics([m for label, m in label_avg_metrics.items()], self.round_digits,
self.support_weighted_average)
logger.info(f"Write results & predictions to {self.output_folder}")
self.write_results(gateway_extractions)
self.write_metrics(all_doc_metrics, label_avg_metrics, overall_avg_metrics)
def compute_document_label_metrics(self, gateway_extractions: Dict[str, List[Gateway]]) -> Dict[str, Dict]:
"""
Compute metrics per gateway label/type and document
:param gateway_extractions: dictionary with extracted gateways per document
:return: dictionary with structure {doc-name: {label: {metric: value}}}
"""
all_doc_metrics = {}
for i, doc_name in enumerate(gateway_extractions.keys()):
logger.info(f"Compute metrics for {doc_name} ({i + 1}/{len(gateway_extractions.keys())})")
doc_metrics = {label: self.evaluate_gateway_extractions(doc_name, gateway_extractions[doc_name],
label=label) for label in self.labels}
all_doc_metrics[doc_name] = doc_metrics
return all_doc_metrics
@abstractmethod
def evaluate_gateway_extractions(self, doc_name, gateway_extractions: List[Gateway], label: str) -> Dict:
"""
compute precision, recall and f1 score for given extracted gateways of a document
:param doc_name: document name
:param gateway_extractions: extracted gateways as a list of Gateway objects
:param label: filter gateway extractions and gold standard to this label
:return: dictionary with metrics
"""
pass
def write_results(self, gateway_extractions: Dict[str, List[Gateway]]) -> None:
"""
write single extractions to pickle, json and txt file
:param gateway_extractions: dictionary of gateway extractions document-wise
:return:
"""
save_as_pickle(gateway_extractions, os.path.join(self.output_folder, "predictions.pkl"))
with open(os.path.join(self.output_folder, "predictions.json"), 'w') as file:
json.dump({doc_name: [g.to_json() for g in gateways] for doc_name, gateways in gateway_extractions.items()},
file, indent=4)
with open(os.path.join(self.output_folder, "predictions.txt"), 'w') as file:
for doc_name, gateways in gateway_extractions.items():
file.write(f" {doc_name} ".center(100, '-') + "\n")
for g in gateways:
file.write(str(g) + "\n")
file.write("\n" * 3)
class SimpleGatewayTypeAndNumberBenchmark(GatewayExtractionBenchmark):
"""
Evaluates extraction of gateways created by a GatewayExtractor instance by just checking a gateway of this type
is contained in the PET gold standard
Notes:
- gateway information as location or tokens are not available when extracting with 'GatewayExtractor' based on
activity pair relations
- gateways that are related with a same gateway relation in PET are already extracted as one gateway
i.e. in PET exist more gateways -> handle by subtracting number of same gateway relations
"""
def evaluate_gateway_extractions(self, doc_name, gateway_extractions: List[Gateway], label: str) -> Dict:
"""
compute precision, recall and f1 score for given extracted gateways of a document
:param doc_name: document name
:param gateway_extractions: extracted gateways as a list of Gateway objects
:param label: filter gateway extractions and gold standard to this label
:return: dictionary with metrics
"""
pred_gateways = [g for g in deepcopy(gateway_extractions) if g.check_type_for_evaluation(label)]
same_gateway_relations = pet_reader.get_doc_relations(doc_name)[SAME_GATEWAY]
same_gateway_relations_filtered = [sgr for sgr in same_gateway_relations if sgr[SOURCE_ENTITY_TYPE] == label]
if label == XOR_GATEWAY:
gold_gateways = flatten_list(pet_reader.token_dataset.GetXORGateways(doc_name))
elif label == AND_GATEWAY:
gold_gateways = flatten_list(pet_reader.token_dataset.GetANDGateways(doc_name))
elif label == NO_GATEWAY_RELATIONS:
gold_gateways = []
else:
raise ValueError(f"'{label}' is not a valid Gateway type")
number_gold_gateways = len(gold_gateways)
number_gold_gateways -= len(same_gateway_relations_filtered)
number_pred_gateways = len(pred_gateways)
if number_pred_gateways > number_gold_gateways:
tp = number_gold_gateways
remaining_preds = number_pred_gateways - number_gold_gateways
elif number_pred_gateways <= number_gold_gateways:
tp = number_pred_gateways
remaining_preds = number_gold_gateways - number_pred_gateways
# fp = number of elements in predictions that remain unmatched
fp = remaining_preds
# fn = number of elements in gold standard that remain unmatched
fn = number_gold_gateways - tp
return self.compute_metrics_dict(tp, fp, fn, number_gold_gateways)
if __name__ == '__main__':
logging.basicConfig(level=logging.DEBUG)
logger.setLevel(logging.DEBUG)
# A) evaluate with gold standard relations
# geb_full = SimpleGatewayTypeAndNumberBenchmark(approach_name="ge=standard_rc=goldstandard_vote=full_TESTDOCS",
# gateway_extractor=GatewayExtractor(GoldstandardRelationClassifier(),
# full_branch_vote=True))
# geb_full.evaluate_documents(TEST_DOCS)
#
# geb_limited = SimpleGatewayTypeAndNumberBenchmark(approach_name="ge=standard_rc=goldstandard_vote=limited_TESTDOCS",
# gateway_extractor=GatewayExtractor(
# GoldstandardRelationClassifier(),
# full_branch_vote=False))
# geb_limited.evaluate_documents(TEST_DOCS)
# B) evaluate with BRCNN predicted relations
brcnn_ensemble = NeuralRelationClassifierEnsemble(
ensemble_path=config[MODELS][ACTIVITY_RELATION_CLASSIFIER],
args=get_static_args(),
train_size=100, # only dummy value, not used
# activate in case of memory issues
# seeds=[10]
)
for name, full_branch_vote in [("full", True), ("limited", False)]:
geb_full = SimpleGatewayTypeAndNumberBenchmark(approach_name=f"ge=standard_rc=brcnn128_vote={name}",
gateway_extractor=GatewayExtractor(brcnn_ensemble,
full_branch_vote=full_branch_vote))
geb_full.evaluate_documents(TEST_DOCS)