-
Notifications
You must be signed in to change notification settings - Fork 7
/
util.py
123 lines (102 loc) · 3.85 KB
/
util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import json
import logging
import constant as C
import conlleval
def get_logger(name, level=C.LOGGING_LEVEL, log_file=None):
"""Get a logger by name.
:param name: Logger name (usu. __name__).
:param level: Logging level (default=logging.INFO).
"""
logger = logging.getLogger(name)
logger.addHandler(logging.StreamHandler())
if log_file:
logger.addHandler(logging.FileHandler(log_file, encoding='utf-8'))
logger.setLevel(level)
return logger
def evaluate(results, idx_token, idx_label, writer=None):
"""Evaluate prediction results.
:param results: A List of which each item is a tuple
(predictions, gold labels, sequence lengths, tokens) of a batch.
:param idx_token: Index to token dictionary.
:param idx_label: Index to label dictionary.
:param writer: An object (file object) with a write() function. Extra output.
:return: F-score, precision, and recall.
"""
# b: batch, s: sequence
outputs = []
for preds_b, golds_b, len_b, tokens_b in results:
for preds_s, golds_s, len_s, tokens_s in zip(preds_b, golds_b, len_b, tokens_b):
l = int(len_s.item())
preds_s = preds_s.data.tolist()[:l]
golds_s = golds_s.data.tolist()[:l]
tokens_s = tokens_s.data.tolist()[:l]
for p, g, t in zip(preds_s, golds_s, tokens_s):
token = idx_token.get(t, C.UNK_INDEX)
outputs.append('{} {} {}'.format(
token, idx_label.get(g, 0), idx_label.get(p, 0)))
outputs.append('')
counts = conlleval.evaluate(outputs)
overall, by_type = conlleval.metrics(counts)
conlleval.report(counts)
if writer:
conlleval.report(counts, out=writer)
writer.flush()
return overall.fscore, overall.prec, overall.rec
class Config(dict):
def __init__(self, *args, **kwargs):
super(Config, self).__init__(*args, **kwargs)
__getattr__ = dict.__getitem__
for arg in args:
if isinstance(arg, dict):
for k, v in arg.items():
if isinstance(v, dict):
v = Config(v)
if isinstance(v, list):
v = [Config(x) if isinstance(x, dict) else x for x in v]
self[k] = v
if kwargs:
for k, v in kwargs.items():
self[k] = v
def __setattr__(self, key, value):
self.__setitem__(key, value)
def __setitem__(self, key, value):
super(Config, self).__setitem__(key, value)
self.__dict__.update({key: value})
def __delattr__(self, item):
self.__delitem__(item)
def __delitem__(self, key):
super(Config, self).__delitem__(key)
del self.__dict__[key]
def set_dict(self, dict_obj):
for k, v in dict_obj.items():
if isinstance(v, dict):
v = Config(v)
self[k] = v
def update(self, dict_obj):
for k, v in dict_obj.items():
if isinstance(v, dict):
v = Config(v)
if isinstance(v, list):
v = [Config(x) if isinstance(x, dict) else x for x in v]
self[k] = v
def clone(self):
return Config(dict(self))
@staticmethod
def read(path):
"""Read configuration from JSON format file.
:param path: Path to the configuration file.
:return: Config object.
"""
# logger.info('loading configuration from {}'.format(path))
json_obj = json.load(open(path, 'r', encoding='utf-8'))
return Config(json_obj)
def update_value(self, keys, value):
keys = keys.split('.')
assert len(keys) > 0
tgt = self
for k in keys[:-1]:
try:
tgt = tgt[int(k)]
except Exception:
tgt = tgt[k]
tgt[keys[-1]] = value