-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbuild_atlas_index_faiss.py
378 lines (325 loc) · 15.9 KB
/
build_atlas_index_faiss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
# This script is used to build the index for similarity search using faiss. It supports loading embeddings and meta labels from directories and add to faiss index. The index is saved to disk for later use.
# Options can be set to custom the building process, including:
# - embedding_dir: the directory to the embedding files
# - meta_dir: the directory to the meta files
# - embedding_file_suffix: the suffix of the embedding files
# - meta_file_suffix: the suffix of the meta files
# - embedding_key: the key to the embedding in the embedding file
# - meta_key: the key to the meta in the meta file
# - gpu: whether to use gpu for building the index
# - num_workers: the number of threads to use for building the index
# - index_desc: the type of the index to build, different index may suits for fast or memory efficient building
# - output_dir: the directory to save the index
import json
import os
from pathlib import Path
from typing import Dict, Optional, Tuple, Union
import faiss
import h5py
import numpy as np
import scanpy as sc
from tqdm import tqdm
PathLike = Union[str, os.PathLike]
class FaissIndexBuilder:
"""
Build index for similarity search using faiss.
"""
def __init__(
self,
embedding_dir: PathLike,
output_dir: PathLike,
meta_dir: Optional[PathLike] = None,
recursive: bool = True,
embedding_file_suffix: str = ".h5ad",
meta_file_suffix: Optional[str] = None,
embedding_key: Optional[str] = None,
meta_key: Optional[str] = "cell_type",
gpu: bool = False,
num_workers: Optional[int] = None,
index_desc: str = "PCA64,IVF16384_HNSW32,PQ16",
):
"""
Initialize an AtlasIndexBuilder object.
Args:
embedding_dir (PathLike): Path to the directory containing the input embeddings.
output_dir (PathLike): Path to the directory where the index files will be saved.
meta_dir (PathLike, optional): Path to the directory containing the metadata files. If None, the metadata will be loaded from the embedding files. Defaults to None.
recursive (bool): Whether to search the embedding and meta directory recursively. Defaults to True.
embedding_file_suffix (str, optional): Suffix of the embedding files. Defaults to ".h5ad" in AnnData format.
meta_file_suffix (str, optional): Suffix of the metadata files. Defaults to None.
embedding_key (str, optional): Key to access the embeddings in the input files. If None, will require the input files to be in AnnData format and use the X field. Defaults to None.
meta_key (str, optional): Key to access the metadata in the input files. Defaults to "cell_type".
gpu (bool): Whether to use GPU acceleration. Defaults to False.
num_workers (int, optional): Number of threads to use for CPU parallelism. If None, will use all available cores. Defaults to None.
index_desc (str, optional): Faiss index factory str, see [here](https://github.com/facebookresearch/faiss/wiki/The-index-factory) and [here](https://github.com/facebookresearch/faiss/wiki/Guidelines-to-choose-an-index#if-1m---10m-ivf65536_hnsw32). Defaults to "PCA64,IVF16384_HNSW32,PQ16".
"""
self.embedding_dir = embedding_dir
self.output_dir = output_dir
self.meta_dir = meta_dir
self.recursive = recursive
self.embedding_file_suffix = embedding_file_suffix
self.meta_file_suffix = meta_file_suffix
self.embedding_key = embedding_key
self.meta_key = meta_key
self.gpu = gpu
self.num_workers = num_workers
self.index_desc = index_desc
if self.num_workers is None:
try:
self.num_workers = len(os.sched_getaffinity(0))
print("Number of available cores: {}".format(self.num_workers))
except Exception:
self.num_workers = min(10, os.cpu_count())
if self.meta_dir is None: # metadata and embeddings are in the same file
self.META_FROM_EMBEDDING = True
if self.embedding_key is None:
if embedding_file_suffix != ".h5ad":
raise ValueError(
"embedding_key is required when embedding_file_suffix is not .h5ad"
)
# See the index factory https://github.com/facebookresearch/faiss/wiki/Lower-memory-footprint#simplifying-index-construction
# Choose index https://github.com/facebookresearch/faiss/wiki/Guidelines-to-choose-an-index, particularly, see these options:
# - https://github.com/facebookresearch/faiss/wiki/Guidelines-to-choose-an-index#if-quite-important-then-opqm_dpqmx4fsr
# - https://github.com/facebookresearch/faiss/wiki/Guidelines-to-choose-an-index#if-quite-important-then-opqm_dpqmx4fsr
# - For the clustering option, see https://github.com/facebookresearch/faiss/wiki/Guidelines-to-choose-an-index#if-quite-important-then-opqm_dpqmx4fsr and https://gist.github.com/mdouze/46d6bbbaabca0b9778fca37ed2bcccf6
# May choose the index option based on the benchmark here https://github.com/facebookresearch/faiss/wiki/Indexing-1G-vectors#10m-datasets
def _load_data(self):
# Load embeddings and meta labels
embeddings = []
meta_labels = []
if self.META_FROM_EMBEDDING and self.embedding_file_suffix == ".h5ad":
embedding_files = (
list(Path(self.embedding_dir).rglob("*" + self.embedding_file_suffix))
if self.recursive
else list(
Path(self.embedding_dir).glob("*" + self.embedding_file_suffix)
)
)
embedding_files = [str(f) for f in embedding_files]
embedding_files = sorted(embedding_files)
if self.num_workers > 1:
raise NotImplementedError
else:
for file in tqdm(
embedding_files, desc="Loading embeddings and metalabels"
):
adata = sc.read(file)
# TODO: set the embedding_key according to self.embedding_key
embedding = adata.X.astype(np.float32)
if not isinstance(embedding, np.ndarray):
embedding = embedding.toarray().astype(np.float32)
meta_label = adata.obs[self.meta_key].values
embeddings.append(embedding)
meta_labels.append(meta_label)
del adata
else:
raise NotImplementedError
embeddings = np.concatenate(embeddings, axis=0, dtype=np.float32)
meta_labels = np.concatenate(meta_labels, axis=0)
assert embeddings.shape[0] == meta_labels.shape[0]
return embeddings, meta_labels
def build_index(self) -> Tuple[faiss.Index, np.ndarray]:
# Load embeddings and meta labels
embeddings, meta_labels = self._load_data()
# Build index
index = faiss.index_factory(
embeddings.shape[1], self.index_desc, faiss.METRIC_L2
)
nprobe = _auto_set_nprobe(index)
if self.gpu:
res = faiss.StandardGpuResources()
index = faiss.index_cpu_to_gpu(res, 0, index)
index.verbose = True
print(
f"Training index {self.index_desc} on {embeddings.shape[0]} embeddings ..."
)
index.train(embeddings)
print("Adding embeddings to index ...")
index.add(embeddings)
# Save index
os.makedirs(self.output_dir, exist_ok=True)
# create sub folder if the output_dir is not empty, and throw warning
if len(os.listdir(self.output_dir)) > 0:
print(
f"Warning: the output_dir {self.output_dir} is not empty, the index will be saved to a sub folder named index"
)
self.output_dir = os.path.join(self.output_dir, "index")
assert not os.path.exists(self.output_dir)
os.makedirs(self.output_dir, exist_ok=True)
# save index file, meta file and a json of index config params
index_file = os.path.join(self.output_dir, "index.faiss")
meta_file = os.path.join(self.output_dir, "meta.h5ad")
index_config_file = os.path.join(self.output_dir, "index_config.json")
faiss.write_index(
faiss.index_gpu_to_cpu(index) if self.gpu else index, index_file
)
with h5py.File(meta_file, "w") as f:
f.create_dataset(
"meta_labels", data=meta_labels, compression="gzip", chunks=True
)
with open(index_config_file, "w") as f:
json.dump(
{
"embedding_dir": self.embedding_dir,
"meta_dir": self.meta_dir,
"recursive": self.recursive,
"embedding_file_suffix": self.embedding_file_suffix,
"meta_file_suffix": self.meta_file_suffix,
"embedding_key": self.embedding_key,
"meta_key": self.meta_key,
"gpu": self.gpu,
"num_workers": self.num_workers,
"index_desc": self.index_desc,
"num_embeddings": embeddings.shape[0],
"num_features": embeddings.shape[1],
"nprobe": nprobe,
},
f,
)
print(f"All files saved to {self.output_dir}")
print(
f"Index saved to {index_file}, "
f"file size: {os.path.getsize(index_file) / 1024 / 1024} MB"
)
return index, meta_labels
def load_index(self) -> Tuple[faiss.Index, np.ndarray]:
"""
Load the index from self.output_dir.
Returns:
faiss.Index: The loaded index and meta labels.
"""
return load_index(self.output_dir, use_config_file=False, use_gpu=self.gpu)
def load_index(
index_dir: PathLike,
use_config_file=True,
use_gpu=False,
nprobe=None,
) -> Tuple[faiss.Index, np.ndarray]:
"""
Load index from disk.
Args:
index_dir (PathLike): Path to the directory containing the index files.
use_config_file (bool, optional): Whether to load the index config file. If True, will load the index config file and use the parameters of gpu, nprobe. Defaults to True.
use_gpu (bool, optional): Whether to use GPU acceleration. Only used when use_config_file is False. Defaults to False.
nprobe (int, optional): The nprobe to set if index contains :class:`faiss.IndexIVF`. If None, will set based on the number of clusters. Only used when use_config_file is False. Defaults to None.
Returns:
faiss.Index: The loaded index and meta labels.
"""
index_file = os.path.join(index_dir, "index.faiss")
meta_file = os.path.join(index_dir, "meta.h5ad")
index_config_file = os.path.join(index_dir, "index_config.json")
print(f"Loading index and meta from {index_dir} ...")
index = faiss.read_index(index_file)
print(f"Index loaded, num_embeddings: {index.ntotal}")
with h5py.File(meta_file, "r") as f:
meta_labels = f["meta_labels"].asstr()[:]
if use_config_file:
with open(index_config_file, "r") as f:
config = json.load(f)
use_gpu = config["gpu"]
nprobe = config["nprobe"]
_auto_set_nprobe(index, nprobe=nprobe)
if use_gpu:
res = faiss.StandardGpuResources()
index = faiss.index_cpu_to_gpu(res, 0, index)
return index, meta_labels
def _auto_set_nprobe(index: faiss.Index, nprobe: int = None) -> Optional[int]:
"""
Set nprobe for IVF index based on the number of clusters.
Args:
index (faiss.Index): The index to set nprobe.
nprobe (int, optional): The nprobe to set. If None, will set based on the number of clusters. Defaults to None.
Returns:
int: The nprobe set.
"""
# set nprobe if IVF index
index_ivf = faiss.try_extract_index_ivf(index)
if index_ivf:
nlist = index_ivf.nlist
ori_nprobe = index_ivf.nprobe
index_ivf.nprobe = (
nprobe
if nprobe is not None
else 16
if nlist <= 1e3
else 32
if nlist <= 4e3
else 64
if nlist <= 1.6e4
else 128
)
print(
f"Set nprobe from {ori_nprobe} to {index_ivf.nprobe} for {nlist} clusters"
)
return index_ivf.nprobe
def compute_category_proportion(meta_labels) -> Dict[str, float]:
"""
Compute the proportion of each cell type in the meta_labels, which can be used for weighted voting in the search.
Args:
meta_labels (numpy.ndarray): A 1D array of cell type labels.
Returns:
dict: A dictionary containing the proportion of each cell type in the input array.
"""
unique_labels, counts = np.unique(meta_labels, return_counts=True)
category_proportion = dict(zip(unique_labels, counts / counts.sum()))
return category_proportion
def weighted_vote(
predicts_for_query, cell_type_proportion, return_prob=True
) -> Tuple[np.ndarray, np.ndarray]:
"""
Use the proportion of each cell type as the weight for voting.
Args:
predicts_for_query (numpy.ndarray): A 1D array of the predicted cell types for an individual query.
cell_type_proportion (dict): A dictionary containing the proportion of each cell type in the meta_labels.
return_prob (bool, optional): Whether to return the probability of each predicted cell type. Defaults to True.
Returns:
numpy.ndarray: A 1D array of the predicted cell types for the input query, weighted by the proportion of each cell type and sorted by the proportion.
numpy.ndarray: A 1D array of the probability of each predicted cell type. Only returned when return_prob is True.
"""
unique_labels, counts = np.unique(predicts_for_query, return_counts=True)
weighted_counts = (
np.clip(counts - 0.01 * counts.sum(), 0, None)
* 1e-3
/ np.array([cell_type_proportion[l] for l in unique_labels])
) # the -1 is to reduce noise
weighted_counts = weighted_counts / weighted_counts.sum()
sorted_idx = np.argsort(weighted_counts)[::-1]
predicts_for_query = unique_labels[sorted_idx]
if return_prob:
return predicts_for_query, weighted_counts[sorted_idx]
return predicts_for_query
def vote(predicts_for_query, return_prob=True) -> Tuple[np.ndarray, np.ndarray]:
"""
Majority voting for the predicted cell types.
Args:
predicts_for_query (numpy.ndarray): A 1D array of the predicted cell types for an individual query.
return_prob (bool, optional): Whether to return the probability of each predicted cell type. Defaults to True.
Returns:
numpy.ndarray: A 1D array of the predicted cell types for the input query, weighted by the proportion of each cell type and sorted by the proportion.
numpy.ndarray: A 1D array of the probability of each predicted cell type. Only returned when return_prob is True.
"""
unique_labels, counts = np.unique(predicts_for_query, return_counts=True)
weighted_counts = counts / counts.sum()
sorted_idx = np.argsort(weighted_counts)[::-1]
predicts_for_query = unique_labels[sorted_idx]
if return_prob:
return predicts_for_query, weighted_counts[sorted_idx]
return predicts_for_query
if __name__ == "__main__":
# Set options
embedding_dir = "/scratch/hdd001/home/haotian/cellxgene_cencus_embed/"
output_dir = "/scratch/hdd001/home/haotian/projects/cellxemb/normal"
embedding_file_suffix = ".h5ad"
gpu = True
index_desc = "PCA64,IVF16384_HNSW32,PQ16"
num_workers = 1
# Build index
builder = FaissIndexBuilder(
embedding_dir,
output_dir=output_dir,
embedding_file_suffix=embedding_file_suffix,
gpu=gpu,
num_workers=num_workers,
index_desc=index_desc,
)
index, meta_labels = builder.build_index()