-
Notifications
You must be signed in to change notification settings - Fork 0
/
build_hn.py
95 lines (82 loc) · 3.35 KB
/
build_hn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
from argparse import ArgumentParser
from transformers import AutoTokenizer
import os
import random
from tqdm import tqdm
from datetime import datetime
from multiprocessing import Pool
from tevatron.preprocessor import MarcoPassageTrainPreProcessor as TrainPreProcessor
def load_ranking(rank_file, relevance, n_sample, depth, collection_size):
with open(rank_file) as rf:
lines = iter(rf)
q_0, p_0, _ = next(lines).strip().split()
curr_q = q_0
negatives = [] if p_0 in relevance[q_0] else [p_0]
while True:
try:
q, p, _ = next(lines).strip().split()
if q != curr_q:
negatives = negatives[:depth]
random.shuffle(negatives)
if len(negatives) < n_sample:
rand_negaives = random.sample(range(collection_size), n_sample - len(negatives))
negatives.extend(rand_negaives)
yield curr_q, relevance[curr_q], negatives[:n_sample]
curr_q = q
negatives = [] if p in relevance[q] else [p]
else:
if p not in relevance[q]:
negatives.append(p)
except StopIteration:
negatives = negatives[:depth]
random.shuffle(negatives)
if len(negatives) < n_sample:
rand_negaives = random.sample(range(collection_size), n_sample - len(negatives))
negatives.extend(rand_negaives)
yield curr_q, relevance[curr_q], negatives[:n_sample]
return
random.seed(datetime.now())
parser = ArgumentParser()
parser.add_argument('--tokenizer_name', required=True)
parser.add_argument('--hn_file', required=True)
parser.add_argument('--qrels', required=True)
parser.add_argument('--queries', required=True)
parser.add_argument('--collection', required=True)
parser.add_argument('--save_to', required=True)
parser.add_argument('--cache_dir', required=True)
parser.add_argument('--sep_token', type=str, default=' ')
parser.add_argument('--truncate', type=int, default=128)
parser.add_argument('--n_sample', type=int, default=30)
parser.add_argument('--depth', type=int, default=200)
parser.add_argument('--mp_chunk_size', type=int, default=500)
parser.add_argument('--shard_size', type=int, default=45000)
args = parser.parse_args()
qrel = TrainPreProcessor.read_qrel(args.qrels)
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, use_fast=True)
processor = TrainPreProcessor(
query_file=args.queries,
collection_file=args.collection,
tokenizer=tokenizer,
max_length=args.truncate,
cache_dir=args.cache_dir,
sep_token=args.sep_token
)
counter = 0
shard_id = 0
f = None
os.makedirs(args.save_to, exist_ok=True)
pbar = tqdm(load_ranking(args.hn_file, qrel, args.n_sample, args.depth, len(processor.collection)))
with Pool() as p:
for x in p.imap(processor.process_one, pbar, chunksize=args.mp_chunk_size):
counter += 1
if f is None:
f = open(os.path.join(args.save_to, f'split{shard_id:02d}.hn.json'), 'w')
pbar.set_description(f'split - {shard_id:02d}')
f.write(x + '\n')
if counter == args.shard_size:
f.close()
f = None
shard_id += 1
counter = 0
if f is not None:
f.close()