-
Notifications
You must be signed in to change notification settings - Fork 0
/
squad2doc.py
145 lines (119 loc) · 4.63 KB
/
squad2doc.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import json
import sys
import docx
import tqdm
def make_spans(color_map, characters):
result = []
last_span_lst = None
current_text = []
for span_lst, c in zip(color_map, characters):
if last_span_lst is None or last_span_lst == span_lst:
current_text.append(c)
last_span_lst = span_lst
else:
assert last_span_lst is not None and len(current_text) > 0
result.append((last_span_lst, "".join(current_text)))
last_span_lst = span_lst
current_text = [c]
else:
# I really should have something here
assert last_span_lst is not None and len(current_text) > 0
result.append((last_span_lst, "".join(current_text)))
return result
def index_fonts():
palette = []
with open("palette.txt") as f:
for line in f:
line = line.strip()
if not line.startswith("#"):
continue
palette.append(line)
# print(palette)
assert len(palette) == len(set(palette))
rgb_colors = []
for colorspec in palette:
color = docx.shared.RGBColor.from_string(colorspec[1:].upper())
rgb_colors.append(color)
return rgb_colors
rgb_colors = index_fonts()
def para2txt(p, p_idx, doc):
total_len = 0
pgraph = doc.add_paragraph()
t = f"Text number {p_idx}"
total_len += len(t)+1
pgraph.add_run(t).bold = True
pgraph = doc.add_paragraph('')
ctx = p["context"]
color_map = [[] for _ in range(len(ctx))]
for qa in p["qas"]:
q = qa["question"]
q_id = qa["id"]
for a_idx, a in enumerate(qa["answers"]+qa.get("plausible_answers", [])):
atext = a["text"]
a_char_idx = int(a["answer_start"])
assert atext == ctx[a_char_idx:a_char_idx+len(atext)]
for lst in color_map[a_char_idx:a_char_idx+len(atext)]:
lst.append(q_id+"_"+str(a_idx))
# print("ATEXT",atext)
# print("CTXT ",ctx[aidx:aidx+len(atext)])
# print()
# make a lookup table such that each unique overlap of answer ids has a color of its own whew
color_lookup = {"": -1}
for answer_list in color_map: # this is a list of all questions which overlap as "questionid_answeridx" strings
answer_list = "+".join(answer_list) # make it a single string
color_lookup.setdefault(answer_list, len(color_lookup))
# print(color_lookup)
# colors=index_fonts()
spans = make_spans(color_map, ctx)
for span_list, txt in spans:
r = pgraph.add_run(txt)
font_idx = color_lookup.get("+".join(span_list), None)
if font_idx >= 0:
r.font.color.rgb = rgb_colors[font_idx]
total_len += len(ctx)+1
question_list = []
for q_idx, qa in enumerate(p["qas"]):
q = qa["question"]
question_list.append(qa["id"])
pgraph = doc.add_paragraph("")
t = f"Question {q_idx}"
total_len += len(t)
pgraph.add_run(t).bold = True
pgraph = doc.add_paragraph(q)
total_len += len(q)+2
return total_len, color_lookup, question_list
if __name__ == "__main__":
doc = docx.Document()
total_len = 0
file_counter = 0
sizes = []
d_idx = 0
with open("squad2-en/meta.jsonl", "wt") as meta:
for fname in sys.argv[1:]:
with open(fname, "rt") as f:
data = json.load(f)["data"]
for d in tqdm.tqdm(data):
docmeta = {
"file": fname, "title": d["title"], "sequence_idx": d_idx, "paragraphs": []}
title = d["title"]
pgraph = doc.add_paragraph("")
r = pgraph.add_run(f"Document number {d_idx}")
r.bold = True
r.underline = True
ps = d["paragraphs"]
for p_idx, p in enumerate(ps):
l, color_map, question_list = para2txt(p, p_idx, doc)
docmeta["paragraphs"].append(
(p_idx, color_map, question_list))
total_len += l
print(json.dumps(docmeta, sort_keys=True,
ensure_ascii=False), file=meta, flush=True)
doc.add_page_break()
d_idx += 1
if total_len > 900000: # document full!
doc.save(f"squad2-en/squad2_{file_counter:03d}.docx")
file_counter += 1
total_len = 0
doc = docx.Document()
else:
doc.save(f"squad2-en/squad2_{file_counter:03d}.docx")