From d05fdd97ddf1a5302fd9980f1b48823a5c1540b9 Mon Sep 17 00:00:00 2001 From: Stefano Lottini Date: Wed, 9 Oct 2024 08:41:34 +0200 Subject: [PATCH] community: Cassandra Vector Store: extend metadata-related methods (#27078) **Description:** this PR adds a set of methods to deal with metadata associated to the vector store entries. These, while essential to the Graph-related extension of the `Cassandra` vector store, are also useful in themselves. These are (all come in their sync+async versions): - `[a]delete_by_metadata_filter` - `[a]replace_metadata` - `[a]get_by_document_id` - `[a]metadata_search` Additionally, a `[a]similarity_search_with_embedding_id_by_vector` method is introduced to better serve the store's internal working (esp. related to reranking logic). **Issue:** no issue number, but now all Document's returned bear their `.id` consistently (as a consequence of a slight refactoring in how the raw entries read from DB are made back into `Document` instances). **Dependencies:** (no new deps: packaging comes through langchain-core already; `cassio` is now required to be version 0.1.10+) **Add tests and docs** Added integration tests for the relevant newly-introduced methods. (Docs will be updated in a separate PR). **Lint and test** Lint and (updated) test all pass. --------- Co-authored-by: Erick Friis --- .../vectorstores/cassandra.py | 267 +++++++++++++++++- .../vectorstores/test_cassandra.py | 147 +++++++++- 2 files changed, 396 insertions(+), 18 deletions(-) diff --git a/libs/community/langchain_community/vectorstores/cassandra.py b/libs/community/langchain_community/vectorstores/cassandra.py index 85f6460d0133e..3e9ea17cbb0ae 100644 --- a/libs/community/langchain_community/vectorstores/cassandra.py +++ b/libs/community/langchain_community/vectorstores/cassandra.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +import importlib.metadata import typing import uuid from typing import ( @@ -18,6 +19,7 @@ ) import numpy as np +from packaging.version import Version # this is a lancghain-core dependency if typing.TYPE_CHECKING: from cassandra.cluster import Session @@ -30,6 +32,7 @@ from langchain_community.vectorstores.utils import maximal_marginal_relevance CVST = TypeVar("CVST", bound="Cassandra") +MIN_CASSIO_VERSION = Version("0.1.10") class Cassandra(VectorStore): @@ -110,6 +113,15 @@ def __init__( "Could not import cassio python package. " "Please install it with `pip install cassio`." ) + cassio_version = Version(importlib.metadata.version("cassio")) + + if cassio_version is not None and cassio_version < MIN_CASSIO_VERSION: + msg = ( + "Cassio version not supported. Please upgrade cassio " + f"to version {MIN_CASSIO_VERSION} or higher." + ) + raise ImportError(msg) + if not table_name: raise ValueError("Missing required parameter 'table_name'.") self.embedding = embedding @@ -143,6 +155,9 @@ def __init__( **kwargs, ) + if self.session is None: + self.session = self.table.session + @property def embeddings(self) -> Embeddings: return self.embedding @@ -231,6 +246,70 @@ async def adelete( await self.adelete_by_document_id(document_id) return True + def delete_by_metadata_filter( + self, + filter: dict[str, Any], + *, + batch_size: int = 50, + ) -> int: + """Delete all documents matching a certain metadata filtering condition. + + This operation does not use the vector embeddings in any way, it simply + removes all documents whose metadata match the provided condition. + + Args: + filter: Filter on the metadata to apply. The filter cannot be empty. + batch_size: amount of deletions per each batch (until exhaustion of + the matching documents). + + Returns: + A number expressing the amount of deleted documents. + """ + if not filter: + msg = ( + "Method `delete_by_metadata_filter` does not accept an empty " + "filter. Use the `clear()` method if you really want to empty " + "the vector store." + ) + raise ValueError(msg) + + return self.table.find_and_delete_entries( + metadata=filter, + batch_size=batch_size, + ) + + async def adelete_by_metadata_filter( + self, + filter: dict[str, Any], + *, + batch_size: int = 50, + ) -> int: + """Delete all documents matching a certain metadata filtering condition. + + This operation does not use the vector embeddings in any way, it simply + removes all documents whose metadata match the provided condition. + + Args: + filter: Filter on the metadata to apply. The filter cannot be empty. + batch_size: amount of deletions per each batch (until exhaustion of + the matching documents). + + Returns: + A number expressing the amount of deleted documents. + """ + if not filter: + msg = ( + "Method `delete_by_metadata_filter` does not accept an empty " + "filter. Use the `clear()` method if you really want to empty " + "the vector store." + ) + raise ValueError(msg) + + return await self.table.afind_and_delete_entries( + metadata=filter, + batch_size=batch_size, + ) + def add_texts( self, texts: Iterable[str], @@ -333,6 +412,180 @@ async def send_concurrently( await asyncio.gather(*tasks) return ids + def replace_metadata( + self, + id_to_metadata: dict[str, dict], + *, + batch_size: int = 50, + ) -> None: + """Replace the metadata of documents. + + For each document to update, identified by its ID, the new metadata + dictionary completely replaces what is on the store. This includes + passing empty metadata `{}` to erase the currently-stored information. + + Args: + id_to_metadata: map from the Document IDs to modify to the + new metadata for updating. + Keys in this dictionary that do not correspond to an existing + document will not cause an error, rather will result in new + rows being written into the Cassandra table but without an + associated vector: hence unreachable through vector search. + batch_size: Number of concurrent requests to send to the server. + + Returns: + None if the writes succeed (otherwise an error is raised). + """ + ids_and_metadatas = list(id_to_metadata.items()) + for i in range(0, len(ids_and_metadatas), batch_size): + batch_i_m = ids_and_metadatas[i : i + batch_size] + futures = [ + self.table.put_async( + row_id=doc_id, + metadata=doc_md, + ) + for doc_id, doc_md in batch_i_m + ] + for future in futures: + future.result() + return + + async def areplace_metadata( + self, + id_to_metadata: dict[str, dict], + *, + concurrency: int = 50, + ) -> None: + """Replace the metadata of documents. + + For each document to update, identified by its ID, the new metadata + dictionary completely replaces what is on the store. This includes + passing empty metadata `{}` to erase the currently-stored information. + + Args: + id_to_metadata: map from the Document IDs to modify to the + new metadata for updating. + Keys in this dictionary that do not correspond to an existing + document will not cause an error, rather will result in new + rows being written into the Cassandra table but without an + associated vector: hence unreachable through vector search. + concurrency: Number of concurrent queries to the database. + Defaults to 50. + + Returns: + None if the writes succeed (otherwise an error is raised). + """ + ids_and_metadatas = list(id_to_metadata.items()) + + sem = asyncio.Semaphore(concurrency) + + async def send_concurrently(doc_id: str, doc_md: dict) -> None: + async with sem: + await self.table.aput( + row_id=doc_id, + metadata=doc_md, + ) + + for doc_id, doc_md in ids_and_metadatas: + tasks = [asyncio.create_task(send_concurrently(doc_id, doc_md))] + await asyncio.gather(*tasks) + + return + + @staticmethod + def _row_to_document(row: Dict[str, Any]) -> Document: + return Document( + id=row["row_id"], + page_content=row["body_blob"], + metadata=row["metadata"], + ) + + def get_by_document_id(self, document_id: str) -> Document | None: + """Get by document ID. + + Args: + document_id: the document ID to get. + """ + row = self.table.get(row_id=document_id) + if row is None: + return None + return self._row_to_document(row=row) + + async def aget_by_document_id(self, document_id: str) -> Document | None: + """Get by document ID. + + Args: + document_id: the document ID to get. + """ + row = await self.table.aget(row_id=document_id) + if row is None: + return None + return self._row_to_document(row=row) + + def metadata_search( + self, + metadata: dict[str, Any] = {}, # noqa: B006 + n: int = 5, + ) -> Iterable[Document]: + """Get documents via a metadata search. + + Args: + metadata: the metadata to query for. + """ + rows = self.table.find_entries(metadata=metadata, n=n) + return [self._row_to_document(row=row) for row in rows if row] + + async def ametadata_search( + self, + metadata: dict[str, Any] = {}, # noqa: B006 + n: int = 5, + ) -> Iterable[Document]: + """Get documents via a metadata search. + + Args: + metadata: the metadata to query for. + """ + rows = await self.table.afind_entries(metadata=metadata, n=n) + return [self._row_to_document(row=row) for row in rows] + + async def asimilarity_search_with_embedding_id_by_vector( + self, + embedding: List[float], + k: int = 4, + filter: Optional[Dict[str, str]] = None, + body_search: Optional[Union[str, List[str]]] = None, + ) -> List[Tuple[Document, List[float], str]]: + """Return docs most similar to embedding vector. + + Args: + embedding: Embedding to look up documents similar to. + k: Number of Documents to return. Defaults to 4. + filter: Filter on the metadata to apply. + body_search: Document textual search terms to apply. + Only supported by Astra DB at the moment. + Returns: + List of (Document, embedding, id), the most similar to the query vector. + """ + kwargs: Dict[str, Any] = {} + if filter is not None: + kwargs["metadata"] = filter + if body_search is not None: + kwargs["body_search"] = body_search + + hits = await self.table.aann_search( + vector=embedding, + n=k, + **kwargs, + ) + return [ + ( + self._row_to_document(row=hit), + hit["vector"], + hit["row_id"], + ) + for hit in hits + ] + @staticmethod def _search_to_documents( hits: Iterable[Dict[str, Any]], @@ -341,10 +594,7 @@ def _search_to_documents( # (1=most relevant), as required by this class' contract. return [ ( - Document( - page_content=hit["body_blob"], - metadata=hit["metadata"], - ), + Cassandra._row_to_document(row=hit), 0.5 + 0.5 * hit["distance"], hit["row_id"], ) @@ -375,7 +625,6 @@ def similarity_search_with_score_id_by_vector( kwargs["metadata"] = filter if body_search is not None: kwargs["body_search"] = body_search - hits = self.table.metric_ann_search( vector=embedding, n=k, @@ -712,13 +961,7 @@ def _mmr_search_to_documents( for pf_index, pf_hit in enumerate(prefetch_hits) if pf_index in mmr_chosen_indices ] - return [ - Document( - page_content=hit["body_blob"], - metadata=hit["metadata"], - ) - for hit in mmr_hits - ] + return [Cassandra._row_to_document(row=hit) for hit in mmr_hits] def max_marginal_relevance_search_by_vector( self, diff --git a/libs/community/tests/integration_tests/vectorstores/test_cassandra.py b/libs/community/tests/integration_tests/vectorstores/test_cassandra.py index 014629220f4b8..fd55bab2d3163 100644 --- a/libs/community/tests/integration_tests/vectorstores/test_cassandra.py +++ b/libs/community/tests/integration_tests/vectorstores/test_cassandra.py @@ -17,6 +17,17 @@ ) +def _strip_docs(documents: List[Document]) -> List[Document]: + return [_strip_doc(doc) for doc in documents] + + +def _strip_doc(document: Document) -> Document: + return Document( + page_content=document.page_content, + metadata=document.metadata, + ) + + def _vectorstore_from_texts( texts: List[str], metadatas: Optional[List[dict]] = None, @@ -110,9 +121,9 @@ async def test_cassandra() -> None: texts = ["foo", "bar", "baz"] docsearch = _vectorstore_from_texts(texts) output = docsearch.similarity_search("foo", k=1) - assert output == [Document(page_content="foo")] + assert _strip_docs(output) == _strip_docs([Document(page_content="foo")]) output = await docsearch.asimilarity_search("foo", k=1) - assert output == [Document(page_content="foo")] + assert _strip_docs(output) == _strip_docs([Document(page_content="foo")]) async def test_cassandra_with_score() -> None: @@ -130,13 +141,13 @@ async def test_cassandra_with_score() -> None: output = docsearch.similarity_search_with_score("foo", k=3) docs = [o[0] for o in output] scores = [o[1] for o in output] - assert docs == expected_docs + assert _strip_docs(docs) == _strip_docs(expected_docs) assert scores[0] > scores[1] > scores[2] output = await docsearch.asimilarity_search_with_score("foo", k=3) docs = [o[0] for o in output] scores = [o[1] for o in output] - assert docs == expected_docs + assert _strip_docs(docs) == _strip_docs(expected_docs) assert scores[0] > scores[1] > scores[2] @@ -239,7 +250,7 @@ async def test_cassandra_no_drop_async() -> None: def test_cassandra_delete() -> None: """Test delete methods from vector store.""" texts = ["foo", "bar", "baz", "gni"] - metadatas = [{"page": i} for i in range(len(texts))] + metadatas = [{"page": i, "mod2": i % 2} for i in range(len(texts))] docsearch = _vectorstore_from_texts([], metadatas=metadatas) ids = docsearch.add_texts(texts, metadatas) @@ -263,11 +274,21 @@ def test_cassandra_delete() -> None: output = docsearch.similarity_search("foo", k=10) assert len(output) == 0 + docsearch.add_texts(texts, metadatas) + num_deleted = docsearch.delete_by_metadata_filter({"mod2": 0}, batch_size=1) + assert num_deleted == 2 + output = docsearch.similarity_search("foo", k=10) + assert len(output) == 2 + docsearch.clear() + + with pytest.raises(ValueError): + docsearch.delete_by_metadata_filter({}) + async def test_cassandra_adelete() -> None: """Test delete methods from vector store.""" texts = ["foo", "bar", "baz", "gni"] - metadatas = [{"page": i} for i in range(len(texts))] + metadatas = [{"page": i, "mod2": i % 2} for i in range(len(texts))] docsearch = await _vectorstore_from_texts_async([], metadatas=metadatas) ids = await docsearch.aadd_texts(texts, metadatas) @@ -291,6 +312,16 @@ async def test_cassandra_adelete() -> None: output = docsearch.similarity_search("foo", k=10) assert len(output) == 0 + await docsearch.aadd_texts(texts, metadatas) + num_deleted = await docsearch.adelete_by_metadata_filter({"mod2": 0}, batch_size=1) + assert num_deleted == 2 + output = await docsearch.asimilarity_search("foo", k=10) + assert len(output) == 2 + await docsearch.aclear() + + with pytest.raises(ValueError): + await docsearch.adelete_by_metadata_filter({}) + def test_cassandra_metadata_indexing() -> None: """Test comparing metadata indexing policies.""" @@ -316,3 +347,107 @@ def test_cassandra_metadata_indexing() -> None: with pytest.raises(ValueError): # "Non-indexed metadata fields cannot be used in queries." vstore_f1.similarity_search("bar", filter={"field2": "b"}, k=2) + + +def test_cassandra_replace_metadata() -> None: + """Test of replacing metadata.""" + N_DOCS = 100 + REPLACE_RATIO = 2 # one in ... will have replaced metadata + BATCH_SIZE = 3 + + vstore_f1 = _vectorstore_from_texts( + texts=[], + metadata_indexing=("allowlist", ["field1", "field2"]), + table_name="vector_test_table_indexing", + ) + orig_documents = [ + Document( + page_content=f"doc_{doc_i}", + id=f"doc_id_{doc_i}", + metadata={"field1": f"f1_{doc_i}", "otherf": "pre"}, + ) + for doc_i in range(N_DOCS) + ] + vstore_f1.add_documents(orig_documents) + + ids_to_replace = [ + f"doc_id_{doc_i}" for doc_i in range(N_DOCS) if doc_i % REPLACE_RATIO == 0 + ] + + # various kinds of replacement at play here: + def _make_new_md(mode: int, doc_id: str) -> dict[str, str]: + if mode == 0: + return {} + elif mode == 1: + return {"field2": f"NEW_{doc_id}"} + elif mode == 2: + return {"field2": f"NEW_{doc_id}", "ofherf2": "post"} + else: + return {"ofherf2": "post"} + + ids_to_new_md = { + doc_id: _make_new_md(rep_i % 4, doc_id) + for rep_i, doc_id in enumerate(ids_to_replace) + } + + vstore_f1.replace_metadata(ids_to_new_md, batch_size=BATCH_SIZE) + # thorough check + expected_id_to_metadata: dict[str, dict] = { + **{(document.id or ""): document.metadata for document in orig_documents}, + **ids_to_new_md, + } + for hit in vstore_f1.similarity_search("doc", k=N_DOCS + 1): + assert hit.id is not None + assert hit.metadata == expected_id_to_metadata[hit.id] + + +async def test_cassandra_areplace_metadata() -> None: + """Test of replacing metadata.""" + N_DOCS = 100 + REPLACE_RATIO = 2 # one in ... will have replaced metadata + BATCH_SIZE = 3 + + vstore_f1 = _vectorstore_from_texts( + texts=[], + metadata_indexing=("allowlist", ["field1", "field2"]), + table_name="vector_test_table_indexing", + ) + orig_documents = [ + Document( + page_content=f"doc_{doc_i}", + id=f"doc_id_{doc_i}", + metadata={"field1": f"f1_{doc_i}", "otherf": "pre"}, + ) + for doc_i in range(N_DOCS) + ] + await vstore_f1.aadd_documents(orig_documents) + + ids_to_replace = [ + f"doc_id_{doc_i}" for doc_i in range(N_DOCS) if doc_i % REPLACE_RATIO == 0 + ] + + # various kinds of replacement at play here: + def _make_new_md(mode: int, doc_id: str) -> dict[str, str]: + if mode == 0: + return {} + elif mode == 1: + return {"field2": f"NEW_{doc_id}"} + elif mode == 2: + return {"field2": f"NEW_{doc_id}", "ofherf2": "post"} + else: + return {"ofherf2": "post"} + + ids_to_new_md = { + doc_id: _make_new_md(rep_i % 4, doc_id) + for rep_i, doc_id in enumerate(ids_to_replace) + } + + await vstore_f1.areplace_metadata(ids_to_new_md, concurrency=BATCH_SIZE) + # thorough check + expected_id_to_metadata: dict[str, dict] = { + **{(document.id or ""): document.metadata for document in orig_documents}, + **ids_to_new_md, + } + for hit in await vstore_f1.asimilarity_search("doc", k=N_DOCS + 1): + assert hit.id is not None + assert hit.metadata == expected_id_to_metadata[hit.id]