Skip to content

Commit

Permalink
community: Cassandra Vector Store: extend metadata-related methods (#…
Browse files Browse the repository at this point in the history
…27078)

**Description:** this PR adds a set of methods to deal with metadata
associated to the vector store entries. These, while essential to the
Graph-related extension of the `Cassandra` vector store, are also useful
in themselves. These are (all come in their sync+async versions):

- `[a]delete_by_metadata_filter`
- `[a]replace_metadata`
- `[a]get_by_document_id`
- `[a]metadata_search`

Additionally, a `[a]similarity_search_with_embedding_id_by_vector`
method is introduced to better serve the store's internal working (esp.
related to reranking logic).

**Issue:** no issue number, but now all Document's returned bear their
`.id` consistently (as a consequence of a slight refactoring in how the
raw entries read from DB are made back into `Document` instances).

**Dependencies:** (no new deps: packaging comes through langchain-core
already; `cassio` is now required to be version 0.1.10+)


**Add tests and docs**
Added integration tests for the relevant newly-introduced methods.
(Docs will be updated in a separate PR).

**Lint and test** Lint and (updated) test all pass.

---------

Co-authored-by: Erick Friis <[email protected]>
  • Loading branch information
hemidactylus and efriis authored Oct 9, 2024
1 parent 84c05b0 commit d05fdd9
Show file tree
Hide file tree
Showing 2 changed files with 396 additions and 18 deletions.
267 changes: 255 additions & 12 deletions libs/community/langchain_community/vectorstores/cassandra.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import asyncio
import importlib.metadata
import typing
import uuid
from typing import (
Expand All @@ -18,6 +19,7 @@
)

import numpy as np
from packaging.version import Version # this is a lancghain-core dependency

if typing.TYPE_CHECKING:
from cassandra.cluster import Session
Expand All @@ -30,6 +32,7 @@
from langchain_community.vectorstores.utils import maximal_marginal_relevance

CVST = TypeVar("CVST", bound="Cassandra")
MIN_CASSIO_VERSION = Version("0.1.10")


class Cassandra(VectorStore):
Expand Down Expand Up @@ -110,6 +113,15 @@ def __init__(
"Could not import cassio python package. "
"Please install it with `pip install cassio`."
)
cassio_version = Version(importlib.metadata.version("cassio"))

if cassio_version is not None and cassio_version < MIN_CASSIO_VERSION:
msg = (
"Cassio version not supported. Please upgrade cassio "
f"to version {MIN_CASSIO_VERSION} or higher."
)
raise ImportError(msg)

if not table_name:
raise ValueError("Missing required parameter 'table_name'.")
self.embedding = embedding
Expand Down Expand Up @@ -143,6 +155,9 @@ def __init__(
**kwargs,
)

if self.session is None:
self.session = self.table.session

@property
def embeddings(self) -> Embeddings:
return self.embedding
Expand Down Expand Up @@ -231,6 +246,70 @@ async def adelete(
await self.adelete_by_document_id(document_id)
return True

def delete_by_metadata_filter(
self,
filter: dict[str, Any],
*,
batch_size: int = 50,
) -> int:
"""Delete all documents matching a certain metadata filtering condition.
This operation does not use the vector embeddings in any way, it simply
removes all documents whose metadata match the provided condition.
Args:
filter: Filter on the metadata to apply. The filter cannot be empty.
batch_size: amount of deletions per each batch (until exhaustion of
the matching documents).
Returns:
A number expressing the amount of deleted documents.
"""
if not filter:
msg = (
"Method `delete_by_metadata_filter` does not accept an empty "
"filter. Use the `clear()` method if you really want to empty "
"the vector store."
)
raise ValueError(msg)

return self.table.find_and_delete_entries(
metadata=filter,
batch_size=batch_size,
)

async def adelete_by_metadata_filter(
self,
filter: dict[str, Any],
*,
batch_size: int = 50,
) -> int:
"""Delete all documents matching a certain metadata filtering condition.
This operation does not use the vector embeddings in any way, it simply
removes all documents whose metadata match the provided condition.
Args:
filter: Filter on the metadata to apply. The filter cannot be empty.
batch_size: amount of deletions per each batch (until exhaustion of
the matching documents).
Returns:
A number expressing the amount of deleted documents.
"""
if not filter:
msg = (
"Method `delete_by_metadata_filter` does not accept an empty "
"filter. Use the `clear()` method if you really want to empty "
"the vector store."
)
raise ValueError(msg)

return await self.table.afind_and_delete_entries(
metadata=filter,
batch_size=batch_size,
)

def add_texts(
self,
texts: Iterable[str],
Expand Down Expand Up @@ -333,6 +412,180 @@ async def send_concurrently(
await asyncio.gather(*tasks)
return ids

def replace_metadata(
self,
id_to_metadata: dict[str, dict],
*,
batch_size: int = 50,
) -> None:
"""Replace the metadata of documents.
For each document to update, identified by its ID, the new metadata
dictionary completely replaces what is on the store. This includes
passing empty metadata `{}` to erase the currently-stored information.
Args:
id_to_metadata: map from the Document IDs to modify to the
new metadata for updating.
Keys in this dictionary that do not correspond to an existing
document will not cause an error, rather will result in new
rows being written into the Cassandra table but without an
associated vector: hence unreachable through vector search.
batch_size: Number of concurrent requests to send to the server.
Returns:
None if the writes succeed (otherwise an error is raised).
"""
ids_and_metadatas = list(id_to_metadata.items())
for i in range(0, len(ids_and_metadatas), batch_size):
batch_i_m = ids_and_metadatas[i : i + batch_size]
futures = [
self.table.put_async(
row_id=doc_id,
metadata=doc_md,
)
for doc_id, doc_md in batch_i_m
]
for future in futures:
future.result()
return

async def areplace_metadata(
self,
id_to_metadata: dict[str, dict],
*,
concurrency: int = 50,
) -> None:
"""Replace the metadata of documents.
For each document to update, identified by its ID, the new metadata
dictionary completely replaces what is on the store. This includes
passing empty metadata `{}` to erase the currently-stored information.
Args:
id_to_metadata: map from the Document IDs to modify to the
new metadata for updating.
Keys in this dictionary that do not correspond to an existing
document will not cause an error, rather will result in new
rows being written into the Cassandra table but without an
associated vector: hence unreachable through vector search.
concurrency: Number of concurrent queries to the database.
Defaults to 50.
Returns:
None if the writes succeed (otherwise an error is raised).
"""
ids_and_metadatas = list(id_to_metadata.items())

sem = asyncio.Semaphore(concurrency)

async def send_concurrently(doc_id: str, doc_md: dict) -> None:
async with sem:
await self.table.aput(
row_id=doc_id,
metadata=doc_md,
)

for doc_id, doc_md in ids_and_metadatas:
tasks = [asyncio.create_task(send_concurrently(doc_id, doc_md))]
await asyncio.gather(*tasks)

return

@staticmethod
def _row_to_document(row: Dict[str, Any]) -> Document:
return Document(
id=row["row_id"],
page_content=row["body_blob"],
metadata=row["metadata"],
)

def get_by_document_id(self, document_id: str) -> Document | None:
"""Get by document ID.
Args:
document_id: the document ID to get.
"""
row = self.table.get(row_id=document_id)
if row is None:
return None
return self._row_to_document(row=row)

async def aget_by_document_id(self, document_id: str) -> Document | None:
"""Get by document ID.
Args:
document_id: the document ID to get.
"""
row = await self.table.aget(row_id=document_id)
if row is None:
return None
return self._row_to_document(row=row)

def metadata_search(
self,
metadata: dict[str, Any] = {}, # noqa: B006
n: int = 5,
) -> Iterable[Document]:
"""Get documents via a metadata search.
Args:
metadata: the metadata to query for.
"""
rows = self.table.find_entries(metadata=metadata, n=n)
return [self._row_to_document(row=row) for row in rows if row]

async def ametadata_search(
self,
metadata: dict[str, Any] = {}, # noqa: B006
n: int = 5,
) -> Iterable[Document]:
"""Get documents via a metadata search.
Args:
metadata: the metadata to query for.
"""
rows = await self.table.afind_entries(metadata=metadata, n=n)
return [self._row_to_document(row=row) for row in rows]

async def asimilarity_search_with_embedding_id_by_vector(
self,
embedding: List[float],
k: int = 4,
filter: Optional[Dict[str, str]] = None,
body_search: Optional[Union[str, List[str]]] = None,
) -> List[Tuple[Document, List[float], str]]:
"""Return docs most similar to embedding vector.
Args:
embedding: Embedding to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
filter: Filter on the metadata to apply.
body_search: Document textual search terms to apply.
Only supported by Astra DB at the moment.
Returns:
List of (Document, embedding, id), the most similar to the query vector.
"""
kwargs: Dict[str, Any] = {}
if filter is not None:
kwargs["metadata"] = filter
if body_search is not None:
kwargs["body_search"] = body_search

hits = await self.table.aann_search(
vector=embedding,
n=k,
**kwargs,
)
return [
(
self._row_to_document(row=hit),
hit["vector"],
hit["row_id"],
)
for hit in hits
]

@staticmethod
def _search_to_documents(
hits: Iterable[Dict[str, Any]],
Expand All @@ -341,10 +594,7 @@ def _search_to_documents(
# (1=most relevant), as required by this class' contract.
return [
(
Document(
page_content=hit["body_blob"],
metadata=hit["metadata"],
),
Cassandra._row_to_document(row=hit),
0.5 + 0.5 * hit["distance"],
hit["row_id"],
)
Expand Down Expand Up @@ -375,7 +625,6 @@ def similarity_search_with_score_id_by_vector(
kwargs["metadata"] = filter
if body_search is not None:
kwargs["body_search"] = body_search

hits = self.table.metric_ann_search(
vector=embedding,
n=k,
Expand Down Expand Up @@ -712,13 +961,7 @@ def _mmr_search_to_documents(
for pf_index, pf_hit in enumerate(prefetch_hits)
if pf_index in mmr_chosen_indices
]
return [
Document(
page_content=hit["body_blob"],
metadata=hit["metadata"],
)
for hit in mmr_hits
]
return [Cassandra._row_to_document(row=hit) for hit in mmr_hits]

def max_marginal_relevance_search_by_vector(
self,
Expand Down
Loading

0 comments on commit d05fdd9

Please sign in to comment.