forked from minimaxir/gpt-2-keyword-generation
-
Notifications
You must be signed in to change notification settings - Fork 0
/
keyword_decode.py
74 lines (57 loc) · 2.5 KB
/
keyword_decode.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import re
DELIMS = {
'section': '~',
'category': '`',
'keywords': '^',
'title': '@',
'body': '}'
}
def build_pattern(sections, start_token, end_token):
# sections may not be in the correct order: fix it
key_order = ['category', 'keywords', 'title', 'body']
sections = [section for section in key_order if section in sections]
pattern_text = re.escape(start_token) + '(?:.*)'
for section in sections:
pattern_text += '(?:{})'.format(
re.escape(DELIMS['section'] + DELIMS[section])) + '(.*)'
pattern_text += '(?:.*)' + re.escape(end_token)
return re.compile(pattern_text, flags=re.MULTILINE)
def decode_texts(texts, sections=['title'],
start_token="<|startoftext|>",
end_token="<|endoftext|>"):
# get the index of the group(s) we want to extract
group_indices = [i + 1 for i, section in enumerate(sections)]
assert len(group_indices) > 0
pattern = build_pattern(sections, start_token, end_token)
if not isinstance(texts, (list,)):
texts = [texts]
decoded_texts = []
for text in texts:
decoded_text = re.match(pattern, text)
if decoded_text is None:
continue
decoded_text_attrs = tuple(decoded_text.group(i)
for i in group_indices)
if len(group_indices) == 1:
decoded_text_attrs = decoded_text_attrs[0]
decoded_texts.append(decoded_text_attrs)
return decoded_texts
def decode_file(file_path, out_file='texts_decoded.txt',
doc_delim='=' * 20 + '\n',
sections=['title'],
start_token="<|startoftext|>",
end_token="<|endoftext|>"):
assert len(sections) == 1, "This function only supports output of a single section for now."
doc_pattern = re.compile(re.escape(start_token) +
'(.*)' + re.escape(end_token), flags=re.MULTILINE)
with open(file_path, 'r', encoding='utf8', errors='ignore') as f:
# warning: loads entire file into memory!
docs = re.findall(doc_pattern, f.read())
docs = [start_token + doc + end_token for doc in docs]
decoded_docs = decode_texts(docs,
sections=sections,
start_token=start_token,
end_token=end_token)
with open(out_file, 'w', encoding='utf8', errors='ignore') as f:
for doc in decoded_docs:
f.write("{}\n{}".format(doc, doc_delim))