From 0bec1f68777fb621c5f139af1327f11f79c3729d Mon Sep 17 00:00:00 2001 From: Tymofii Date: Tue, 12 Mar 2024 05:33:03 +0000 Subject: [PATCH] commnity[patch]: refactor code for faiss vectorstore, update faiss vectorstore documentation (#18092) **Description:** Refactor code of FAISS vectorcstore and update the related documentation. Details: - replace `.format()` with f-strings for strings formatting; - refactor definition of a filtering function to make code more readable and more flexible; - slightly improve efficiency of `max_marginal_relevance_search_with_score_by_vector` method by removing unnecessary looping over the same elements; - slightly improve efficiency of `delete` method by using set data structure for checking if the element was already deleted; **Issue:** fix small inconsistency in the documentation (the old example was incorrect and unappliable to faiss vectorstore) **Dependencies:** basic langchain-community dependencies and `faiss` (for CPU or for GPU) **Twitter handle:** antonenkodev --- .../integrations/vectorstores/faiss.ipynb | 108 +++++++----------- .../langchain_community/vectorstores/faiss.py | 104 ++++++++--------- 2 files changed, 91 insertions(+), 121 deletions(-) diff --git a/docs/docs/integrations/vectorstores/faiss.ipynb b/docs/docs/integrations/vectorstores/faiss.ipynb index cdb600a4397bc..022d5a5003fe3 100644 --- a/docs/docs/integrations/vectorstores/faiss.ipynb +++ b/docs/docs/integrations/vectorstores/faiss.ipynb @@ -83,7 +83,18 @@ "metadata": { "tags": [] }, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "42" + ] + }, + "execution_count": 1, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "# Uncomment the following line if you need to initialize FAISS with no AVX2 optimization\n", "# os.environ['FAISS_NO_AVX2'] = '1'\n", @@ -98,7 +109,8 @@ "text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)\n", "docs = text_splitter.split_documents(documents)\n", "embeddings = OpenAIEmbeddings()\n", - "db = FAISS.from_documents(docs, embeddings)" + "db = FAISS.from_documents(docs, embeddings)\n", + "print(db.index.ntotal)" ] }, { @@ -113,7 +125,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 2, "id": "5eabdb75", "metadata": { "tags": [] @@ -126,7 +138,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 3, "id": "4b172de8", "metadata": { "tags": [] @@ -162,27 +174,18 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 4, "id": "6e91b475-3878-44e0-8720-98d903754b46", "metadata": {}, "outputs": [], "source": [ - "retriever = db.as_retriever()" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "id": "a869c874-84b5-4d2c-9993-2513f10aee83", - "metadata": {}, - "outputs": [], - "source": [ + "retriever = db.as_retriever()\n", "docs = retriever.invoke(query)" ] }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 5, "id": "046739d2-91fe-4101-8b72-c0bcdd9e02b9", "metadata": {}, "outputs": [ @@ -215,7 +218,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 6, "id": "186ee1d8", "metadata": {}, "outputs": [], @@ -225,7 +228,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 7, "id": "284e04b5", "metadata": {}, "outputs": [ @@ -236,7 +239,7 @@ " 0.36913747)" ] }, - "execution_count": 14, + "execution_count": 8, "metadata": {}, "output_type": "execute_result" } @@ -255,7 +258,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 9, "id": "b558ebb7", "metadata": {}, "outputs": [], @@ -289,7 +292,7 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 10, "id": "98378c4e", "metadata": {}, "outputs": [ @@ -299,7 +302,7 @@ "Document(page_content='Tonight. I call on the Senate to: Pass the Freedom to Vote Act. Pass the John Lewis Voting Rights Act. And while you’re at it, pass the Disclose Act so Americans can know who is funding our elections. \\n\\nTonight, I’d like to honor someone who has dedicated his life to serve this country: Justice Stephen Breyer—an Army veteran, Constitutional scholar, and retiring Justice of the United States Supreme Court. Justice Breyer, thank you for your service. \\n\\nOne of the most serious constitutional responsibilities a President has is nominating someone to serve on the United States Supreme Court. \\n\\nAnd I did that 4 days ago, when I nominated Circuit Court of Appeals Judge Ketanji Brown Jackson. One of our nation’s top legal minds, who will continue Justice Breyer’s legacy of excellence.', metadata={'source': '../../../state_of_the_union.txt'})" ] }, - "execution_count": 19, + "execution_count": 9, "metadata": {}, "output_type": "execute_result" } @@ -359,7 +362,7 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 11, "id": "83392605", "metadata": {}, "outputs": [ @@ -369,7 +372,7 @@ "{'807e0c63-13f6-4070-9774-5c6f0fbb9866': Document(page_content='bar', metadata={})}" ] }, - "execution_count": 22, + "execution_count": 10, "metadata": {}, "output_type": "execute_result" } @@ -380,7 +383,7 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": 12, "id": "a3fcc1c7", "metadata": {}, "outputs": [], @@ -390,7 +393,7 @@ }, { "cell_type": "code", - "execution_count": 24, + "execution_count": 13, "id": "41c51f89", "metadata": {}, "outputs": [ @@ -401,7 +404,7 @@ " '807e0c63-13f6-4070-9774-5c6f0fbb9866': Document(page_content='bar', metadata={})}" ] }, - "execution_count": 24, + "execution_count": 13, "metadata": {}, "output_type": "execute_result" } @@ -421,7 +424,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 14, "id": "d5bf812c", "metadata": {}, "outputs": [ @@ -465,7 +468,7 @@ }, { "cell_type": "code", - "execution_count": 26, + "execution_count": 15, "id": "83159330", "metadata": {}, "outputs": [ @@ -496,7 +499,7 @@ }, { "cell_type": "code", - "execution_count": 27, + "execution_count": 16, "id": "432c6980", "metadata": {}, "outputs": [ @@ -525,7 +528,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 17, "id": "1fd60fd1", "metadata": {}, "outputs": [ @@ -550,59 +553,32 @@ "source": [ "## Delete\n", "\n", - "You can also delete ids. Note that the ids to delete should be the ids in the docstore." + "You can also delete records from vectorstore. In the example below `db.index_to_docstore_id` represents a dictionary with elements of the FAISS index." ] }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 18, "id": "1408b870", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "True" + "count before: 8\n", + "count after: 7" ] }, - "execution_count": 4, + "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "db.delete([db.index_to_docstore_id[0]])" + "print(\"count before:\", db.index.ntotal)\n", + "db.delete([db.index_to_docstore_id[0]])\n", + "print(\"count after:\", db.index.ntotal)" ] - }, - { - "cell_type": "code", - "execution_count": 6, - "id": "d13daf33", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "False" - ] - }, - "execution_count": 6, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# Is now missing\n", - "0 in db.index_to_docstore_id" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "30ace43e", - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": { diff --git a/libs/community/langchain_community/vectorstores/faiss.py b/libs/community/langchain_community/vectorstores/faiss.py index 32286bb17846c..d01de0068982a 100644 --- a/libs/community/langchain_community/vectorstores/faiss.py +++ b/libs/community/langchain_community/vectorstores/faiss.py @@ -119,9 +119,8 @@ def __init__( and self._normalize_L2 ): warnings.warn( - "Normalizing L2 is not applicable for metric type: {strategy}".format( - strategy=self.distance_strategy - ) + "Normalizing L2 is not applicable for " + f"metric type: {self.distance_strategy}" ) @property @@ -306,24 +305,7 @@ def similarity_search_with_score_by_vector( docs = [] if filter is not None: - if isinstance(filter, dict): - - def filter_func(metadata): # type: ignore[no-untyped-def] - if all( - metadata.get(key) in value - if isinstance(value, list) - else metadata.get(key) == value - for key, value in filter.items() - ): - return True - return False - elif callable(filter): - filter_func = filter - else: - raise ValueError( - "filter must be a dict of metadata or " - f"a callable, not {type(filter)}" - ) + filter_func = self._create_filter_func(filter) for j, i in enumerate(indices[0]): if i == -1: @@ -608,25 +590,8 @@ def max_marginal_relevance_search_with_score_by_vector( fetch_k if filter is None else fetch_k * 2, ) if filter is not None: + filter_func = self._create_filter_func(filter) filtered_indices = [] - if isinstance(filter, dict): - - def filter_func(metadata): # type: ignore[no-untyped-def] - if all( - metadata.get(key) in value - if isinstance(value, list) - else metadata.get(key) == value - for key, value in filter.items() - ): - return True - return False - elif callable(filter): - filter_func = filter - else: - raise ValueError( - "filter must be a dict of metadata or " - f"a callable, not {type(filter)}" - ) for i in indices[0]: if i == -1: # This happens when not enough docs are returned. @@ -646,18 +611,18 @@ def filter_func(metadata): # type: ignore[no-untyped-def] k=k, lambda_mult=lambda_mult, ) - selected_indices = [indices[0][i] for i in mmr_selected] - selected_scores = [scores[0][i] for i in mmr_selected] + docs_and_scores = [] - for i, score in zip(selected_indices, selected_scores): - if i == -1: + for i in mmr_selected: + if indices[0][i] == -1: # This happens when not enough docs are returned. continue - _id = self.index_to_docstore_id[i] + _id = self.index_to_docstore_id[indices[0][i]] doc = self.docstore.search(_id) if not isinstance(doc, Document): raise ValueError(f"Could not find document for id {_id}, got {doc}") - docs_and_scores.append((doc, score)) + docs_and_scores.append((doc, scores[0][i])) + return docs_and_scores async def amax_marginal_relevance_search_with_score_by_vector( @@ -857,9 +822,9 @@ def delete(self, ids: Optional[List[str]] = None, **kwargs: Any) -> Optional[boo ) reversed_index = {id_: idx for idx, id_ in self.index_to_docstore_id.items()} - index_to_delete = [reversed_index[id_] for id_ in ids] + index_to_delete = {reversed_index[id_] for id_ in ids} - self.index.remove_ids(np.array(index_to_delete, dtype=np.int64)) + self.index.remove_ids(np.fromiter(index_to_delete, dtype=np.int64)) self.docstore.delete(ids) remaining_ids = [ @@ -1079,12 +1044,10 @@ def save_local(self, folder_path: str, index_name: str = "index") -> None: # save index separately since it is not picklable faiss = dependable_faiss_import() - faiss.write_index( - self.index, str(path / "{index_name}.faiss".format(index_name=index_name)) - ) + faiss.write_index(self.index, str(path / f"{index_name}.faiss")) # save docstore and index_to_docstore_id - with open(path / "{index_name}.pkl".format(index_name=index_name), "wb") as f: + with open(path / f"{index_name}.pkl", "wb") as f: pickle.dump((self.docstore, self.index_to_docstore_id), f) @classmethod @@ -1127,12 +1090,10 @@ def load_local( path = Path(folder_path) # load index separately since it is not picklable faiss = dependable_faiss_import() - index = faiss.read_index( - str(path / "{index_name}.faiss".format(index_name=index_name)) - ) + index = faiss.read_index(str(path / f"{index_name}.faiss")) # load docstore and index_to_docstore_id - with open(path / "{index_name}.pkl".format(index_name=index_name), "rb") as f: + with open(path / f"{index_name}.pkl", "rb") as f: docstore, index_to_docstore_id = pickle.load(f) return cls(embeddings, index, docstore, index_to_docstore_id, **kwargs) @@ -1235,3 +1196,36 @@ async def _asimilarity_search_with_relevance_scores( (doc, relevance_score_fn(score)) for doc, score in docs_and_scores ] return docs_and_rel_scores + + @staticmethod + def _create_filter_func( + filter: Optional[Union[Callable, Dict[str, Any]]], + ) -> Callable[[Dict[str, Any]], bool]: + """ + Create a filter function based on the provided filter. + + Args: + filter: A callable or a dictionary representing the filter + conditions for documents. + + Returns: + Callable[[Dict[str, Any]], bool]: A function that takes Document's metadata + and returns True if it satisfies the filter conditions, otherwise False. + """ + if callable(filter): + return filter + + if not isinstance(filter, dict): + raise ValueError( + f"filter must be a dict of metadata or a callable, not {type(filter)}" + ) + + def filter_func(metadata: Dict[str, Any]) -> bool: + return all( + metadata.get(key) in value + if isinstance(value, list) + else metadata.get(key) == value + for key, value in filter.items() # type: ignore + ) + + return filter_func