diff --git a/libs/astradb/langchain_astradb/__init__.py b/libs/astradb/langchain_astradb/__init__.py index 0811d51..1695924 100644 --- a/libs/astradb/langchain_astradb/__init__.py +++ b/libs/astradb/langchain_astradb/__init__.py @@ -1,3 +1,5 @@ +from astrapy.info import CollectionVectorServiceOptions + from langchain_astradb.cache import AstraDBCache, AstraDBSemanticCache from langchain_astradb.chat_message_histories import AstraDBChatMessageHistory from langchain_astradb.document_loaders import AstraDBLoader @@ -12,4 +14,5 @@ "AstraDBChatMessageHistory", "AstraDBLoader", "AstraDBVectorStore", + "CollectionVectorServiceOptions", ] diff --git a/libs/astradb/langchain_astradb/utils/astradb.py b/libs/astradb/langchain_astradb/utils/astradb.py index a7d07b0..520c2d0 100644 --- a/libs/astradb/langchain_astradb/utils/astradb.py +++ b/libs/astradb/langchain_astradb/utils/astradb.py @@ -11,6 +11,7 @@ import langchain_core from astrapy.api import APIRequestError from astrapy.db import AstraDB, AstraDBCollection, AsyncAstraDB, AsyncAstraDBCollection +from astrapy.info import CollectionVectorServiceOptions class SetupMode(Enum): @@ -94,6 +95,9 @@ def __init__( metric: Optional[str] = None, requested_indexing_policy: Optional[Dict[str, Any]] = None, default_indexing_policy: Optional[Dict[str, Any]] = None, + collection_vector_service_options: Optional[ + CollectionVectorServiceOptions + ] = None, ) -> None: super().__init__( token, api_endpoint, astra_db_client, async_astra_db_client, namespace @@ -126,12 +130,18 @@ async def _setup_db() -> None: else: dimension = embedding_dimension + # Used for enabling $vectorize on the collection + service_dict: Optional[Dict[str, Any]] = None + if collection_vector_service_options is not None: + service_dict = collection_vector_service_options.as_dict() + try: await async_astra_db.create_collection( collection_name, dimension=dimension, metric=metric, options=_options, + service_dict=service_dict, ) except (APIRequestError, ValueError): # possibly the collection is preexisting and may have legacy, @@ -161,12 +171,18 @@ async def _setup_db() -> None: "set to False" ) else: + # Used for enabling $vectorize on the collection + service_dict: Optional[Dict[str, Any]] = None + if collection_vector_service_options is not None: + service_dict = collection_vector_service_options.as_dict() + try: self.astra_db.create_collection( collection_name, dimension=embedding_dimension, # type: ignore[arg-type] metric=metric, options=_options, + service_dict=service_dict, ) except (APIRequestError, ValueError): # possibly the collection is preexisting and may have legacy, diff --git a/libs/astradb/langchain_astradb/vectorstores.py b/libs/astradb/langchain_astradb/vectorstores.py index ffdea86..a3d16a9 100644 --- a/libs/astradb/langchain_astradb/vectorstores.py +++ b/libs/astradb/langchain_astradb/vectorstores.py @@ -25,6 +25,7 @@ from astrapy.db import ( AsyncAstraDB as AsyncAstraDBClient, ) +from astrapy.info import CollectionVectorServiceOptions from langchain_core.documents import Document from langchain_core.embeddings import Embeddings from langchain_core.runnables.utils import gather_with_concurrency @@ -135,8 +136,8 @@ def _normalize_metadata_indexing_policy( def __init__( self, *, - embedding: Embeddings, collection_name: str, + embedding: Optional[Embeddings] = None, token: Optional[str] = None, api_endpoint: Optional[str] = None, astra_db_client: Optional[AstraDBClient] = None, @@ -152,6 +153,9 @@ def __init__( metadata_indexing_include: Optional[Iterable[str]] = None, metadata_indexing_exclude: Optional[Iterable[str]] = None, collection_indexing_policy: Optional[Dict[str, Any]] = None, + collection_vector_service_options: Optional[ + CollectionVectorServiceOptions + ] = None, ) -> None: """Wrapper around DataStax Astra DB for vector-store workloads. @@ -175,7 +179,10 @@ def __init__( results = vectorstore.similarity_search("Everything's ok", k=1) Args: - embedding: embedding function to use. + embedding: the embeddings function or service to use. + This enables client-side embedding functions or calls to external + embedding providers. Only one of `embedding` or + `collection_vector_service_options` can be provided. collection_name: name of the Astra DB collection to create/use. token: API token for Astra DB usage. api_endpoint: full URL to the API endpoint, such as @@ -209,6 +216,11 @@ def __init__( This dict must conform to to the API specifications (see docs.datastax.com/en/astra/astra-db-vector/api-reference/ data-api-commands.html#advanced-feature-indexing-clause-on-createcollection) + collection_vector_service_options: specifies the use of server-side + embeddings within Astra DB. Only one of `embedding` or + `collection_vector_service_options` can be provided. + NOTE: This feature is under current development. + Note: For concurrency in synchronous :meth:`~add_texts`:, as a rule of thumb, on a @@ -227,12 +239,27 @@ def __init__( Remember you can pass concurrency settings to individual calls to :meth:`~add_texts` and :meth:`~add_documents` as well. """ + # Embedding and collection_vector_service_options are mutually exclusive, + # as both specify how to produce embeddings + if embedding is None and collection_vector_service_options is None: + raise ValueError( + "Either an `embedding` or a `collection_vector_service_options`\ + must be provided." + ) + + if embedding is not None and collection_vector_service_options is not None: + raise ValueError( + "Only one of `embedding` or `collection_vector_service_options`\ + can be provided." + ) + self.embedding_dimension: Optional[int] = None self.embedding = embedding self.collection_name = collection_name self.token = token self.api_endpoint = api_endpoint self.namespace = namespace + self.collection_vector_service_options = collection_vector_service_options # Concurrency settings self.batch_size: int = batch_size or DEFAULT_BATCH_SIZE self.bulk_insert_batch_concurrency: int = ( @@ -248,10 +275,11 @@ def __init__( # "vector-related" settings self.metric = metric embedding_dimension: Union[int, Awaitable[int], None] = None - if setup_mode == SetupMode.ASYNC: - embedding_dimension = self._aget_embedding_dimension() - elif setup_mode == SetupMode.SYNC or setup_mode == SetupMode.OFF: - embedding_dimension = self._get_embedding_dimension() + if self.embedding is not None: + if setup_mode == SetupMode.ASYNC: + embedding_dimension = self._aget_embedding_dimension() + elif setup_mode == SetupMode.SYNC or setup_mode == SetupMode.OFF: + embedding_dimension = self._get_embedding_dimension() # indexing policy setting self.indexing_policy: Dict[str, Any] = self._normalize_metadata_indexing_policy( @@ -273,6 +301,7 @@ def __init__( metric=metric, requested_indexing_policy=self.indexing_policy, default_indexing_policy=DEFAULT_INDEXING_OPTIONS, + collection_vector_service_options=collection_vector_service_options, ) self.astra_db = self.astra_env.astra_db self.async_astra_db = self.astra_env.async_astra_db @@ -280,6 +309,8 @@ def __init__( self.async_collection = self.astra_env.async_collection def _get_embedding_dimension(self) -> int: + assert self.embedding is not None + if self.embedding_dimension is None: self.embedding_dimension = len( self.embedding.embed_query(text="This is a sample sentence.") @@ -287,6 +318,8 @@ def _get_embedding_dimension(self) -> int: return self.embedding_dimension async def _aget_embedding_dimension(self) -> int: + assert self.embedding is not None + if self.embedding_dimension is None: self.embedding_dimension = len( await self.embedding.aembed_query(text="This is a sample sentence.") @@ -294,9 +327,17 @@ async def _aget_embedding_dimension(self) -> int: return self.embedding_dimension @property - def embeddings(self) -> Embeddings: + def embeddings(self) -> Optional[Embeddings]: + """ + Accesses the supplied embeddings object. If using server-side embeddings, + this will return None. + """ return self.embedding + def _using_vectorize(self) -> bool: + """Indicates whether server-side embeddings are being used.""" + return self.collection_vector_service_options is not None + def _select_relevance_score_fn(self) -> Callable[[float], float]: """ The underlying API calls already returns a "score proper", @@ -474,6 +515,36 @@ def _get_documents_to_insert( )[::-1] return uniqued_documents_to_insert + @staticmethod + def _get_vectorize_documents_to_insert( + texts: Iterable[str], + metadatas: Optional[List[dict]] = None, + ids: Optional[List[str]] = None, + ) -> List[DocDict]: + if ids is None: + ids = [uuid.uuid4().hex for _ in texts] + if metadatas is None: + metadatas = [{} for _ in texts] + # + documents_to_insert = [ + { + "_id": b_id, + "$vectorize": b_txt, + "metadata": b_md, + } + for b_txt, b_id, b_md in zip( + texts, + ids, + metadatas, + ) + ] + # make unique by id, keeping the last + uniqued_documents_to_insert = _unique_list( + documents_to_insert[::-1], + lambda document: document["_id"], + )[::-1] + return uniqued_documents_to_insert + @staticmethod def _get_missing_from_batch( document_batch: List[DocDict], insert_result: Dict[str, Any] @@ -555,10 +626,16 @@ def add_texts( ) self.astra_env.ensure_db_setup() - embedding_vectors = self.embedding.embed_documents(list(texts)) - documents_to_insert = self._get_documents_to_insert( - texts, embedding_vectors, metadatas, ids - ) + if self._using_vectorize(): + documents_to_insert = self._get_vectorize_documents_to_insert( + texts, metadatas, ids + ) + else: + assert self.embedding is not None + embedding_vectors = self.embedding.embed_documents(list(texts)) + documents_to_insert = self._get_documents_to_insert( + texts, embedding_vectors, metadatas, ids + ) def _handle_batch(document_batch: List[DocDict]) -> List[str]: # self.collection is not None (by _ensure_astra_db_client) @@ -651,10 +728,17 @@ async def aadd_texts( ) await self.astra_env.aensure_db_setup() - embedding_vectors = await self.embedding.aembed_documents(list(texts)) - documents_to_insert = self._get_documents_to_insert( - texts, embedding_vectors, metadatas, ids - ) + if self._using_vectorize(): + # using server-side embeddings + documents_to_insert = self._get_vectorize_documents_to_insert( + texts, metadatas, ids + ) + else: + assert self.embedding is not None + embedding_vectors = await self.embedding.aembed_documents(list(texts)) + documents_to_insert = self._get_documents_to_insert( + texts, embedding_vectors, metadatas, ids + ) async def _handle_batch(document_batch: List[DocDict]) -> List[str]: # self.async_collection is not None here for sure @@ -783,6 +867,81 @@ async def asimilarity_search_with_score_id_by_vector( ) ] + def _similarity_search_with_score_id_with_vectorize( + self, + query: str, + k: int = 4, + filter: Optional[Dict[str, Any]] = None, + ) -> List[Tuple[Document, float, str]]: + """Return docs most similar to the query with score and id using $vectorize. + + This is only available when using server-side embeddings. + """ + self.astra_env.ensure_db_setup() + metadata_parameter = self._filter_to_metadata(filter) + # + hits = list( + # self.collection is not None (by _ensure_astra_db_client) + self.collection.paginated_find( # type: ignore[union-attr] + filter=metadata_parameter, + sort={"$vectorize": query}, + options={"limit": k, "includeSimilarity": True}, + projection={ + "_id": 1, + "$vectorize": 1, + "metadata": 1, + }, + ) + ) + # + return [ + ( + Document( + # text content is stored in $vectorize instead of content + page_content=hit["$vectorize"], + metadata=hit["metadata"], + ), + hit["$similarity"], + hit["_id"], + ) + for hit in hits + ] + + async def _asimilarity_search_with_score_id_with_vectorize( + self, + query: str, + k: int = 4, + filter: Optional[Dict[str, Any]] = None, + ) -> List[Tuple[Document, float, str]]: + """Return docs most similar to the query with score and id using $vectorize. + + This is only available when using server-side embeddings. + """ + await self.astra_env.aensure_db_setup() + metadata_parameter = self._filter_to_metadata(filter) + # + return [ + ( + Document( + # text content is stored in $vectorize instead of content + page_content=hit["$vectorize"], + metadata=hit["metadata"], + ), + hit["$similarity"], + hit["_id"], + ) + async for hit in self.async_collection.paginated_find( + filter=metadata_parameter, + sort={"$vectorize": query}, + options={"limit": k, "includeSimilarity": True}, + projection={ + "_id": 1, + "$vectorize": 1, + "metadata": 1, + }, + ) + ] + def similarity_search_with_score_id( self, query: str, @@ -799,12 +958,21 @@ def similarity_search_with_score_id( Returns: The list of (Document, score, id), the most similar to the query. """ - embedding_vector = self.embedding.embed_query(query) - return self.similarity_search_with_score_id_by_vector( - embedding=embedding_vector, - k=k, - filter=filter, - ) + + if self._using_vectorize(): + return self._similarity_search_with_score_id_with_vectorize( + query=query, + k=k, + filter=filter, + ) + else: + assert self.embedding is not None + embedding_vector = self.embedding.embed_query(query) + return self.similarity_search_with_score_id_by_vector( + embedding=embedding_vector, + k=k, + filter=filter, + ) async def asimilarity_search_with_score_id( self, @@ -822,12 +990,20 @@ async def asimilarity_search_with_score_id( Returns: The list of (Document, score, id), the most similar to the query. """ - embedding_vector = await self.embedding.aembed_query(query) - return await self.asimilarity_search_with_score_id_by_vector( - embedding=embedding_vector, - k=k, - filter=filter, - ) + if self._using_vectorize(): + return await self._asimilarity_search_with_score_id_with_vectorize( + query=query, + k=k, + filter=filter, + ) + else: + assert self.embedding is not None + embedding_vector = await self.embedding.aembed_query(query) + return await self.asimilarity_search_with_score_id_by_vector( + embedding=embedding_vector, + k=k, + filter=filter, + ) def similarity_search_with_score_by_vector( self, @@ -900,12 +1076,23 @@ def similarity_search( Returns: The list of Documents most similar to the query. """ - embedding_vector = self.embedding.embed_query(query) - return self.similarity_search_by_vector( - embedding_vector, - k, - filter=filter, - ) + if self._using_vectorize(): + return [ + doc + for (doc, _, _) in self._similarity_search_with_score_id_with_vectorize( + query, + k, + filter=filter, + ) + ] + else: + assert self.embedding is not None + embedding_vector = self.embedding.embed_query(query) + return self.similarity_search_by_vector( + embedding_vector, + k, + filter=filter, + ) async def asimilarity_search( self, @@ -924,12 +1111,27 @@ async def asimilarity_search( Returns: The list of Documents most similar to the query. """ - embedding_vector = await self.embedding.aembed_query(query) - return await self.asimilarity_search_by_vector( - embedding_vector, - k, - filter=filter, - ) + if self._using_vectorize(): + return [ + doc + for ( + doc, + _, + _, + ) in await self._asimilarity_search_with_score_id_with_vectorize( + query, + k, + filter=filter, + ) + ] + else: + assert self.embedding is not None + embedding_vector = await self.embedding.aembed_query(query) + return await self.asimilarity_search_by_vector( + embedding_vector, + k, + filter=filter, + ) def similarity_search_by_vector( self, @@ -999,12 +1201,27 @@ def similarity_search_with_score( Returns: The list of (Document, score), the most similar to the query vector. """ - embedding_vector = self.embedding.embed_query(query) - return self.similarity_search_with_score_by_vector( - embedding_vector, - k, - filter=filter, - ) + if self._using_vectorize(): + return [ + (doc, score) + for ( + doc, + score, + doc_id, + ) in self._similarity_search_with_score_id_with_vectorize( + query=query, + k=k, + filter=filter, + ) + ] + else: + assert self.embedding is not None + embedding_vector = self.embedding.embed_query(query) + return self.similarity_search_with_score_by_vector( + embedding_vector, + k, + filter=filter, + ) async def asimilarity_search_with_score( self, @@ -1022,12 +1239,27 @@ async def asimilarity_search_with_score( Returns: The list of (Document, score), the most similar to the query vector. """ - embedding_vector = await self.embedding.aembed_query(query) - return await self.asimilarity_search_with_score_by_vector( - embedding_vector, - k, - filter=filter, - ) + if self._using_vectorize(): + return [ + (doc, score) + for ( + doc, + score, + doc_id, + ) in await self._asimilarity_search_with_score_id_with_vectorize( + query=query, + k=k, + filter=filter, + ) + ] + else: + assert self.embedding is not None + embedding_vector = await self.embedding.aembed_query(query) + return await self.asimilarity_search_with_score_by_vector( + embedding_vector, + k, + filter=filter, + ) @staticmethod def _get_mmr_hits( @@ -1170,14 +1402,18 @@ def max_marginal_relevance_search( Returns: The list of Documents selected by maximal marginal relevance. """ - embedding_vector = self.embedding.embed_query(query) - return self.max_marginal_relevance_search_by_vector( - embedding_vector, - k, - fetch_k, - lambda_mult=lambda_mult, - filter=filter, - ) + if self._using_vectorize(): + raise ValueError("MMR search is unsupported for server-side embeddings.") + else: + assert self.embedding is not None + embedding_vector = self.embedding.embed_query(query) + return self.max_marginal_relevance_search_by_vector( + embedding_vector, + k, + fetch_k, + lambda_mult=lambda_mult, + filter=filter, + ) async def amax_marginal_relevance_search( self, @@ -1205,19 +1441,26 @@ async def amax_marginal_relevance_search( Returns: The list of Documents selected by maximal marginal relevance. """ - embedding_vector = await self.embedding.aembed_query(query) - return await self.amax_marginal_relevance_search_by_vector( - embedding_vector, - k, - fetch_k, - lambda_mult=lambda_mult, - filter=filter, - ) + if self._using_vectorize(): + raise ValueError("MMR search is unsupported for server-side embeddings.") + else: + assert self.embedding is not None + embedding_vector = await self.embedding.aembed_query(query) + return await self.amax_marginal_relevance_search_by_vector( + embedding_vector, + k, + fetch_k, + lambda_mult=lambda_mult, + filter=filter, + ) @classmethod def _from_kwargs( cls: Type[AstraDBVectorStore], - embedding: Embeddings, + embedding: Optional[Embeddings] = None, + collection_vector_service_options: Optional[ + CollectionVectorServiceOptions + ] = None, **kwargs: Any, ) -> AstraDBVectorStore: known_kwargs = { @@ -1268,15 +1511,19 @@ def _from_kwargs( "bulk_insert_overwrite_concurrency" ), bulk_delete_concurrency=kwargs.get("bulk_delete_concurrency"), + collection_vector_service_options=collection_vector_service_options, ) @classmethod def from_texts( cls: Type[AstraDBVectorStore], texts: List[str], - embedding: Embeddings, + embedding: Optional[Embeddings] = None, metadatas: Optional[List[dict]] = None, ids: Optional[List[str]] = None, + collection_vector_service_options: Optional[ + CollectionVectorServiceOptions + ] = None, **kwargs: Any, ) -> AstraDBVectorStore: """Create an Astra DB vectorstore from raw texts. @@ -1284,8 +1531,14 @@ def from_texts( Args: texts: the texts to insert. embedding: the embedding function to use in the store. + This enables client-side embedding functions or calls to external + embedding providers. Only one of `embedding` or + `collection_vector_service_options` can be provided. metadatas: metadata dicts for the texts. ids: ids to associate to the texts. + collection_vector_service_options: specifies the use of server-side + embeddings within Astra DB. Only one of `embedding` or + `collection_vector_service_options` can be provided. **kwargs: you can pass any argument that you would to :meth:`~add_texts` and/or to the 'AstraDBVectorStore' constructor (see these methods for details). These arguments will be @@ -1294,7 +1547,11 @@ def from_texts( Returns: an `AstraDBVectorStore` vectorstore. """ - astra_db_store = AstraDBVectorStore._from_kwargs(embedding, **kwargs) + astra_db_store = AstraDBVectorStore._from_kwargs( + embedding=embedding, + collection_vector_service_options=collection_vector_service_options, + **kwargs, + ) astra_db_store.add_texts( texts=texts, metadatas=metadatas, @@ -1309,9 +1566,12 @@ def from_texts( async def afrom_texts( cls: Type[AstraDBVectorStore], texts: List[str], - embedding: Embeddings, + embedding: Optional[Embeddings] = None, metadatas: Optional[List[dict]] = None, ids: Optional[List[str]] = None, + collection_vector_service_options: Optional[ + CollectionVectorServiceOptions + ] = None, **kwargs: Any, ) -> AstraDBVectorStore: """Create an Astra DB vectorstore from raw texts. @@ -1321,6 +1581,9 @@ async def afrom_texts( embedding: the embedding function to use in the store. metadatas: metadata dicts for the texts. ids: ids to associate to the texts. + collection_vector_service_options: specifies the use of server-side + embeddings within Astra DB. Only one of `embedding` or + `collection_vector_service_options` can be provided. **kwargs: you can pass any argument that you would to :meth:`~add_texts` and/or to the 'AstraDBVectorStore' constructor (see these methods for details). These arguments will be @@ -1329,7 +1592,11 @@ async def afrom_texts( Returns: an `AstraDBVectorStore` vectorstore. """ - astra_db_store = AstraDBVectorStore._from_kwargs(embedding, **kwargs) + astra_db_store = AstraDBVectorStore._from_kwargs( + embedding, + collection_vector_service_options=collection_vector_service_options, + **kwargs, + ) await astra_db_store.aadd_texts( texts=texts, metadatas=metadatas, @@ -1344,7 +1611,10 @@ async def afrom_texts( def from_documents( cls: Type[AstraDBVectorStore], documents: List[Document], - embedding: Embeddings, + embedding: Optional[Embeddings] = None, + collection_vector_service_options: Optional[ + CollectionVectorServiceOptions + ] = None, **kwargs: Any, ) -> AstraDBVectorStore: """Create an Astra DB vectorstore from a document list. @@ -1357,4 +1627,42 @@ def from_documents( Returns: an `AstraDBVectorStore` vectorstore. """ - return super().from_documents(documents, embedding, **kwargs) + texts = [d.page_content for d in documents] + metadatas = [d.metadata for d in documents] + return cls.from_texts( + texts, + embedding=embedding, + metadatas=metadatas, + collection_vector_service_options=collection_vector_service_options, + **kwargs, + ) + + @classmethod + async def afrom_documents( + cls: Type[AstraDBVectorStore], + documents: List[Document], + embedding: Optional[Embeddings] = None, + collection_vector_service_options: Optional[ + CollectionVectorServiceOptions + ] = None, + **kwargs: Any, + ) -> AstraDBVectorStore: + """Create an Astra DB vectorstore from a document list. + + Utility method that defers to 'afrom_texts' (see that one). + + Args: see 'afrom_texts', except here you have to supply 'documents' + in place of 'texts' and 'metadatas'. + + Returns: + an `AstraDBVectorStore` vectorstore. + """ + texts = [d.page_content for d in documents] + metadatas = [d.metadata for d in documents] + return await cls.afrom_texts( + texts, + embedding=embedding, + metadatas=metadatas, + collection_vector_service_options=collection_vector_service_options, + **kwargs, + ) diff --git a/libs/astradb/tests/integration_tests/test_vectorstores.py b/libs/astradb/tests/integration_tests/test_vectorstores.py index 68e4aa9..7faddf8 100644 --- a/libs/astradb/tests/integration_tests/test_vectorstores.py +++ b/libs/astradb/tests/integration_tests/test_vectorstores.py @@ -21,6 +21,7 @@ import pytest from astrapy.db import AstraDB, AsyncAstraDB +from astrapy.info import CollectionVectorServiceOptions from langchain_core.documents import Document from langchain_core.embeddings import Embeddings @@ -33,9 +34,20 @@ COLLECTION_NAME_DIM2 = "lc_test_d2" COLLECTION_NAME_DIM2_EUCLIDEAN = "lc_test_d2_eucl" +COLLECTION_NAME_VECTORIZE = "lc_test_vectorize" MATCH_EPSILON = 0.0001 + +def is_vector_service_available() -> bool: + return all( + [ + "us-west-2" in os.environ.get("ASTRA_DB_API_ENDPOINT", ""), + "astra-dev.datastax.com" in os.environ.get("ASTRA_DB_API_ENDPOINT", ""), + ] + ) + + # Ad-hoc embedding classes: @@ -159,6 +171,33 @@ def store_parseremb( v_store.clear() +@pytest.fixture(scope="function") +def vectorize_store( + astradb_credentials: AstraDBCredentials, +) -> Iterable[AstraDBVectorStore]: + """ + astra db vector store with server-side embeddings using the nvidia model + """ + # Only available in dev us-west-2 now + if not is_vector_service_available(): + pytest.skip("vectorize unavailable") + + options = CollectionVectorServiceOptions( + provider="nvidia", model_name="NV-Embed-QA" + ) + v_store = AstraDBVectorStore( + collection_vector_service_options=options, + collection_name=COLLECTION_NAME_VECTORIZE, + **astradb_credentials, + ) + v_store.clear() + + yield v_store + + # explicilty delete the collection to avoid max collection limit + v_store.delete_collection() + + @pytest.mark.skipif(not _has_env_vars(), reason="Missing Astra DB env. vars") class TestAstraDBVectorStore: def test_astradb_vectorstore_create_delete( @@ -167,7 +206,8 @@ def test_astradb_vectorstore_create_delete( """Create and delete.""" emb = SomeEmbeddings(dimension=2) - # creation by passing the connection secrets + + # Creation by passing the connection secrets v_store = AstraDBVectorStore( embedding=emb, collection_name=COLLECTION_NAME_DIM2, @@ -194,18 +234,37 @@ def test_astradb_vectorstore_create_delete( else: v_store_2.clear() + @pytest.mark.skipif( + not is_vector_service_available(), reason="vectorize unavailable" + ) + def test_astradb_vectorstore_create_delete_vectorize( + self, astradb_credentials: AstraDBCredentials + ) -> None: + """Create and delete with vectorize option.""" + options = CollectionVectorServiceOptions( + provider="nvidia", model_name="NV-Embed-QA" + ) + v_store = AstraDBVectorStore( + collection_vector_service_options=options, + collection_name=COLLECTION_NAME_VECTORIZE, + **astradb_credentials, + ) + v_store.add_texts(["Sample 1"]) + v_store.delete_collection() + async def test_astradb_vectorstore_create_delete_async( self, astradb_credentials: AstraDBCredentials ) -> None: """Create and delete.""" emb = SomeEmbeddings(dimension=2) - # creation by passing the connection secrets + # Creation by passing the connection secrets v_store = AstraDBVectorStore( embedding=emb, collection_name=COLLECTION_NAME_DIM2, **astradb_credentials, ) await v_store.adelete_collection() + # Creation by passing a ready-made astrapy client: astra_db_client = AsyncAstraDB( **astradb_credentials, @@ -220,6 +279,23 @@ async def test_astradb_vectorstore_create_delete_async( else: await v_store_2.aclear() + @pytest.mark.skipif( + not is_vector_service_available(), reason="vectorize unavailable" + ) + async def test_astradb_vectorstore_create_delete_vectorize_async( + self, astradb_credentials: AstraDBCredentials + ) -> None: + """Create and delete with vectorize option.""" + options = CollectionVectorServiceOptions( + provider="nvidia", model_name="NV-Embed-QA" + ) + v_store = AstraDBVectorStore( + collection_vector_service_options=options, + collection_name=COLLECTION_NAME_VECTORIZE, + **astradb_credentials, + ) + await v_store.adelete_collection() + @pytest.mark.skipif( SKIP_COLLECTION_DELETE, reason="Collection-deletion tests are suppressed", @@ -337,6 +413,50 @@ def test_astradb_vectorstore_from_x( else: v_store_2.clear() + @pytest.mark.skipif( + not is_vector_service_available(), reason="vectorize unavailable" + ) + def test_astradb_vectorstore_from_x_vectorize( + self, astradb_credentials: AstraDBCredentials + ) -> None: + """from_texts and from_documents methods with vectorize.""" + options = CollectionVectorServiceOptions( + provider="nvidia", model_name="NV-Embed-QA" + ) + + AstraDBVectorStore( + collection_vector_service_options=options, + collection_name=COLLECTION_NAME_VECTORIZE, + **astradb_credentials, + ).clear() + + # from_texts + v_store = AstraDBVectorStore.from_texts( + texts=["Hi", "Ho"], + collection_vector_service_options=options, + collection_name=COLLECTION_NAME_VECTORIZE, + **astradb_credentials, + ) + try: + assert v_store.similarity_search("Ho", k=1)[0].page_content == "Ho" + finally: + v_store.delete_collection() + + # from_documents + v_store_2 = AstraDBVectorStore.from_documents( + [ + Document(page_content="Hee"), + Document(page_content="Hoi"), + ], + collection_vector_service_options=options, + collection_name=COLLECTION_NAME_VECTORIZE, + **astradb_credentials, + ) + try: + assert v_store_2.similarity_search("Hoi", k=1)[0].page_content == "Hoi" + finally: + v_store_2.delete_collection() + async def test_astradb_vectorstore_from_x_async( self, astradb_credentials: AstraDBCredentials ) -> None: @@ -383,12 +503,58 @@ async def test_astradb_vectorstore_from_x_async( else: await v_store_2.aclear() - def test_astradb_vectorstore_crud(self, store_someemb: AstraDBVectorStore) -> None: + @pytest.mark.skipif( + not is_vector_service_available(), reason="vectorize unavailable" + ) + async def test_astradb_vectorstore_from_x_async_vectorize( + self, astradb_credentials: AstraDBCredentials + ) -> None: + """from_texts and from_documents methods with vectorize.""" + # from_text with vectorize + options = CollectionVectorServiceOptions( + provider="nvidia", model_name="NV-Embed-QA" + ) + v_store = await AstraDBVectorStore.afrom_texts( + texts=["Haa", "Huu"], + collection_vector_service_options=options, + collection_name=COLLECTION_NAME_VECTORIZE, + **astradb_credentials, + ) + try: + assert (await v_store.asimilarity_search("Haa", k=1))[ + 0 + ].page_content == "Haa" + finally: + await v_store.adelete_collection() + + # from_documents with vectorize + v_store_2 = await AstraDBVectorStore.afrom_documents( + [ + Document(page_content="HeeH"), + Document(page_content="HooH"), + ], + collection_vector_service_options=options, + collection_name=COLLECTION_NAME_VECTORIZE, + **astradb_credentials, + ) + try: + assert (await v_store_2.asimilarity_search("HeeH", k=1))[ + 0 + ].page_content == "HeeH" + finally: + await v_store_2.adelete_collection() + + @pytest.mark.parametrize("vector_store", ["store_someemb", "vectorize_store"]) + def test_astradb_vectorstore_crud( + self, vector_store: str, request: pytest.FixtureRequest + ) -> None: """Basic add/delete/update behaviour.""" - res0 = store_someemb.similarity_search("Abc", k=2) + vstore: AstraDBVectorStore = request.getfixturevalue(vector_store) + + res0 = vstore.similarity_search("Abc", k=2) assert res0 == [] # write and check again - store_someemb.add_texts( + vstore.add_texts( texts=["aa", "bb", "cc"], metadatas=[ {"k": "a", "ord": 0}, @@ -397,10 +563,10 @@ def test_astradb_vectorstore_crud(self, store_someemb: AstraDBVectorStore) -> No ], ids=["a", "b", "c"], ) - res1 = store_someemb.similarity_search("Abc", k=5) + res1 = vstore.similarity_search("Abc", k=5) assert {doc.page_content for doc in res1} == {"aa", "bb", "cc"} # partial overwrite and count total entries - store_someemb.add_texts( + vstore.add_texts( texts=["cc", "dd"], metadatas=[ {"k": "c_new", "ord": 102}, @@ -408,46 +574,48 @@ def test_astradb_vectorstore_crud(self, store_someemb: AstraDBVectorStore) -> No ], ids=["c", "d"], ) - res2 = store_someemb.similarity_search("Abc", k=10) + res2 = vstore.similarity_search("Abc", k=10) assert len(res2) == 4 # pick one that was just updated and check its metadata - res3 = store_someemb.similarity_search_with_score_id( + res3 = vstore.similarity_search_with_score_id( query="cc", k=1, filter={"k": "c_new"} ) - doc3, score3, id3 = res3[0] + doc3, _, id3 = res3[0] assert doc3.page_content == "cc" assert doc3.metadata == {"k": "c_new", "ord": 102} - assert score3 > 0.999 # leaving some leeway for approximations... assert id3 == "c" # delete and count again - del1_res = store_someemb.delete(["b"]) + del1_res = vstore.delete(["b"]) assert del1_res is True - del2_res = store_someemb.delete(["a", "c", "Z!"]) + del2_res = vstore.delete(["a", "c", "Z!"]) assert del2_res is True # a non-existing ID was supplied - assert len(store_someemb.similarity_search("xy", k=10)) == 1 + assert len(vstore.similarity_search("xy", k=10)) == 1 # clear store - store_someemb.clear() - assert store_someemb.similarity_search("Abc", k=2) == [] + vstore.clear() + assert vstore.similarity_search("Abc", k=2) == [] # add_documents with "ids" arg passthrough - store_someemb.add_documents( + vstore.add_documents( [ Document(page_content="vv", metadata={"k": "v", "ord": 204}), Document(page_content="ww", metadata={"k": "w", "ord": 205}), ], ids=["v", "w"], ) - assert len(store_someemb.similarity_search("xy", k=10)) == 2 - res4 = store_someemb.similarity_search("ww", k=1, filter={"k": "w"}) + assert len(vstore.similarity_search("xy", k=10)) == 2 + res4 = vstore.similarity_search("ww", k=1, filter={"k": "w"}) assert res4[0].metadata["ord"] == 205 + @pytest.mark.parametrize("vector_store", ["store_someemb", "vectorize_store"]) async def test_astradb_vectorstore_crud_async( - self, store_someemb: AstraDBVectorStore + self, vector_store: str, request: pytest.FixtureRequest ) -> None: """Basic add/delete/update behaviour.""" - res0 = await store_someemb.asimilarity_search("Abc", k=2) + vstore: AstraDBVectorStore = request.getfixturevalue(vector_store) + + res0 = await vstore.asimilarity_search("Abc", k=2) assert res0 == [] # write and check again - await store_someemb.aadd_texts( + await vstore.aadd_texts( texts=["aa", "bb", "cc"], metadatas=[ {"k": "a", "ord": 0}, @@ -456,10 +624,10 @@ async def test_astradb_vectorstore_crud_async( ], ids=["a", "b", "c"], ) - res1 = await store_someemb.asimilarity_search("Abc", k=5) + res1 = await vstore.asimilarity_search("Abc", k=5) assert {doc.page_content for doc in res1} == {"aa", "bb", "cc"} # partial overwrite and count total entries - await store_someemb.aadd_texts( + await vstore.aadd_texts( texts=["cc", "dd"], metadatas=[ {"k": "c_new", "ord": 102}, @@ -467,36 +635,35 @@ async def test_astradb_vectorstore_crud_async( ], ids=["c", "d"], ) - res2 = await store_someemb.asimilarity_search("Abc", k=10) + res2 = await vstore.asimilarity_search("Abc", k=10) assert len(res2) == 4 # pick one that was just updated and check its metadata - res3 = await store_someemb.asimilarity_search_with_score_id( + res3 = await vstore.asimilarity_search_with_score_id( query="cc", k=1, filter={"k": "c_new"} ) - doc3, score3, id3 = res3[0] + doc3, _, id3 = res3[0] assert doc3.page_content == "cc" assert doc3.metadata == {"k": "c_new", "ord": 102} - assert score3 > 0.999 # leaving some leeway for approximations... assert id3 == "c" # delete and count again - del1_res = await store_someemb.adelete(["b"]) + del1_res = await vstore.adelete(["b"]) assert del1_res is True - del2_res = await store_someemb.adelete(["a", "c", "Z!"]) + del2_res = await vstore.adelete(["a", "c", "Z!"]) assert del2_res is False # a non-existing ID was supplied - assert len(await store_someemb.asimilarity_search("xy", k=10)) == 1 + assert len(await vstore.asimilarity_search("xy", k=10)) == 1 # clear store - await store_someemb.aclear() - assert await store_someemb.asimilarity_search("Abc", k=2) == [] + await vstore.aclear() + assert await vstore.asimilarity_search("Abc", k=2) == [] # add_documents with "ids" arg passthrough - await store_someemb.aadd_documents( + await vstore.aadd_documents( [ Document(page_content="vv", metadata={"k": "v", "ord": 204}), Document(page_content="ww", metadata={"k": "w", "ord": 205}), ], ids=["v", "w"], ) - assert len(await store_someemb.asimilarity_search("xy", k=10)) == 2 - res4 = await store_someemb.asimilarity_search("ww", k=1, filter={"k": "w"}) + assert len(await vstore.asimilarity_search("xy", k=10)) == 2 + res4 = await vstore.asimilarity_search("ww", k=1, filter={"k": "w"}) assert res4[0].metadata["ord"] == 205 def test_astradb_vectorstore_mmr(self, store_parseremb: AstraDBVectorStore) -> None: @@ -552,11 +719,31 @@ def _v_from_i(i: int, N: int) -> str: res_i_vals = {doc.metadata["i"] for doc in res1} assert res_i_vals == {0, 4} + def test_astradb_vectorstore_mmr_vectorize_unsupported( + self, vectorize_store: AstraDBVectorStore + ) -> None: + """ + MMR testing with vectorize, currently unsupported. + """ + with pytest.raises(ValueError): + vectorize_store.max_marginal_relevance_search("aa", k=2, fetch_k=3) + + async def test_astradb_vectorstore_mmr_vectorize_unsupported_async( + self, vectorize_store: AstraDBVectorStore + ) -> None: + """ + MMR async testing with vectorize, currently unsupported. + """ + with pytest.raises(ValueError): + await vectorize_store.amax_marginal_relevance_search("aa", k=2, fetch_k=3) + + @pytest.mark.parametrize("vector_store", ["store_someemb", "vectorize_store"]) def test_astradb_vectorstore_metadata( - self, store_someemb: AstraDBVectorStore + self, vector_store: str, request: pytest.FixtureRequest ) -> None: """Metadata filtering.""" - store_someemb.add_documents( + vstore: AstraDBVectorStore = request.getfixturevalue(vector_store) + vstore.add_documents( [ Document( page_content="q", @@ -585,49 +772,51 @@ def test_astradb_vectorstore_metadata( ] ) # no filters - res0 = store_someemb.similarity_search("x", k=10) + res0 = vstore.similarity_search("x", k=10) assert {doc.page_content for doc in res0} == set("qwreio") # single filter - res1 = store_someemb.similarity_search( + res1 = vstore.similarity_search( "x", k=10, filter={"group": "vowel"}, ) assert {doc.page_content for doc in res1} == set("eio") # multiple filters - res2 = store_someemb.similarity_search( + res2 = vstore.similarity_search( "x", k=10, filter={"group": "consonant", "ord": ord("q")}, ) assert {doc.page_content for doc in res2} == set("q") # excessive filters - res3 = store_someemb.similarity_search( + res3 = vstore.similarity_search( "x", k=10, filter={"group": "consonant", "ord": ord("q"), "case": "upper"}, ) assert res3 == [] # filter with logical operator - res4 = store_someemb.similarity_search( + res4 = vstore.similarity_search( "x", k=10, filter={"$or": [{"ord": ord("q")}, {"ord": ord("r")}]}, ) assert {doc.page_content for doc in res4} == {"q", "r"} + @pytest.mark.parametrize("vector_store", ["store_parseremb"]) def test_astradb_vectorstore_similarity_scale( - self, store_parseremb: AstraDBVectorStore + self, vector_store: str, request: pytest.FixtureRequest ) -> None: """Scale of the similarity scores.""" - store_parseremb.add_texts( + vstore: AstraDBVectorStore = request.getfixturevalue(vector_store) + vstore.add_texts( texts=[ json.dumps([1, 1]), json.dumps([-1, -1]), ], ids=["near", "far"], ) - res1 = store_parseremb.similarity_search_with_score( + res1 = vstore.similarity_search_with_score( json.dumps([0.5, 0.5]), k=2, ) @@ -635,18 +824,20 @@ def test_astradb_vectorstore_similarity_scale( sco_near, sco_far = scores assert abs(1 - sco_near) < MATCH_EPSILON and abs(sco_far) < MATCH_EPSILON + @pytest.mark.parametrize("vector_store", ["store_parseremb"]) async def test_astradb_vectorstore_similarity_scale_async( - self, store_parseremb: AstraDBVectorStore + self, vector_store: str, request: pytest.FixtureRequest ) -> None: """Scale of the similarity scores.""" - await store_parseremb.aadd_texts( + vstore: AstraDBVectorStore = request.getfixturevalue(vector_store) + await vstore.aadd_texts( texts=[ json.dumps([1, 1]), json.dumps([-1, -1]), ], ids=["near", "far"], ) - res1 = await store_parseremb.asimilarity_search_with_score( + res1 = await vstore.asimilarity_search_with_score( json.dumps([0.5, 0.5]), k=2, ) @@ -654,24 +845,26 @@ async def test_astradb_vectorstore_similarity_scale_async( sco_near, sco_far = scores assert abs(1 - sco_near) < MATCH_EPSILON and abs(sco_far) < MATCH_EPSILON + @pytest.mark.parametrize("vector_store", ["store_someemb", "vectorize_store"]) def test_astradb_vectorstore_massive_delete( - self, store_someemb: AstraDBVectorStore + self, vector_store: str, request: pytest.FixtureRequest ) -> None: """Larger-scale bulk deletes.""" + vstore: AstraDBVectorStore = request.getfixturevalue(vector_store) M = 50 texts = [str(i + 1 / 7.0) for i in range(2 * M)] ids0 = ["doc_%i" % i for i in range(M)] ids1 = ["doc_%i" % (i + M) for i in range(M)] ids = ids0 + ids1 - store_someemb.add_texts(texts=texts, ids=ids) + vstore.add_texts(texts=texts, ids=ids) # deleting a bunch of these - del_res0 = store_someemb.delete(ids0) + del_res0 = vstore.delete(ids0) assert del_res0 is True # deleting the rest plus a fake one - del_res1 = store_someemb.delete(ids1 + ["ghost!"]) + del_res1 = vstore.delete(ids1 + ["ghost!"]) assert del_res1 is True # ensure no error # nothing left - assert store_someemb.similarity_search("x", k=2 * M) == [] + assert vstore.similarity_search("x", k=2 * M) == [] @pytest.mark.skipif( SKIP_COLLECTION_DELETE, diff --git a/libs/astradb/tests/unit_tests/test_imports.py b/libs/astradb/tests/unit_tests/test_imports.py index c2bcff9..dbffdd5 100644 --- a/libs/astradb/tests/unit_tests/test_imports.py +++ b/libs/astradb/tests/unit_tests/test_imports.py @@ -8,6 +8,7 @@ "AstraDBChatMessageHistory", "AstraDBLoader", "AstraDBVectorStore", + "CollectionVectorServiceOptions", ] diff --git a/libs/astradb/tests/unit_tests/test_vectorstores.py b/libs/astradb/tests/unit_tests/test_vectorstores.py index 0110862..7c24ef6 100644 --- a/libs/astradb/tests/unit_tests/test_vectorstores.py +++ b/libs/astradb/tests/unit_tests/test_vectorstores.py @@ -2,6 +2,7 @@ from unittest.mock import Mock import pytest +from astrapy.info import CollectionVectorServiceOptions from langchain_core.embeddings import Embeddings from langchain_astradb.vectorstores import ( @@ -49,6 +50,30 @@ def test_initialization(self) -> None: astra_db_client=mock_astra_db, ) + # Test with server-side embeddings + vector_options = CollectionVectorServiceOptions( + provider="test", model_name="test" + ) + AstraDBVectorStore( + collection_name="mock_coll_name", + astra_db_client=mock_astra_db, + collection_vector_service_options=vector_options, + ) + + with pytest.raises(ValueError): + AstraDBVectorStore( + embedding=embedding, + collection_name="mock_coll_name", + astra_db_client=mock_astra_db, + collection_vector_service_options=vector_options, + ) + + with pytest.raises(ValueError): + AstraDBVectorStore( + collection_name="mock_coll_name", + astra_db_client=mock_astra_db, + ) + def test_astradb_vectorstore_unit_indexing_normalization(self) -> None: """Unit test of the indexing policy normalization""" n3_idx = AstraDBVectorStore._normalize_metadata_indexing_policy(