-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathexpansion.py
146 lines (121 loc) · 5.28 KB
/
expansion.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
from argparse import ArgumentParser
from transformers import BertLMHeadModel, BertTokenizer, DataCollatorWithPadding
import torch
import json
import re
from nltk.corpus import stopwords
from tqdm import tqdm
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import numpy as np
import os
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def clean_vacab(tokenizer, do_stopwords=True):
if do_stopwords:
stop_words = set(stopwords.words('english'))
# keep some common words in ms marco questions
# stop_words.difference_update(["where", "how", "what", "when", "which", "why", "who"])
stop_words.add("definition")
vocab = tokenizer.get_vocab()
tokens = vocab.keys()
good_ids = []
bad_ids = []
for stop_word in stop_words:
ids = tokenizer(stop_word, add_special_tokens=False)["input_ids"]
if len(ids) == 1:
bad_ids.append(ids[0])
for token in tokens:
token_id = vocab[token]
if token_id in bad_ids:
continue
if token[0] == '#' and len(token) > 1:
good_ids.append(token_id)
else:
if not re.match("^[A-Za-z0-9_-]*$", token):
bad_ids.append(token_id)
else:
good_ids.append(token_id)
bad_ids.append(2015) # add ##s to stopwords
return good_ids, bad_ids
class MarcoEncodeDataset(Dataset):
def __init__(self, path, tokenizer, p_max_len=128):
self.tok = tokenizer
self.p_max_len = p_max_len
self.passages = []
self.pids = []
with open(path, 'rt') as fin:
lines = fin.readlines()
for line in tqdm(lines, desc="Loading collection"):
pid, passage = line.split("\t")
self.passages.append(passage)
self.pids.append(pid)
def __len__(self):
return len(self.passages)
def __getitem__(self, item):
psg = self.passages[item]
encoded_psg = self.tok.encode_plus(
psg,
max_length=self.p_max_len,
truncation='only_first',
return_attention_mask=False,
)
encoded_psg.input_ids[0] = 1 # TILDE use token id 1 as the indicator of passage input.
return encoded_psg
def get_pids(self):
return self.pids
def main(args):
model = BertLMHeadModel.from_pretrained("ielab/TILDE", cache_dir='./cache')
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', use_fast=True, cache_dir='./cache')
model.eval().to(DEVICE)
with open(os.path.join(args.output_dir, f"collection-tilde-expanded-top{args.topk}.jsonl"), 'w+') as wf:
_, bad_ids = clean_vacab(tokenizer)
encode_dataset = MarcoEncodeDataset(args.corpus_path, tokenizer)
encode_loader = DataLoader(
encode_dataset,
batch_size=args.batch_size,
collate_fn=DataCollatorWithPadding(
tokenizer,
max_length=128,
padding='max_length'
),
shuffle=False, # important
drop_last=False, # important
num_workers=args.num_workers,
)
pids = encode_dataset.get_pids()
COUNTER = 0
for batch in tqdm(encode_loader):
passage_input_ids = batch.input_ids.numpy()
batch.to(DEVICE)
with torch.no_grad():
logits = model(**batch, return_dict=True).logits[:, 0]
batch_selected = torch.topk(logits, args.topk).indices.cpu().numpy()
expansions = []
for i, selected in enumerate(batch_selected):
expand_term_ids = np.setdiff1d(np.setdiff1d(selected, passage_input_ids[i], assume_unique=True),
bad_ids, assume_unique=True)
expansions.append(expand_term_ids)
for ind, passage_input_id in enumerate(passage_input_ids):
passage_input_id = passage_input_id[passage_input_id != 0][1:] # skip the first special token
expanded_passage = np.append(passage_input_id, expansions[ind]).tolist()
if args.store_raw:
expanded_passage = tokenizer.decode(expanded_passage)
temp = {
"pid": pids[COUNTER],
"psg": expanded_passage
}
COUNTER += 1
wf.write(f'{json.dumps(temp)}\n')
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument('--corpus_path', required=True)
parser.add_argument("--output_dir", required=True)
parser.add_argument('--topk', default=200, type=int, help='k tokens with highest likelihood to be expanded to the original document. '
'NOTE: this is the number before filtering out expanded tokens that already in the original document')
parser.add_argument('--batch_size', default=64, type=int)
parser.add_argument('--num_workers', default=8, type=int)
parser.add_argument('--store_raw', action='store_true', help="True if you want to store expanded raw text. False if you want to expanded store token ids.")
args = parser.parse_args()
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
main(args)