-
Notifications
You must be signed in to change notification settings - Fork 2
/
create_faiss_index.py
82 lines (73 loc) · 3.72 KB
/
create_faiss_index.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import faiss
import torch
import random
import tqdm
import pickle
import sys
def random_sample_of_batches(batch_files, proportion):
"""Takes a random sample of batches from all batch_files
this is used to make training data for faiss. Proportion
is from [0,1] interval"""
all_batches = []
batch_files = list(batch_files)
random.shuffle(batch_files)
with tqdm.tqdm() as pbar:
for b in batch_files:
with open(b, "rb") as f:
while True:
try:
sent_idx, embedding_batch = pickle.load(f)
# do I want to keep it?
if random.random() < proportion:
all_batches.append(embedding_batch)
pbar.update(embedding_batch.shape[0])
except: # no more batches
break
random.shuffle(all_batches)
print("Got", len(all_batches), "random batches", file=sys.stderr, flush=True)
return torch.vstack(all_batches)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("BATCHFILES", default=None, nargs="+",
help="Batch files saved by embed.py")
parser.add_argument("--prepare-sample", default=None,
help="File name to save the sampled examples to. Prepares sample from batchfiles on which faiss can be trained. Does a 5% sample by default.")
parser.add_argument("--train-faiss", default=None,
help="File name to save the trained faiss index to. BATCHFILES should be a single .pt produced by --prepare-sample")
parser.add_argument("--fill-faiss", default=None, help="Fill faiss index with vectors and save to index with the name given i this argument. BATCHFILES are all batchfiles to store into the index (will be sorted by name). Give the name of the trained index (trained with --train-faiss) in the argument --pretrained-index")
parser.add_argument("--pretrained-index", default=None,
help="Name of the pretrained index to be used for --fill-fais")
args = parser.parse_args()
if args.prepare_sample:
sampled = random_sample_of_batches(sorted(args.BATCHFILES), 0.1)
torch.save(sampled, args.prepare_sample)
elif args.train_faiss:
assert len(
args.BATCHFILES) == 1, "Give one argument which is a .pt file produced by --prepare-sample"
quantizer = faiss.IndexFlatL2(768)
# 768 is bert size, 1024 is how many Voronoi cells we want,
# 48 is number of quantizers, and these are 8-bit
index = faiss.IndexIVFPQ(quantizer, 768, 1024, 48, 8)
sampled_vectors = torch.load(args.BATCHFILES[0])
print("Training on", sampled_vectors.shape, "vectors", flush=True)
# how comes this doesnt take any time at all ...?
index.train(sampled_vectors.numpy())
print("Done training", flush=True)
trained_index = index
faiss.write_index(trained_index, args.train_faiss)
elif args.fill_faiss:
index = faiss.read_index(args.pretrained_index)
all_batches = list(sorted(args.BATCHFILES))
for batchfile in tqdm.tqdm(all_batches):
with open(batchfile, "rb") as f:
while True:
try:
line_idx, embedded_batch = pickle.load(f)
index.add_with_ids(
embedded_batch.numpy(), line_idx.numpy())
except EOFError:
break # no more batches in this file
index_filled = index
faiss.write_index(index_filled, args.fill_faiss)
print("Index has", index_filled.ntotal, "vectors. Done.")