-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathalphabet.py
executable file
·116 lines (98 loc) · 3.94 KB
/
alphabet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
# -*- coding: utf-8 -*-
# follow NCRF++, a Neural Sequence Labeling Toolkit.
"""
Alphabet maps objects to integer ids. It provides two way mapping from the index to the objects.
"""
import json
import os
import sys
class Alphabet:
def __init__(self, name, label=False, keep_growing=True):
self.name = name
self.UNKNOWN = "</unk>"
self.label = label
self.instance2index = {}
self.instances = []
self.keep_growing = keep_growing
# Index 0 is occupied by default, all else following.
self.default_index = 0
self.next_index = 1
if not self.label:
self.add(self.UNKNOWN)
def clear(self, keep_growing=True):
self.instance2index = {}
self.instances = []
self.keep_growing = keep_growing
# Index 0 is occupied by default, all else following.
self.default_index = 0
self.next_index = 1
def add(self, instance):
if instance not in self.instance2index:
self.instances.append(instance)
self.instance2index[instance] = self.next_index
self.next_index += 1
def get_index(self, instance):
try:
return self.instance2index[instance]
except KeyError:
if self.keep_growing:
index = self.next_index
self.add(instance)
return index
else:
return self.instance2index[self.UNKNOWN]
def get_instance(self, index):
if index == 0:
if self.label:
return self.instances[0]
# First index is occupied by the wildcard element.
return None
try:
return self.instances[index - 1]
except IndexError:
print('WARNING:Alphabet get_instance ,unknown instance, return the first label.')
return self.instances[0]
def size(self):
# if self.label:
# return len(self.instances)
# else:
return len(self.instances) + 1 # 为啥要+1??????
def iteritems(self):
if sys.version_info[0] < 3: # If using python3, dict item access uses different syntax
return self.instance2index.iteritems()
else:
return self.instance2index.items()
def enumerate_items(self, start=1):
if start < 1 or start >= self.size():
raise IndexError("Enumerate is allowed between [1 : size of the alphabet)")
return zip(range(start, len(self.instances) + 1), self.instances[start - 1:])
def close(self):
self.keep_growing = False
def open(self):
self.keep_growing = True
def get_content(self):
return {'instance2index': self.instance2index, 'instances': self.instances}
def from_json(self, data):
self.instances = data["instances"]
self.instance2index = data["instance2index"]
def save(self, output_directory, name=None):
"""
Save both alhpabet records to the given directory.
:param output_directory: Directory to save model and weights.
:param name: The alphabet saving name, optional.
:return:
"""
saving_name = name if name else self.__name
try:
json.dump(self.get_content(), open(os.path.join(output_directory, saving_name + ".json"), 'w'))
except Exception as e:
print("Exception: Alphabet is not saved: " % repr(e))
def load(self, input_directory, name=None):
"""
Load model architecture and weights from the give directory. This allow we use old models even the structure
changes.
:param input_directory: Directory to save model and weights
:return:
"""
loading_name = name if name else self.__name
self.from_json(json.load(open(os.path.join(input_directory, loading_name + ".json"))))