diff --git a/libs/partners/astradb/langchain_astradb/storage.py b/libs/partners/astradb/langchain_astradb/storage.py index da4cc58593d4e..1e1ec9a08ef1a 100644 --- a/libs/partners/astradb/langchain_astradb/storage.py +++ b/libs/partners/astradb/langchain_astradb/storage.py @@ -29,7 +29,20 @@ class AstraDBBaseStore(Generic[V], BaseStore[str, V], ABC): """Base class for the DataStax AstraDB data store.""" def __init__(self, *args: Any, **kwargs: Any) -> None: - self.astra_env = _AstraDBCollectionEnvironment(*args, **kwargs) + if "requested_indexing_policy" in kwargs: + raise ValueError( + "Do not pass 'requested_indexing_policy' to AstraDBBaseStore init" + ) + if "default_indexing_policy" in kwargs: + raise ValueError( + "Do not pass 'default_indexing_policy' to AstraDBBaseStore init" + ) + kwargs["requested_indexing_policy"] = {"allow": ["_id"]} + kwargs["default_indexing_policy"] = {"allow": ["_id"]} + self.astra_env = _AstraDBCollectionEnvironment( + *args, + **kwargs, + ) self.collection = self.astra_env.collection self.async_collection = self.astra_env.async_collection diff --git a/libs/partners/astradb/langchain_astradb/utils/astradb.py b/libs/partners/astradb/langchain_astradb/utils/astradb.py index b1869a8bff601..769b74ffcb508 100644 --- a/libs/partners/astradb/langchain_astradb/utils/astradb.py +++ b/libs/partners/astradb/langchain_astradb/utils/astradb.py @@ -2,11 +2,14 @@ import asyncio import inspect +import json +import warnings from asyncio import InvalidStateError, Task from enum import Enum -from typing import Awaitable, Optional, Union +from typing import Any, Awaitable, Dict, List, Optional, Union import langchain_core +from astrapy.api import APIRequestError from astrapy.db import AstraDB, AsyncAstraDB @@ -89,6 +92,8 @@ def __init__( pre_delete_collection: bool = False, embedding_dimension: Union[int, Awaitable[int], None] = None, metric: Optional[str] = None, + requested_indexing_policy: Optional[Dict[str, Any]] = None, + default_indexing_policy: Optional[Dict[str, Any]] = None, ) -> None: from astrapy.db import AstraDBCollection, AsyncAstraDBCollection @@ -106,6 +111,11 @@ def __init__( astra_db=self.async_astra_db, ) + if requested_indexing_policy is not None: + _options = {"indexing": requested_indexing_policy} + else: + _options = None + self.async_setup_db_task: Optional[Task] = None if setup_mode == SetupMode.ASYNC: async_astra_db = self.async_astra_db @@ -117,9 +127,31 @@ async def _setup_db() -> None: dimension = await embedding_dimension else: dimension = embedding_dimension - await async_astra_db.create_collection( - collection_name, dimension=dimension, metric=metric - ) + + try: + await async_astra_db.create_collection( + collection_name, + dimension=dimension, + metric=metric, + options=_options, + ) + except (APIRequestError, ValueError): + # possibly the collection is preexisting and may have legacy, + # or custom, indexing settings: verify + get_coll_response = await async_astra_db.get_collections( + options={"explain": True} + ) + collections = (get_coll_response["status"] or {}).get( + "collections" + ) or [] + if not self._validate_indexing_policy( + detected_collections=collections, + collection_name=self.collection_name, + requested_indexing_policy=requested_indexing_policy, + default_indexing_policy=default_indexing_policy, + ): + # other reasons for the exception + raise self.async_setup_db_task = asyncio.create_task(_setup_db()) elif setup_mode == SetupMode.SYNC: @@ -130,12 +162,138 @@ async def _setup_db() -> None: "Cannot use an awaitable embedding_dimension with async_setup " "set to False" ) - self.astra_db.create_collection( - collection_name, - dimension=embedding_dimension, # type: ignore[arg-type] - metric=metric, + else: + try: + self.astra_db.create_collection( + collection_name, + dimension=embedding_dimension, # type: ignore[arg-type] + metric=metric, + options=_options, + ) + except (APIRequestError, ValueError): + # possibly the collection is preexisting and may have legacy, + # or custom, indexing settings: verify + get_coll_response = self.astra_db.get_collections( # type: ignore[union-attr] + options={"explain": True} + ) + collections = (get_coll_response["status"] or {}).get( + "collections" + ) or [] + if not self._validate_indexing_policy( + detected_collections=collections, + collection_name=self.collection_name, + requested_indexing_policy=requested_indexing_policy, + default_indexing_policy=default_indexing_policy, + ): + # other reasons for the exception + raise + + @staticmethod + def _validate_indexing_policy( + detected_collections: List[Dict[str, Any]], + collection_name: str, + requested_indexing_policy: Optional[Dict[str, Any]], + default_indexing_policy: Optional[Dict[str, Any]], + ) -> bool: + """ + This is a validation helper, to be called when the collection-creation + call has failed. + + Args: + detected_collection (List[Dict[str, Any]]): + the list of collection items returned by astrapy + collection_name (str): the name of the collection whose attempted + creation failed + requested_indexing_policy: the 'indexing' part of the collection + options, e.g. `{"deny": ["field1", "field2"]}`. + Leave to its default of None if no options required. + default_indexing_policy: an optional 'default value' for the + above, used to issue just a gentle warning in the special + case that no policy is detected on a preexisting collection + on DB and the default is requested. This is to enable + a warning-only transition to new code using indexing without + disrupting usage of a legacy collection, i.e. one created + before adopting the usage of indexing policies altogether. + You cannot pass this one without requested_indexing_policy. + + This function may raise an error (indexing mismatches), issue a warning + (about legacy collections), or do nothing. + In any case, when the function returns, it returns either + - True: the exception was handled here as part of the indexing + management + - False: the exception is unrelated to indexing and the caller + has to reraise it. + """ + if requested_indexing_policy is None and default_indexing_policy is not None: + raise ValueError( + "Cannot specify a default indexing policy " + "when no indexing policy is requested for this collection " + "(requested_indexing_policy is None, " + "default_indexing_policy is not None)." ) + preexisting = [ + collection + for collection in detected_collections + if collection["name"] == collection_name + ] + if preexisting: + pre_collection = preexisting[0] + # if it has no "indexing", it is a legacy collection + pre_col_options = pre_collection.get("options") or {} + if "indexing" not in pre_col_options: + # legacy collection on DB + if requested_indexing_policy == default_indexing_policy: + warnings.warn( + ( + f"Astra DB collection '{collection_name}' is " + "detected as legacy and has indexing turned " + "on for all fields. This implies stricter " + "limitations on the amount of text each string in a " + "document can store. Consider reindexing anew on a " + "fresh collection to be able to store longer texts." + ), + UserWarning, + stacklevel=2, + ) + else: + raise ValueError( + f"Astra DB collection '{collection_name}' is " + "detected as legacy and has indexing turned " + "on for all fields. This is incompatible with " + "the requested indexing policy for this object. " + "Consider reindexing anew on a fresh " + "collection with the requested indexing " + "policy, or alternatively leave the indexing " + "settings for this object to their defaults " + "to keep using this collection." + ) + elif pre_col_options["indexing"] != requested_indexing_policy: + # collection on DB has indexing settings, but different + options_json = json.dumps(pre_col_options["indexing"]) + if pre_col_options["indexing"] == default_indexing_policy: + default_desc = " (default setting)" + else: + default_desc = "" + raise ValueError( + f"Astra DB collection '{collection_name}' is " + "detected as having the following indexing policy: " + f"{options_json}{default_desc}. This is incompatible " + "with the requested indexing policy for this object. " + "Consider reindexing anew on a fresh " + "collection with the requested indexing " + "policy, or alternatively align the requested " + "indexing settings to the collection to keep using it." + ) + else: + # the discrepancies have to do with options other than indexing + return False + # the original exception, related to indexing, was handled here + return True + else: + # foreign-origin for the original exception + return False + def ensure_db_setup(self) -> None: if self.async_setup_db_task: try: diff --git a/libs/partners/astradb/langchain_astradb/vectorstores.py b/libs/partners/astradb/langchain_astradb/vectorstores.py index 3e093dd13a15c..fe7381f91f07f 100644 --- a/libs/partners/astradb/langchain_astradb/vectorstores.py +++ b/libs/partners/astradb/langchain_astradb/vectorstores.py @@ -1,12 +1,11 @@ from __future__ import annotations -import asyncio import uuid import warnings -from asyncio import Task from concurrent.futures import ThreadPoolExecutor from typing import ( Any, + Awaitable, Callable, Dict, Iterable, @@ -16,27 +15,26 @@ Tuple, Type, TypeVar, - cast, + Union, ) import numpy as np from astrapy.db import ( AstraDB as AstraDBClient, ) -from astrapy.db import ( - AstraDBCollection, - AsyncAstraDBCollection, -) from astrapy.db import ( AsyncAstraDB as AsyncAstraDBClient, ) from langchain_core.documents import Document from langchain_core.embeddings import Embeddings -from langchain_core.runnables import run_in_executor from langchain_core.runnables.utils import gather_with_concurrency from langchain_core.utils.iter import batch_iterate from langchain_core.vectorstores import VectorStore +from langchain_astradb.utils.astradb import ( + SetupMode, + _AstraDBCollectionEnvironment, +) from langchain_astradb.utils.mmr import maximal_marginal_relevance T = TypeVar("T") @@ -54,6 +52,9 @@ # Number of threads (for deleting multiple rows concurrently): DEFAULT_BULK_DELETE_CONCURRENCY = 20 +# indexing options when creating a collection +DEFAULT_INDEXING_OPTIONS = {"allow": ["metadata"]} + def _unique_list(lst: List[T], key: Callable[[T], U]) -> List[T]: visited_keys: Set[U] = set() @@ -87,6 +88,50 @@ def _filter_to_metadata(filter_dict: Optional[Dict[str, Any]]) -> Dict[str, Any] return metadata_filter + @staticmethod + def _normalize_metadata_indexing_policy( + metadata_indexing_include: Optional[Iterable[str]], + metadata_indexing_exclude: Optional[Iterable[str]], + collection_indexing_policy: Optional[Dict[str, Any]], + ) -> Dict[str, Any]: + """ + Validate the constructor indexing parameters and normalize them + into a ready-to-use dict for the 'options' when creating a collection. + """ + none_count = sum( + [ + 1 if var is None else 0 + for var in [ + metadata_indexing_include, + metadata_indexing_exclude, + collection_indexing_policy, + ] + ] + ) + if none_count >= 2: + if metadata_indexing_include is not None: + return { + "allow": [ + f"metadata.{md_field}" for md_field in metadata_indexing_include + ] + } + elif metadata_indexing_exclude is not None: + return { + "deny": [ + f"metadata.{md_field}" for md_field in metadata_indexing_exclude + ] + } + elif collection_indexing_policy is not None: + return collection_indexing_policy + else: + return DEFAULT_INDEXING_OPTIONS + else: + raise ValueError( + "At most one of the parameters `metadata_indexing_include`," + " `metadata_indexing_exclude` and `collection_indexing_policy`" + " can be specified as non null." + ) + def __init__( self, *, @@ -102,7 +147,11 @@ def __init__( bulk_insert_batch_concurrency: Optional[int] = None, bulk_insert_overwrite_concurrency: Optional[int] = None, bulk_delete_concurrency: Optional[int] = None, + setup_mode: SetupMode = SetupMode.SYNC, pre_delete_collection: bool = False, + metadata_indexing_include: Optional[Iterable[str]] = None, + metadata_indexing_exclude: Optional[Iterable[str]] = None, + collection_indexing_policy: Optional[Dict[str, Any]] = None, ) -> None: """Wrapper around DataStax Astra DB for vector-store workloads. @@ -151,6 +200,15 @@ def __init__( pre_delete_collection: whether to delete the collection before creating it. If False and the collection already exists, the collection will be used as is. + metadata_indexing_include: an allowlist of the specific metadata subfields + that should be indexed for later filtering in searches. + metadata_indexing_exclude: a denylist of the specific metadata subfields + that should not be indexed for later filtering in searches. + collection_indexing_policy: a full "indexing" specification for + what fields should be indexed for later filtering in searches. + This dict must conform to to the API specifications + (see docs.datastax.com/en/astra/astra-db-vector/api-reference/ + data-api-commands.html#advanced-feature-indexing-clause-on-createcollection) Note: For concurrency in synchronous :meth:`~add_texts`:, as a rule of thumb, on a @@ -169,14 +227,7 @@ def __init__( Remember you can pass concurrency settings to individual calls to :meth:`~add_texts` and :meth:`~add_documents` as well. """ - # Conflicting-arg checks: - if astra_db_client is not None or async_astra_db_client is not None: - if token is not None or api_endpoint is not None: - raise ValueError( - "You cannot pass 'astra_db_client' or 'async_astra_db_client' to " - "AstraDBVectorStore if passing 'token' and 'api_endpoint'." - ) - + self.embedding_dimension: Optional[int] = None self.embedding = embedding self.collection_name = collection_name self.token = token @@ -195,110 +246,52 @@ def __init__( bulk_delete_concurrency or DEFAULT_BULK_DELETE_CONCURRENCY ) # "vector-related" settings - self._embedding_dimension: Optional[int] = None self.metric = metric + embedding_dimension: Union[int, Awaitable[int], None] = None + if setup_mode == SetupMode.ASYNC: + embedding_dimension = self._aget_embedding_dimension() + elif setup_mode == SetupMode.SYNC or setup_mode == SetupMode.OFF: + embedding_dimension = self._get_embedding_dimension() + + # indexing policy setting + self.indexing_policy: Dict[str, Any] = self._normalize_metadata_indexing_policy( + metadata_indexing_include=metadata_indexing_include, + metadata_indexing_exclude=metadata_indexing_exclude, + collection_indexing_policy=collection_indexing_policy, + ) - self.astra_db = astra_db_client - self.async_astra_db = async_astra_db_client - self.collection = None - self.async_collection = None - - if token and api_endpoint: - self.astra_db = AstraDBClient( - token=cast(str, self.token), - api_endpoint=cast(str, self.api_endpoint), - namespace=self.namespace, - ) - self.async_astra_db = AsyncAstraDBClient( - token=cast(str, self.token), - api_endpoint=cast(str, self.api_endpoint), - namespace=self.namespace, - ) - - if self.astra_db is not None: - self.collection = AstraDBCollection( - collection_name=self.collection_name, - astra_db=self.astra_db, - ) - - self.async_setup_db_task: Optional[Task] = None - if self.async_astra_db is not None: - self.async_collection = AsyncAstraDBCollection( - collection_name=self.collection_name, - astra_db=self.async_astra_db, - ) - try: - asyncio.get_running_loop() - self.async_setup_db_task = asyncio.create_task( - self._setup_db(pre_delete_collection) - ) - except RuntimeError: - pass - - if self.async_setup_db_task is None: - if not pre_delete_collection: - self._provision_collection() - else: - self.clear() - - def _ensure_astra_db_client(self) -> None: - """ - If no error is raised, that means self.collection - is also not None (as per constructor flow). - """ - if not self.astra_db: - raise ValueError("Missing AstraDB client") - - async def _setup_db(self, pre_delete_collection: bool) -> None: - if pre_delete_collection: - # _setup_db is called from the constructor only, from a place - # where async_astra_db is not None for sure - await self.async_astra_db.delete_collection( # type: ignore[union-attr] - collection_name=self.collection_name, - ) - await self._aprovision_collection() - - async def _ensure_db_setup(self) -> None: - if self.async_setup_db_task: - await self.async_setup_db_task + self.astra_env = _AstraDBCollectionEnvironment( + collection_name=collection_name, + token=token, + api_endpoint=api_endpoint, + astra_db_client=astra_db_client, + async_astra_db_client=async_astra_db_client, + namespace=namespace, + setup_mode=setup_mode, + pre_delete_collection=pre_delete_collection, + embedding_dimension=embedding_dimension, + metric=metric, + requested_indexing_policy=self.indexing_policy, + default_indexing_policy=DEFAULT_INDEXING_OPTIONS, + ) + self.astra_db = self.astra_env.astra_db + self.async_astra_db = self.astra_env.async_astra_db + self.collection = self.astra_env.collection + self.async_collection = self.astra_env.async_collection def _get_embedding_dimension(self) -> int: - if self._embedding_dimension is None: - self._embedding_dimension = len( - self.embedding.embed_query("This is a sample sentence.") + if self.embedding_dimension is None: + self.embedding_dimension = len( + self.embedding.embed_query(text="This is a sample sentence.") ) - return self._embedding_dimension - - def _provision_collection(self) -> None: - """ - Run the API invocation to create the collection on the backend. - - Internal-usage method, no object members are set, - other than working on the underlying actual storage. - """ - self._ensure_astra_db_client() - # self.astra_db is not None (by _ensure_astra_db_client) - self.astra_db.create_collection( # type: ignore[union-attr] - dimension=self._get_embedding_dimension(), - collection_name=self.collection_name, - metric=self.metric, - ) - - async def _aprovision_collection(self) -> None: - """ - Run the API invocation to create the collection on the backend. + return self.embedding_dimension - Internal-usage method, no object members are set, - other than working on the underlying actual storage. - """ - if not self.async_astra_db: - await run_in_executor(None, self._provision_collection) - else: - await self.async_astra_db.create_collection( - dimension=self._get_embedding_dimension(), - collection_name=self.collection_name, - metric=self.metric, + async def _aget_embedding_dimension(self) -> int: + if self.embedding_dimension is None: + self.embedding_dimension = len( + await self.embedding.aembed_query(text="This is a sample sentence.") ) + return self.embedding_dimension @property def embeddings(self) -> Embeddings: @@ -319,18 +312,13 @@ def _select_relevance_score_fn(self) -> Callable[[float], float]: def clear(self) -> None: """Empty the collection of all its stored entries.""" - self._ensure_astra_db_client() - # self.collection is not None (by _ensure_astra_db_client) - self.collection.delete_many(filter={}) # type: ignore[union-attr] + self.astra_env.ensure_db_setup() + self.collection.delete_many({}) async def aclear(self) -> None: """Empty the collection of all its stored entries.""" - await self._ensure_db_setup() - if not self.async_astra_db: - return await run_in_executor(None, self.clear) - else: - # async_collection not None if so is async_astra_db (constr. flow) - await self.async_collection.delete_many({}) # type: ignore[union-attr] + await self.astra_env.aensure_db_setup() + await self.async_collection.delete_many({}) def delete_by_document_id(self, document_id: str) -> bool: """ @@ -342,9 +330,9 @@ def delete_by_document_id(self, document_id: str) -> bool: Returns True if a document has indeed been deleted, False if ID not found. """ - self._ensure_astra_db_client() + self.astra_env.ensure_db_setup() # self.collection is not None (by _ensure_astra_db_client) - deletion_response = self.collection.delete_one(document_id) # type: ignore[union-attr] + deletion_response = self.collection.delete_one(document_id) return ((deletion_response or {}).get("status") or {}).get( "deletedCount", 0 ) == 1 @@ -359,9 +347,7 @@ async def adelete_by_document_id(self, document_id: str) -> bool: Returns True if a document has indeed been deleted, False if ID not found. """ - await self._ensure_db_setup() - if not self.async_collection: - return await run_in_executor(None, self.delete_by_document_id, document_id) + await self.astra_env.aensure_db_setup() deletion_response = await self.async_collection.delete_one(document_id) return ((deletion_response or {}).get("status") or {}).get( "deletedCount", 0 @@ -443,9 +429,8 @@ def delete_collection(self) -> None: Stored data is lost and unrecoverable, resources are freed. Use with caution. """ - self._ensure_astra_db_client() - # self.astra_db is not None (by _ensure_astra_db_client) - self.astra_db.delete_collection( # type: ignore[union-attr] + self.astra_env.ensure_db_setup() + self.astra_db.delete_collection( collection_name=self.collection_name, ) @@ -456,13 +441,10 @@ async def adelete_collection(self) -> None: Stored data is lost and unrecoverable, resources are freed. Use with caution. """ - await self._ensure_db_setup() - if not self.async_astra_db: - return await run_in_executor(None, self.delete_collection) - else: - await self.async_astra_db.delete_collection( - collection_name=self.collection_name, - ) + await self.astra_env.aensure_db_setup() + await self.async_astra_db.delete_collection( + collection_name=self.collection_name, + ) @staticmethod def _get_documents_to_insert( @@ -576,7 +558,7 @@ def add_texts( f"unsupported arguments ({', '.join(sorted(kwargs.keys()))}), " "which will be ignored." ) - self._ensure_astra_db_client() + self.astra_env.ensure_db_setup() embedding_vectors = self.embedding.embed_documents(list(texts)) documents_to_insert = self._get_documents_to_insert( @@ -666,72 +648,60 @@ async def aadd_texts( Returns: The list of ids of the added texts. """ - await self._ensure_db_setup() - if not self.async_collection: - return await super().aadd_texts( - texts, - metadatas, - ids=ids, - batch_size=batch_size, - batch_concurrency=batch_concurrency, - overwrite_concurrency=overwrite_concurrency, + if kwargs: + warnings.warn( + "Method 'aadd_texts' of AstraDBVectorStore invoked with " + f"unsupported arguments ({', '.join(sorted(kwargs.keys()))}), " + "which will be ignored." ) - else: - if kwargs: - warnings.warn( - "Method 'aadd_texts' of AstraDBVectorStore invoked with " - f"unsupported arguments ({', '.join(sorted(kwargs.keys()))}), " - "which will be ignored." - ) + await self.astra_env.aensure_db_setup() + + embedding_vectors = await self.embedding.aembed_documents(list(texts)) + documents_to_insert = self._get_documents_to_insert( + texts, embedding_vectors, metadatas, ids + ) - embedding_vectors = await self.embedding.aembed_documents(list(texts)) - documents_to_insert = self._get_documents_to_insert( - texts, embedding_vectors, metadatas, ids + async def _handle_batch(document_batch: List[DocDict]) -> List[str]: + # self.async_collection is not None here for sure + im_result = await self.async_collection.insert_many( + documents=document_batch, + options={"ordered": False}, + partial_failures_allowed=True, + ) + batch_inserted, missing_from_batch = self._get_missing_from_batch( + document_batch, im_result ) - async def _handle_batch(document_batch: List[DocDict]) -> List[str]: + async def _handle_missing_document(missing_document: DocDict) -> str: # self.async_collection is not None here for sure - im_result = await self.async_collection.insert_many( # type: ignore[union-attr] - documents=document_batch, - options={"ordered": False}, - partial_failures_allowed=True, - ) - batch_inserted, missing_from_batch = self._get_missing_from_batch( - document_batch, im_result + replacement_result = await self.async_collection.find_one_and_replace( + filter={"_id": missing_document["_id"]}, + replacement=missing_document, ) + return replacement_result["data"]["document"]["_id"] - async def _handle_missing_document(missing_document: DocDict) -> str: - # self.async_collection is not None here for sure - replacement_result = ( - await self.async_collection.find_one_and_replace( # type: ignore[union-attr] - filter={"_id": missing_document["_id"]}, - replacement=missing_document, - ) - ) - return replacement_result["data"]["document"]["_id"] + _u_max_workers = ( + overwrite_concurrency or self.bulk_insert_overwrite_concurrency + ) + batch_replaced = await gather_with_concurrency( + _u_max_workers, + *[_handle_missing_document(doc) for doc in missing_from_batch], + ) + return batch_inserted + batch_replaced - _u_max_workers = ( - overwrite_concurrency or self.bulk_insert_overwrite_concurrency - ) - batch_replaced = await gather_with_concurrency( - _u_max_workers, - *[_handle_missing_document(doc) for doc in missing_from_batch], + _b_max_workers = batch_concurrency or self.bulk_insert_batch_concurrency + all_ids_nested = await gather_with_concurrency( + _b_max_workers, + *[ + _handle_batch(batch) + for batch in batch_iterate( + batch_size or self.batch_size, + documents_to_insert, ) - return batch_inserted + batch_replaced - - _b_max_workers = batch_concurrency or self.bulk_insert_batch_concurrency - all_ids_nested = await gather_with_concurrency( - _b_max_workers, - *[ - _handle_batch(batch) - for batch in batch_iterate( - batch_size or self.batch_size, - documents_to_insert, - ) - ], - ) + ], + ) - return [iid for id_list in all_ids_nested for iid in id_list] + return [iid for id_list in all_ids_nested for iid in id_list] def similarity_search_with_score_id_by_vector( self, @@ -749,7 +719,7 @@ def similarity_search_with_score_id_by_vector( Returns: The list of (Document, score, id), the most similar to the query vector. """ - self._ensure_astra_db_client() + self.astra_env.ensure_db_setup() metadata_parameter = self._filter_to_metadata(filter) # hits = list( @@ -794,15 +764,7 @@ async def asimilarity_search_with_score_id_by_vector( Returns: The list of (Document, score, id), the most similar to the query vector. """ - await self._ensure_db_setup() - if not self.async_collection: - return await run_in_executor( - None, - self.similarity_search_with_score_id_by_vector, - embedding, - k, - filter, - ) + await self.astra_env.aensure_db_setup() metadata_parameter = self._filter_to_metadata(filter) # return [ @@ -1121,7 +1083,7 @@ def max_marginal_relevance_search_by_vector( Returns: The list of Documents selected by maximal marginal relevance. """ - self._ensure_astra_db_client() + self.astra_env.ensure_db_setup() metadata_parameter = self._filter_to_metadata(filter) prefetch_hits = list( @@ -1167,18 +1129,7 @@ async def amax_marginal_relevance_search_by_vector( Returns: The list of Documents selected by maximal marginal relevance. """ - await self._ensure_db_setup() - if not self.async_collection: - return await run_in_executor( - None, - self.max_marginal_relevance_search_by_vector, - embedding, - k, - fetch_k, - lambda_mult, - filter, - **kwargs, - ) + await self.astra_env.aensure_db_setup() metadata_parameter = self._filter_to_metadata(filter) prefetch_hits = [ diff --git a/libs/partners/astradb/tests/integration_tests/test_storage.py b/libs/partners/astradb/tests/integration_tests/test_storage.py index 86e919e4df38f..6370f4554a5c3 100644 --- a/libs/partners/astradb/tests/integration_tests/test_storage.py +++ b/libs/partners/astradb/tests/integration_tests/test_storage.py @@ -174,3 +174,29 @@ def test_bytestore_mset(self, astra_db: AstraDB) -> None: assert result["data"]["document"]["value"] == "dmFsdWUy" finally: astra_db.delete_collection(collection_name) + + def test_indexing_detection(self, astra_db: AstraDB) -> None: + """Test the behaviour against preexisting legacy collections.""" + astra_db.create_collection("lc_test_legacy_store") + astra_db.create_collection( + "lc_test_custom_store", options={"indexing": {"allow": ["my_field"]}} + ) + AstraDBStore(collection_name="lc_test_regular_store", astra_db_client=astra_db) + + # repeated instantiation must work + AstraDBStore(collection_name="lc_test_regular_store", astra_db_client=astra_db) + # on a legacy collection must just give a warning + with pytest.warns(UserWarning) as rec_warnings: + AstraDBStore( + collection_name="lc_test_legacy_store", astra_db_client=astra_db + ) + assert len(rec_warnings) == 1 + # on a custom collection must error + with pytest.raises(ValueError): + AstraDBStore( + collection_name="lc_test_custom_store", astra_db_client=astra_db + ) + + astra_db.delete_collection("lc_test_legacy_store") + astra_db.delete_collection("lc_test_custom_store") + astra_db.delete_collection("lc_test_regular_store") diff --git a/libs/partners/astradb/tests/integration_tests/test_vectorstores.py b/libs/partners/astradb/tests/integration_tests/test_vectorstores.py index 8e6d20152b93f..afa01992a48fd 100644 --- a/libs/partners/astradb/tests/integration_tests/test_vectorstores.py +++ b/libs/partners/astradb/tests/integration_tests/test_vectorstores.py @@ -16,9 +16,11 @@ import json import math import os +import warnings from typing import Iterable, List, Optional, TypedDict import pytest +from astrapy.db import AstraDB, AsyncAstraDB from langchain_core.documents import Document from langchain_core.embeddings import Embeddings @@ -163,7 +165,6 @@ def test_astradb_vectorstore_create_delete( self, astradb_credentials: AstraDBCredentials ) -> None: """Create and delete.""" - from astrapy.db import AstraDB as LibAstraDB emb = SomeEmbeddings(dimension=2) # creation by passing the connection secrets @@ -179,7 +180,7 @@ def test_astradb_vectorstore_create_delete( v_store.clear() # Creation by passing a ready-made astrapy client: - astra_db_client = LibAstraDB( + astra_db_client = AstraDB( **astradb_credentials, ) v_store_2 = AstraDBVectorStore( @@ -206,8 +207,6 @@ async def test_astradb_vectorstore_create_delete_async( ) await v_store.adelete_collection() # Creation by passing a ready-made astrapy client: - from astrapy.db import AsyncAstraDB - astra_db_client = AsyncAstraDB( **astradb_credentials, ) @@ -866,3 +865,188 @@ def test_astradb_vectorstore_metrics( vstore_euc.delete_collection() else: vstore_euc.clear() + + def test_astradb_vectorstore_indexing(self) -> None: + """ + Test that the right errors/warnings are issued depending + on the compatibility of on-DB indexing settings and the requested ones. + + We do NOT check for substrings in the warning messages: that would + be too brittle a test. + """ + astra_db = AstraDB( + token=os.environ["ASTRA_DB_APPLICATION_TOKEN"], + api_endpoint=os.environ["ASTRA_DB_API_ENDPOINT"], + namespace=os.environ.get("ASTRA_DB_KEYSPACE"), + ) + + embe = SomeEmbeddings(dimension=2) + + # creation of three collections to test warnings against + astra_db.create_collection("lc_legacy_coll", dimension=2, metric=None) + AstraDBVectorStore( + astra_db_client=astra_db, + collection_name="lc_default_idx", + embedding=embe, + ) + AstraDBVectorStore( + astra_db_client=astra_db, + collection_name="lc_custom_idx", + embedding=embe, + metadata_indexing_exclude={"long_summary", "the_divine_comedy"}, + ) + + # these invocations should just work without warnings + with warnings.catch_warnings(): + warnings.simplefilter("error") + AstraDBVectorStore( + astra_db_client=astra_db, + collection_name="lc_default_idx", + embedding=embe, + ) + AstraDBVectorStore( + astra_db_client=astra_db, + collection_name="lc_custom_idx", + embedding=embe, + metadata_indexing_exclude={"long_summary", "the_divine_comedy"}, + ) + + # some are to throw an error: + with pytest.raises(ValueError): + AstraDBVectorStore( + astra_db_client=astra_db, + collection_name="lc_default_idx", + embedding=embe, + metadata_indexing_exclude={"long_summary", "the_divine_comedy"}, + ) + + with pytest.raises(ValueError): + AstraDBVectorStore( + astra_db_client=astra_db, + collection_name="lc_custom_idx", + embedding=embe, + metadata_indexing_exclude={"changed_fields"}, + ) + + with pytest.raises(ValueError): + AstraDBVectorStore( + astra_db_client=astra_db, + collection_name="lc_custom_idx", + embedding=embe, + ) + + with pytest.raises(ValueError): + AstraDBVectorStore( + astra_db_client=astra_db, + collection_name="lc_legacy_coll", + embedding=embe, + metadata_indexing_exclude={"long_summary", "the_divine_comedy"}, + ) + + # one case should result in just a warning: + with pytest.warns(UserWarning) as rec_warnings: + AstraDBVectorStore( + astra_db_client=astra_db, + collection_name="lc_legacy_coll", + embedding=embe, + ) + assert len(rec_warnings) == 1 + + # cleanup + astra_db.delete_collection("lc_legacy_coll") + astra_db.delete_collection("lc_default_idx") + astra_db.delete_collection("lc_custom_idx") + + async def test_astradb_vectorstore_indexing_async(self) -> None: + """ + Async version of the same test on warnings/errors related + to incompatible indexing choices. + """ + astra_db = AsyncAstraDB( + token=os.environ["ASTRA_DB_APPLICATION_TOKEN"], + api_endpoint=os.environ["ASTRA_DB_API_ENDPOINT"], + namespace=os.environ.get("ASTRA_DB_KEYSPACE"), + ) + + embe = SomeEmbeddings(dimension=2) + + # creation of three collections to test warnings against + await astra_db.create_collection("lc_legacy_coll", dimension=2, metric=None) + AstraDBVectorStore( + async_astra_db_client=astra_db, + collection_name="lc_default_idx", + embedding=embe, + ) + AstraDBVectorStore( + async_astra_db_client=astra_db, + collection_name="lc_custom_idx", + embedding=embe, + metadata_indexing_exclude={"long_summary", "the_divine_comedy"}, + ) + + # these invocations should just work without warnings + with warnings.catch_warnings(): + warnings.simplefilter("error") + def_store = AstraDBVectorStore( + async_astra_db_client=astra_db, + collection_name="lc_default_idx", + embedding=embe, + ) + await def_store.aadd_texts(["All good."]) + cus_store = AstraDBVectorStore( + async_astra_db_client=astra_db, + collection_name="lc_custom_idx", + embedding=embe, + metadata_indexing_exclude={"long_summary", "the_divine_comedy"}, + ) + await cus_store.aadd_texts(["All good."]) + + # some are to throw an error: + with pytest.raises(ValueError): + def_store = AstraDBVectorStore( + async_astra_db_client=astra_db, + collection_name="lc_default_idx", + embedding=embe, + metadata_indexing_exclude={"long_summary", "the_divine_comedy"}, + ) + await def_store.aadd_texts(["Not working."]) + + with pytest.raises(ValueError): + cus_store = AstraDBVectorStore( + async_astra_db_client=astra_db, + collection_name="lc_custom_idx", + embedding=embe, + metadata_indexing_exclude={"changed_fields"}, + ) + await cus_store.aadd_texts(["Not working."]) + + with pytest.raises(ValueError): + cus_store = AstraDBVectorStore( + async_astra_db_client=astra_db, + collection_name="lc_custom_idx", + embedding=embe, + ) + await cus_store.aadd_texts(["Not working."]) + + with pytest.raises(ValueError): + leg_store = AstraDBVectorStore( + async_astra_db_client=astra_db, + collection_name="lc_legacy_coll", + embedding=embe, + metadata_indexing_exclude={"long_summary", "the_divine_comedy"}, + ) + await leg_store.aadd_texts(["Not working."]) + + # one case should result in just a warning: + with pytest.warns(UserWarning) as rec_warnings: + leg_store = AstraDBVectorStore( + async_astra_db_client=astra_db, + collection_name="lc_legacy_coll", + embedding=embe, + ) + await leg_store.aadd_texts(["Triggering warning."]) + assert len(rec_warnings) == 1 + + await astra_db.delete_collection("lc_legacy_coll") + await astra_db.delete_collection("lc_default_idx") + await astra_db.delete_collection("lc_custom_idx") diff --git a/libs/partners/astradb/tests/unit_tests/test_vectorstores.py b/libs/partners/astradb/tests/unit_tests/test_vectorstores.py index ebfc6978d18c7..0110862203c81 100644 --- a/libs/partners/astradb/tests/unit_tests/test_vectorstores.py +++ b/libs/partners/astradb/tests/unit_tests/test_vectorstores.py @@ -1,9 +1,13 @@ from typing import List from unittest.mock import Mock +import pytest from langchain_core.embeddings import Embeddings -from langchain_astradb.vectorstores import AstraDBVectorStore +from langchain_astradb.vectorstores import ( + DEFAULT_INDEXING_OPTIONS, + AstraDBVectorStore, +) class SomeEmbeddings(Embeddings): @@ -34,12 +38,74 @@ async def aembed_query(self, text: str) -> List[float]: return self.embed_query(text) -def test_initialization() -> None: - """Test integration vectorstore initialization.""" - mock_astra_db = Mock() - embedding = SomeEmbeddings(dimension=2) - AstraDBVectorStore( - embedding=embedding, - collection_name="mock_coll_name", - astra_db_client=mock_astra_db, - ) +class TestAstraDB: + def test_initialization(self) -> None: + """Test integration vectorstore initialization.""" + mock_astra_db = Mock() + embedding = SomeEmbeddings(dimension=2) + AstraDBVectorStore( + embedding=embedding, + collection_name="mock_coll_name", + astra_db_client=mock_astra_db, + ) + + def test_astradb_vectorstore_unit_indexing_normalization(self) -> None: + """Unit test of the indexing policy normalization""" + n3_idx = AstraDBVectorStore._normalize_metadata_indexing_policy( + metadata_indexing_include=None, + metadata_indexing_exclude=None, + collection_indexing_policy=None, + ) + assert n3_idx == DEFAULT_INDEXING_OPTIONS + + al_idx = AstraDBVectorStore._normalize_metadata_indexing_policy( + metadata_indexing_include=["a1", "a2"], + metadata_indexing_exclude=None, + collection_indexing_policy=None, + ) + assert al_idx == {"allow": ["metadata.a1", "metadata.a2"]} + + dl_idx = AstraDBVectorStore._normalize_metadata_indexing_policy( + metadata_indexing_include=None, + metadata_indexing_exclude=["d1", "d2"], + collection_indexing_policy=None, + ) + assert dl_idx == {"deny": ["metadata.d1", "metadata.d2"]} + + custom_policy = { + "deny": ["myfield", "other_field.subfield", "metadata.long_text"] + } + cip_idx = AstraDBVectorStore._normalize_metadata_indexing_policy( + metadata_indexing_include=None, + metadata_indexing_exclude=None, + collection_indexing_policy=custom_policy, + ) + assert cip_idx == custom_policy + + with pytest.raises(ValueError): + AstraDBVectorStore._normalize_metadata_indexing_policy( + metadata_indexing_include=["a"], + metadata_indexing_exclude=["b"], + collection_indexing_policy=None, + ) + + with pytest.raises(ValueError): + AstraDBVectorStore._normalize_metadata_indexing_policy( + metadata_indexing_include=["a"], + metadata_indexing_exclude=None, + collection_indexing_policy={"a": "z"}, + ) + + with pytest.raises(ValueError): + AstraDBVectorStore._normalize_metadata_indexing_policy( + metadata_indexing_include=None, + metadata_indexing_exclude=["b"], + collection_indexing_policy={"a": "z"}, + ) + + with pytest.raises(ValueError): + AstraDBVectorStore._normalize_metadata_indexing_policy( + metadata_indexing_include=["a"], + metadata_indexing_exclude=["b"], + collection_indexing_policy={"a": "z"}, + )