diff --git a/README.md b/README.md index 5abea20..c8e64c0 100644 --- a/README.md +++ b/README.md @@ -32,11 +32,11 @@ See example in `scripts/contriever_scifact.py`. Note that the nDCG@10 we get for BM25 is much better than in the [paper][2]: instead of `66.5` on row 0, we get '68.4'. The contriever result is also a bit better, with `68.3` instead of `67.7`. Not sure what kind of magic pyterrier is doing here 🤷. -Note that, by default, this codebase uses exhaustive search when querying the dense index. This is not ideal for performance, but it is the setting contriever was evaluated on. If you want to switch to approximate search, you can do so by setting the `factory_config` attribute of `SentenceTransformersRetriever` / `SentenceTransformersIndexer` to any valid index factory string (or pass `factory_config=` to the `contriever_scifact.py` script). I recommend checking out [the faiss docs][3] for more info on the various approximate search options; a good starting point is probably `HNSW`: +Note that, by default, this codebase uses exhaustive search when querying the dense index. This is not ideal for performance, but it is the setting contriever was evaluated on. If you want to switch to approximate search, you can do so by setting the `faiss_factory_config` attribute of `SentenceTransformersRetriever` / `SentenceTransformersIndexer` to any valid index factory string (or pass `faiss_factory_config=` to the `contriever_scifact.py` script). I recommend checking out [the faiss docs][3] for more info on the various approximate search options; a good starting point is probably `HNSW`: ```bash python scripts/contriever_scifact.py \ - factory_config='HNSW32' \ + faiss_factory_config='HNSW32' \ per_call_size=1024 ``` @@ -50,7 +50,7 @@ This gets you close performance to the exact search: Note Note that sometimes you might have to increment the number of passages batch batch (`per_call_size`); this is because the approximate search gets trained using the first batch of passages, and the more passages you have, the better the search will be. -In the example above, switching to `factory_config='HNSW64'` gets you another point of accuracyin nDCG@10, but it will increase query time. +In the example above, switching to `faiss_factory_config='HNSW64'` gets you another point of accuracy in nDCG@10, but it will increase query time. [1]: https://github.com/facebookresearch/faiss/blob/main/INSTALL.md [2]: https://arxiv.org/pdf/2112.09118.pdf diff --git a/src/pyterrier_sentence_transformers/base.py b/src/pyterrier_sentence_transformers/base.py index 88e9900..d9a31bc 100644 --- a/src/pyterrier_sentence_transformers/base.py +++ b/src/pyterrier_sentence_transformers/base.py @@ -30,11 +30,9 @@ class SentenceTransformerConfig: per_gpu_eval_batch_size: int = 128 per_call_size: int = 1_024 num_results: int = 1000 - # faiss_n_subquantizers: int = 0 normalize: bool = True - # faiss_n_bits: int = 8 - factory_config: str = 'Flat' - factory_metric: str = 'METRIC_INNER_PRODUCT' + faiss_factory_config: str = 'Flat' + faiss_factory_metric: str = 'METRIC_INNER_PRODUCT' n_gpu: int = torch.cuda.device_count() @property @@ -135,8 +133,8 @@ def faiss_index(self) -> FaissIndex: # then written to disk index = FaissIndex( vector_sz=embedding_size, - factory_config=self.config.factory_config, - factory_metric=self.config.factory_metric, + factory_config=self.config.faiss_factory_config, + factory_metric=self.config.faiss_factory_metric, ) return index