From 094f2dd3d449c9beedbfc59e60975563f53eeeab Mon Sep 17 00:00:00 2001 From: Luca Soldaini Date: Tue, 11 Oct 2022 14:31:19 -0700 Subject: [PATCH] fixed not being added while training --- src/pyterrier_sentence_transformers/index.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/pyterrier_sentence_transformers/index.py b/src/pyterrier_sentence_transformers/index.py index ae0b366..9ed5a46 100644 --- a/src/pyterrier_sentence_transformers/index.py +++ b/src/pyterrier_sentence_transformers/index.py @@ -67,10 +67,11 @@ def __init__( def index_data(self, ids: List[Any], embeddings: np.ndarray): self._update_id_mapping(ids) embeddings = embeddings.astype('float32') + if not self.index.is_trained: self.index.train(embeddings) - else: - self.index.add(embeddings) + + self.index.add(embeddings) print(f'Total data indexed {len(self.index_id_to_db_id)}')