-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathinference.py
110 lines (82 loc) · 3.1 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import argparse
import re
import sys
from pathlib import Path
import torch
from model import CCAPNet
from train import chars_to_tensor, address_chars, label_chars
from address.utils import normalize_text, build_vocabulary
from address.proc_gen import TokenCategory
def parse_arguments():
args = argparse.ArgumentParser(description='CCAPNet inference script.')
args.add_argument('--model-path', action='store', type=str, metavar='PATH', required=True,
help='model to load')
args.add_argument('--input', action='store', type=str, metavar='STR',
help='address line to parse')
args.add_argument('--file', action='store', type=str, metavar='PATH',
help='file containing address lines to parse')
return args.parse_args()
def process_arguments(args):
model_path = Path(args.model_path)
try:
checkpoint = torch.load(model_path)
except:
print('Failed to load model, exiting.')
sys.exit(1)
if args.input and args.file:
print("Cannot use both '--input' and '--file'.")
sys.exit(1)
return checkpoint['model_state_dict']
def check_continuity(tc_list):
seen = []
for cat in tc_list:
if cat == TokenCategory.SEPARATOR:
continue
elif cat not in seen:
seen.append(cat)
elif cat in seen[:-1]:
return False
return True
if __name__ == '__main__':
args = parse_arguments()
state_dict = process_arguments(args)
x_vocab = build_vocabulary(address_chars)
y_vocab = build_vocabulary(label_chars)
y_vocab_inv = {y_vocab[c] : c for c in y_vocab}
tc_inv = {str(e.value) : e for e in TokenCategory}
model = CCAPNet(len(x_vocab), len(y_vocab))
model.load_state_dict(state_dict)
model.eval()
if args.input:
text = normalize_text(args.input)
assert len(text) > 0, 'Normalized text has zero length'
z = chars_to_tensor(text, x_vocab)
z = z.unsqueeze(0)
l = torch.tensor([z.numel()])
with torch.no_grad():
logits = model(z,l)
pred = logits.argmax(2).squeeze(0)
tc_list = [tc_inv[y_vocab_inv[idx]] for idx in pred.tolist()]
if not check_continuity(tc_list):
print('WARNING: parser did not categorize text in continuous blocks')
tokens = {}
for i, cat in enumerate(tc_list):
if cat == TokenCategory.SEPARATOR:
for key in tokens:
tokens[key].append(' ')
continue
if cat.name not in tokens:
tokens[cat.name] = []
tokens[cat.name].append(text[i])
for key in tokens:
s = ''.join(tokens[key])
s = re.sub(r'\s+', ' ', s)
s = s.strip()
tokens[key] = s
print(f"Normalized text: '{text}'")
print(f"Tokens: {tokens}")
elif args.file:
raise NotImplementedError('Cannot handle files yet.')
else:
print("'--input' or '--file' were not provided!")
sys.exit(1)