diff --git a/libs/astradb/langchain_astradb/graph_vectorstores.py b/libs/astradb/langchain_astradb/graph_vectorstores.py index 0556d18..f8200c9 100644 --- a/libs/astradb/langchain_astradb/graph_vectorstores.py +++ b/libs/astradb/langchain_astradb/graph_vectorstores.py @@ -39,13 +39,14 @@ logger = logging.getLogger(__name__) -class AdjacentNode: +class EmbeddedNode: id: str links: list[Link] embedding: list[float] - def __init__(self, node: Node, embedding: list[float]) -> None: - """Create an Adjacent Node.""" + def __init__(self, doc: Document, embedding: list[float]) -> None: + """Create an Embedded Node.""" + node = _doc_to_node(doc=doc) self.id = node.id or "" self.links = node.links self.embedding = embedding @@ -90,11 +91,11 @@ def _doc_to_node(doc: Document) -> Node: ) -def _incoming_links(node: Node | AdjacentNode) -> set[Link]: +def _incoming_links(node: Node | EmbeddedNode) -> set[Link]: return {link for link in node.links if link.direction in ["in", "bidir"]} -def _outgoing_links(node: Node | AdjacentNode) -> set[Link]: +def _outgoing_links(node: Node | EmbeddedNode) -> set[Link]: return {link for link in node.links if link.direction in ["out", "bidir"]} @@ -104,7 +105,7 @@ def __init__( self, *, collection_name: str, - embedding: Embeddings, + embedding: Embeddings | None = None, metadata_incoming_links_key: str = "incoming_links", token: str | TokenProvider | None = None, api_endpoint: str | None = None, @@ -262,7 +263,6 @@ def __init__( :meth:`~add_texts` and :meth:`~add_documents` as well. """ self.metadata_incoming_links_key = metadata_incoming_links_key - self.embedding = embedding # update indexing policy to ensure incoming_links are indexed if metadata_indexing_include is not None: @@ -362,7 +362,7 @@ def __init__( @property @override def embeddings(self) -> Embeddings | None: - return self.embedding + return self.vector_store.embedding def _get_metadata_filter( self, @@ -454,13 +454,20 @@ async def aadd_nodes( def from_texts( cls: type[AstraDBGraphVectorStore], texts: Iterable[str], - embedding: Embeddings, + embedding: Embeddings | None = None, metadatas: list[dict] | None = None, ids: Iterable[str] | None = None, + collection_vector_service_options: CollectionVectorServiceOptions | None = None, + collection_embedding_api_key: str | EmbeddingHeadersProvider | None = None, **kwargs: Any, ) -> AstraDBGraphVectorStore: """Return AstraDBGraphVectorStore initialized from texts and embeddings.""" - store = cls(embedding=embedding, **kwargs) + store = cls( + embedding=embedding, + collection_vector_service_options=collection_vector_service_options, + collection_embedding_api_key=collection_embedding_api_key, + **kwargs, + ) store.add_texts(texts, metadatas, ids=ids) return store @@ -469,12 +476,19 @@ def from_texts( def from_documents( cls: type[AstraDBGraphVectorStore], documents: Iterable[Document], - embedding: Embeddings, + embedding: Embeddings | None = None, ids: Iterable[str] | None = None, + collection_vector_service_options: CollectionVectorServiceOptions | None = None, + collection_embedding_api_key: str | EmbeddingHeadersProvider | None = None, **kwargs: Any, ) -> AstraDBGraphVectorStore: """Return AstraDBGraphVectorStore initialized from docs and embeddings.""" - store = cls(embedding=embedding, **kwargs) + store = cls( + embedding=embedding, + collection_vector_service_options=collection_vector_service_options, + collection_embedding_api_key=collection_embedding_api_key, + **kwargs, + ) store.add_documents(documents, ids=ids) return store @@ -717,21 +731,43 @@ async def ammr_traversal_search( # noqa: C901 filter: Optional metadata to filter the results. **kwargs: Additional keyword arguments. """ - query_embedding = self.embedding.embed_query(query) - helper = MmrHelper( - k=k, - query_embedding=query_embedding, - lambda_mult=lambda_mult, - score_threshold=score_threshold, - ) - # For each unselected node, stores the outgoing links. outgoing_links_map: dict[str, set[Link]] = {} visited_links: set[Link] = set() - # Map from id to Document + # Map from id to Document, used as a cache retrieved_docs: dict[str, Document] = {} - async def fetch_neighborhood(neighborhood: Sequence[str]) -> None: + def get_candidates(nodes: Iterable[EmbeddedNode]) -> dict[str, list[float]]: + nonlocal outgoing_links_map + + candidates: dict[str, list[float]] = {} + for node in nodes: + if node.id not in outgoing_links_map: + outgoing_links_map[node.id] = _outgoing_links(node=node) + candidates[node.id] = node.embedding + return candidates + + async def fetch_initial_candidates() -> ( + tuple[list[float], dict[str, list[float]]] + ): + """Gets the embedded query and the set of initial candidates. + + If fetch_k is zero, there will be no initial candidates. + """ + nonlocal retrieved_docs + + query_embedding, initial_nodes = await self._get_initial( + query=query, + retrieved_docs=retrieved_docs, + fetch_k=fetch_k, + filter=filter, + ) + + return query_embedding, get_candidates(nodes=initial_nodes) + + async def fetch_neighborhood_candidates( + neighborhood: Sequence[str], + ) -> dict[str, list[float]]: nonlocal outgoing_links_map, visited_links, retrieved_docs # Put the neighborhood into the outgoing links, to avoid adding it @@ -753,41 +789,20 @@ async def fetch_neighborhood(neighborhood: Sequence[str]) -> None: retrieved_docs=retrieved_docs, ) - new_candidates: dict[str, list[float]] = {} - for adjacent_node in adjacent_nodes: - if adjacent_node.id not in outgoing_links_map: - outgoing_links_map[adjacent_node.id] = _outgoing_links( - node=adjacent_node - ) - new_candidates[adjacent_node.id] = adjacent_node.embedding - helper.add_candidates(new_candidates) - - async def fetch_initial_candidates() -> None: - nonlocal outgoing_links_map, visited_links, retrieved_docs - - results = ( - await self.vector_store.asimilarity_search_with_embedding_id_by_vector( - embedding=query_embedding, - k=fetch_k, - filter=filter, - ) - ) - - candidates: dict[str, list[float]] = {} - for doc, embedding, doc_id in results: - if doc_id not in retrieved_docs: - retrieved_docs[doc_id] = doc + return get_candidates(nodes=adjacent_nodes) - if doc_id not in outgoing_links_map: - node = _doc_to_node(doc) - outgoing_links_map[doc_id] = _outgoing_links(node=node) - candidates[doc_id] = embedding - helper.add_candidates(candidates) + query_embedding, initial_candidates = await fetch_initial_candidates() + helper = MmrHelper( + k=k, + query_embedding=query_embedding, + lambda_mult=lambda_mult, + score_threshold=score_threshold, + ) + helper.add_candidates(candidates=initial_candidates) if initial_roots: - await fetch_neighborhood(initial_roots) - if fetch_k > 0: - await fetch_initial_candidates() + neighborhood_candidates = await fetch_neighborhood_candidates(initial_roots) + helper.add_candidates(candidates=neighborhood_candidates) # Tracks the depth of each candidate. depths = {candidate_id: 0 for candidate_id in helper.candidate_ids()} @@ -1142,6 +1157,30 @@ async def _get_outgoing_links(self, source_ids: Iterable[str]) -> set[Link]: return links + async def _get_initial( + self, + query: str, + retrieved_docs: dict[str, Document], + fetch_k: int, + filter: dict[str, Any] | None = None, # noqa: A002 + ) -> tuple[list[float], list[EmbeddedNode]]: + ( + query_embedding, + result, + ) = await self.vector_store.asimilarity_search_with_embedding( + query=query, + k=fetch_k, + filter=filter, + ) + + initial_nodes: list[EmbeddedNode] = [] + for doc, embedding in result: + if doc.id is not None: + retrieved_docs[doc.id] = doc + initial_nodes.append(EmbeddedNode(doc=doc, embedding=embedding)) + + return query_embedding, initial_nodes + async def _get_adjacent( self, links: set[Link], @@ -1149,7 +1188,7 @@ async def _get_adjacent( retrieved_docs: dict[str, Document], k_per_link: int | None = None, filter: dict[str, Any] | None = None, # noqa: A002 - ) -> Iterable[AdjacentNode]: + ) -> Iterable[EmbeddedNode]: """Return the target nodes with incoming links from any of the given links. Args: @@ -1162,7 +1201,7 @@ async def _get_adjacent( Returns: Iterable of adjacent edges. """ - targets: dict[str, AdjacentNode] = {} + targets: dict[str, EmbeddedNode] = {} tasks = [] for link in links: @@ -1172,22 +1211,21 @@ async def _get_adjacent( ) tasks.append( - self.vector_store.asimilarity_search_with_embedding_id_by_vector( + self.vector_store.asimilarity_search_with_embedding_by_vector( embedding=query_embedding, k=k_per_link or 10, filter=metadata_filter, ) ) - results = await asyncio.gather(*tasks) + results: list[list[tuple[Document, list[float]]]] = await asyncio.gather(*tasks) for result in results: - for doc, embedding, doc_id in result: - if doc_id not in retrieved_docs: - retrieved_docs[doc_id] = doc - if doc_id not in targets: - node = _doc_to_node(doc=doc) - targets[doc_id] = AdjacentNode(node=node, embedding=embedding) + for doc, embedding in result: + if doc.id is not None: + retrieved_docs[doc.id] = doc + if doc.id not in targets: + targets[doc.id] = EmbeddedNode(doc=doc, embedding=embedding) # TODO: Consider a combined limit based on the similarity and/or # predicated MMR score? diff --git a/libs/astradb/langchain_astradb/vectorstores.py b/libs/astradb/langchain_astradb/vectorstores.py index 645aa24..bfb03da 100644 --- a/libs/astradb/langchain_astradb/vectorstores.py +++ b/libs/astradb/langchain_astradb/vectorstores.py @@ -1787,13 +1787,13 @@ async def asimilarity_search_with_score_id_by_vector( filter=filter, ) - async def asimilarity_search_with_embedding_id_by_vector( + def similarity_search_with_embedding_by_vector( self, embedding: list[float], k: int = 4, filter: dict[str, Any] | None = None, # noqa: A002 - ) -> list[tuple[Document, list[float], str]]: - """Return docs most similar to embedding vector. + ) -> list[tuple[Document, list[float]]]: + """Return docs most similar to embedding vector with embedding. Args: embedding: Embedding to look up documents similar to. @@ -1801,29 +1801,181 @@ async def asimilarity_search_with_embedding_id_by_vector( filter: Filter on the metadata to apply. Returns: - List of (Document, embedding, id), the most similar to the query vector. + (The query embedding vector, The list of (Document, embedding), + the most similar to the query vector.). + """ + sort = self.document_codec.encode_vector_sort(vector=embedding) + _, doc_emb_list = self._similarity_search_with_embedding_by_sort( + sort=sort, k=k, filter=filter + ) + return doc_emb_list + + async def asimilarity_search_with_embedding_by_vector( + self, + embedding: list[float], + k: int = 4, + filter: dict[str, Any] | None = None, # noqa: A002 + ) -> list[tuple[Document, list[float]]]: + """Return docs most similar to embedding vector with embedding. + + Args: + embedding: Embedding to look up documents similar to. + k: Number of Documents to return. Defaults to 4. + filter: Filter on the metadata to apply. + + Returns: + (The query embedding vector, The list of (Document, embedding), + the most similar to the query vector.). + """ + sort = self.document_codec.encode_vector_sort(vector=embedding) + _, doc_emb_list = await self._asimilarity_search_with_embedding_by_sort( + sort=sort, k=k, filter=filter + ) + return doc_emb_list + + def similarity_search_with_embedding( + self, + query: str, + k: int = 4, + filter: dict[str, Any] | None = None, # noqa: A002 + ) -> tuple[list[float], list[tuple[Document, list[float]]]]: + """Return docs most similar to the query with embedding. + + Also includes the query embedding vector. + + Args: + query: Query to look up documents similar to. + k: Number of Documents to return. Defaults to 4. + filter: Filter on the metadata to apply. + + Returns: + (The query embedding vector, The list of (Document, embedding), + the most similar to the query vector.). + """ + if self.document_codec.server_side_embeddings: + sort = {"$vectorize": query} + else: + query_embedding = self._get_safe_embedding().embed_query(text=query) + # shortcut return if query isn't needed. + if k == 0: + return (query_embedding, []) + sort = self.document_codec.encode_vector_sort(vector=query_embedding) + + return self._similarity_search_with_embedding_by_sort( + sort=sort, k=k, filter=filter + ) + + async def asimilarity_search_with_embedding( + self, + query: str, + k: int = 4, + filter: dict[str, Any] | None = None, # noqa: A002 + ) -> tuple[list[float], list[tuple[Document, list[float]]]]: + """Return docs most similar to the query with embedding. + + Also includes the query embedding vector. + + Args: + query: Query to look up documents similar to. + k: Number of Documents to return. Defaults to 4. + filter: Filter on the metadata to apply. + + Returns: + (The query embedding vector, The list of (Document, embedding), + the most similar to the query vector.). + """ + if self.document_codec.server_side_embeddings: + sort = {"$vectorize": query} + else: + query_embedding = self._get_safe_embedding().embed_query(text=query) + # shortcut return if query isn't needed. + if k == 0: + return (query_embedding, []) + sort = self.document_codec.encode_vector_sort(vector=query_embedding) + + return await self._asimilarity_search_with_embedding_by_sort( + sort=sort, k=k, filter=filter + ) + + async def _asimilarity_search_with_embedding_by_sort( + self, + sort: dict[str, Any], + k: int = 4, + filter: dict[str, Any] | None = None, # noqa: A002 + ) -> tuple[list[float], list[tuple[Document, list[float]]]]: + """Run ANN search with a provided sort clause. + + Returns: + (query_embedding, List of (Document, embedding) most similar to the query). """ await self.astra_env.aensure_db_setup() - metadata_parameter = self.filter_to_query(filter).copy() - results: list[tuple[Document, list[float], str]] = [] - async for hit in self.astra_env.async_collection.find( - filter=metadata_parameter, + async_cursor = self.astra_env.async_collection.find( + filter=self.filter_to_query(filter), projection=self.document_codec.full_projection, limit=k, - include_similarity=True, include_sort_vector=True, - sort=self.document_codec.encode_vector_sort(embedding), - ): - doc = self.document_codec.decode(hit) - if doc is None or doc.id is None: - continue + sort=sort, + ) + sort_vector = await async_cursor.get_sort_vector() + if sort_vector is None: + msg = "Unable to retrieve the server-side embedding of the query." + raise ValueError(msg) + query_embedding = sort_vector - vector = self.document_codec.decode_vector(hit) - if vector is None: - continue + return ( + query_embedding, + [ + (doc, emb) + async for (doc, emb) in ( + ( + self.document_codec.decode(hit), + self.document_codec.decode_vector(hit), + ) + async for hit in async_cursor + ) + if doc is not None and emb is not None + ], + ) + + def _similarity_search_with_embedding_by_sort( + self, + sort: dict[str, Any], + k: int = 4, + filter: dict[str, Any] | None = None, # noqa: A002 + ) -> tuple[list[float], list[tuple[Document, list[float]]]]: + """Run ANN search with a provided sort clause. + + Returns: + (query_embedding, List of (Document, embedding) most similar to the query). + """ + self.astra_env.ensure_db_setup() + cursor = self.astra_env.collection.find( + filter=self.filter_to_query(filter), + projection=self.document_codec.full_projection, + limit=k, + include_sort_vector=True, + sort=sort, + ) + sort_vector = cursor.get_sort_vector() + if sort_vector is None: + msg = "Unable to retrieve the server-side embedding of the query." + raise ValueError(msg) + query_embedding = sort_vector - results.append((doc, vector, doc.id)) - return results + return ( + query_embedding, + [ + (doc, emb) + for (doc, emb) in ( + ( + self.document_codec.decode(hit), + self.document_codec.decode_vector(hit), + ) + for hit in cursor + ) + if doc is not None and emb is not None + ], + ) async def _asimilarity_search_with_score_id_by_sort( self, diff --git a/libs/astradb/tests/integration_tests/conftest.py b/libs/astradb/tests/integration_tests/conftest.py index 0ce30d9..c65763e 100644 --- a/libs/astradb/tests/integration_tests/conftest.py +++ b/libs/astradb/tests/integration_tests/conftest.py @@ -398,7 +398,7 @@ def collection_vz( """A general-purpose $vectorize collection for per-test reuse.""" collection = database.create_collection( COLLECTION_NAME_VZ, - dimension=16, + dimension=1536, check_exists=False, indexing=DEFAULT_INDEXING_OPTIONS, metric="euclidean", @@ -464,7 +464,7 @@ def collection_idxall_vz( """ collection = database.create_collection( COLLECTION_NAME_IDXALL_VZ, - dimension=16, + dimension=1536, check_exists=False, metric="euclidean", service=OPENAI_VECTORIZE_OPTIONS_HEADER, diff --git a/libs/astradb/tests/integration_tests/test_graphvectorstore.py b/libs/astradb/tests/integration_tests/test_graphvectorstore.py index 7b90f2b..67eb076 100644 --- a/libs/astradb/tests/integration_tests/test_graphvectorstore.py +++ b/libs/astradb/tests/integration_tests/test_graphvectorstore.py @@ -5,7 +5,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any import pytest from astrapy.authentication import StaticTokenProvider @@ -19,6 +19,7 @@ from .conftest import ( CUSTOM_CONTENT_KEY, LONG_TEXT, + OPENAI_VECTORIZE_OPTIONS_HEADER, astra_db_env_vars_available, ) @@ -54,9 +55,9 @@ def graph_vector_store_docs() -> list[Document]: F0 the query point is meant to be at (*). - the A are bidirectionally with B - the A are outgoing to T - the A are incoming from F + the A nodes are linked bidirectionally with B + the A nodes are linked outgoing to T + the A nodes are linked incoming from F The links are like: L with L, 0 with 0 and R with R. """ @@ -94,19 +95,102 @@ def graph_vector_store_docs() -> list[Document]: @pytest.fixture -def graph_vector_store_d2( +def graph_vector_store_docs_vz() -> list[Document]: + """ + This is a set of Documents to pre-populate a graph vector store, + with entries placed in a certain way. + + the A nodes are linked bidirectionally with B + the A nodes are linked outgoing to T + the A nodes are linked incoming from F + The links are like: L with L, 0 with 0 and R with R. + """ + + docs_a = [ # docs related to space and the universe + Document( + id="AL", page_content="planets orbit quietly", metadata={"label": "AL"} + ), + Document(id="A0", page_content="distant stars shine", metadata={"label": "A0"}), + Document( + id="AR", page_content="nebulae swirl in space", metadata={"label": "AR"} + ), + ] + docs_b = [ # docs related to emotions and relationships + Document(id="BL", page_content="hearts intertwined", metadata={"label": "BL"}), + Document(id="B0", page_content="a gentle embrace", metadata={"label": "B0"}), + Document(id="BL", page_content="love conquers all", metadata={"label": "BR"}), + ] + docs_f = [ # docs related to technology and programming + Document( + id="FL", page_content="code compiles efficiently", metadata={"label": "FL"} + ), + Document( + id="F0", page_content="a neural network learns", metadata={"label": "F0"} + ), + Document( + id="FR", page_content="data structures organize", metadata={"label": "FR"} + ), + ] + docs_t = [ # docs related to nature and wildlife + Document( + id="TL", page_content="trees sway in the wind", metadata={"label": "TL"} + ), + Document(id="T0", page_content="a river runs deep", metadata={"label": "T0"}), + Document( + id="TR", page_content="birds chirping at dawn", metadata={"label": "TR"} + ), + ] + for doc_a, suffix in zip(docs_a, ["l", "0", "r"]): + add_links(doc_a, Link.bidir(kind="ab_example", tag=f"tag_{suffix}")) + add_links(doc_a, Link.outgoing(kind="at_example", tag=f"tag_{suffix}")) + add_links(doc_a, Link.incoming(kind="af_example", tag=f"tag_{suffix}")) + for doc_b, suffix in zip(docs_b, ["l", "0", "r"]): + add_links(doc_b, Link.bidir(kind="ab_example", tag=f"tag_{suffix}")) + for doc_t, suffix in zip(docs_t, ["l", "0", "r"]): + add_links(doc_t, Link.incoming(kind="at_example", tag=f"tag_{suffix}")) + for doc_f, suffix in zip(docs_f, ["l", "0", "r"]): + add_links(doc_f, Link.outgoing(kind="af_example", tag=f"tag_{suffix}")) + return docs_a + docs_b + docs_f + docs_t + + +@pytest.fixture +def auth_kwargs( astra_db_credentials: AstraDBCredentials, +) -> dict[str, Any]: + return { + "token": StaticTokenProvider(astra_db_credentials["token"]), + "api_endpoint": astra_db_credentials["api_endpoint"], + "namespace": astra_db_credentials["namespace"], + "environment": astra_db_credentials["environment"], + } + + +@pytest.fixture +def graph_vector_store_d2( + auth_kwargs: dict[str, Any], empty_collection_d2: Collection, embedding_d2: Embeddings, ) -> AstraDBGraphVectorStore: return AstraDBGraphVectorStore( embedding=embedding_d2, collection_name=empty_collection_d2.name, - token=StaticTokenProvider(astra_db_credentials["token"]), - api_endpoint=astra_db_credentials["api_endpoint"], - namespace=astra_db_credentials["namespace"], - environment=astra_db_credentials["environment"], setup_mode=SetupMode.OFF, + **auth_kwargs, + ) + + +@pytest.fixture +def graph_vector_store_vz( + auth_kwargs: dict[str, Any], + openai_api_key: str, + empty_collection_vz: Collection, +) -> AstraDBGraphVectorStore: + return AstraDBGraphVectorStore( + collection_vector_service_options=OPENAI_VECTORIZE_OPTIONS_HEADER, + collection_embedding_api_key=openai_api_key, + collection_name=empty_collection_vz.name, + setup_mode=SetupMode.OFF, + **auth_kwargs, ) @@ -119,9 +203,18 @@ def populated_graph_vector_store_d2( return graph_vector_store_d2 +@pytest.fixture +def populated_graph_vector_store_vz( + graph_vector_store_vz: AstraDBGraphVectorStore, + graph_vector_store_docs_vz: list[Document], +) -> AstraDBGraphVectorStore: + graph_vector_store_vz.add_documents(graph_vector_store_docs_vz) + return graph_vector_store_vz + + @pytest.fixture def autodetect_populated_graph_vector_store_d2( - astra_db_credentials: AstraDBCredentials, + auth_kwargs: dict[str, Any], database: Database, embedding_d2: Embeddings, graph_vector_store_docs: list[Document], @@ -160,22 +253,64 @@ def autodetect_populated_graph_vector_store_d2( }, ] ) - gstore = AstraDBGraphVectorStore( + g_store = AstraDBGraphVectorStore( embedding=embedding_d2, collection_name=ephemeral_collection_cleaner_idxall_d2, metadata_incoming_links_key="x_link_to_x", - token=StaticTokenProvider(astra_db_credentials["token"]), - api_endpoint=astra_db_credentials["api_endpoint"], - namespace=astra_db_credentials["namespace"], - environment=astra_db_credentials["environment"], content_field="*", autodetect_collection=True, + **auth_kwargs, ) - gstore.add_documents(graph_vector_store_docs) - return gstore + g_store.add_documents(graph_vector_store_docs) + return g_store -def assert_all_flat_docs(collection: Collection) -> None: +@pytest.fixture +def autodetect_populated_graph_vector_store_vz( + auth_kwargs: dict[str, Any], + openai_api_key: str, + graph_vector_store_docs_vz: list[Document], + empty_collection_idxall_vz: Collection, +) -> AstraDBGraphVectorStore: + """ + Pre-populate the collection and have (VectorStore)autodetect work on it, + then create and return a GraphVectorStore, additionally filled with + the same (graph-)entries as for `populated_graph_vector_store_vz`. + """ + empty_collection_idxall_vz.insert_many( + [ + { + "_id": "1", + "$vectorize": "Cont1", + "mds": "S", + "mdi": 100, + }, + { + "_id": "2", + "$vectorize": "Cont2", + "mds": "T", + "mdi": 101, + }, + { + "_id": "3", + "$vectorize": "Cont3", + "mds": "U", + "mdi": 102, + }, + ] + ) + g_store = AstraDBGraphVectorStore( + collection_embedding_api_key=openai_api_key, + collection_name=empty_collection_idxall_vz.name, + metadata_incoming_links_key="x_link_to_x", + autodetect_collection=True, + **auth_kwargs, + ) + g_store.add_documents(graph_vector_store_docs_vz) + return g_store + + +def assert_all_flat_docs(collection: Collection, is_vectorize: bool) -> None: # noqa: FBT001 """ Check that all docs in the store obey the underlying (flat) autodetected doc schema on DB. @@ -183,7 +318,8 @@ def assert_all_flat_docs(collection: Collection) -> None: """ for doc in collection.find({}, projection={"*": True}): assert all(not isinstance(v, dict) for v in doc.values()) - assert CUSTOM_CONTENT_KEY in doc + content_key = "$vectorize" if is_vectorize else CUSTOM_CONTENT_KEY + assert content_key in doc assert isinstance(doc["$vector"], list) @@ -192,180 +328,285 @@ def assert_all_flat_docs(collection: Collection) -> None: ) class TestAstraDBGraphVectorStore: @pytest.mark.parametrize( - ("store_name", "is_autodetected"), + ("store_name", "is_autodetected", "is_vectorize"), [ - ("populated_graph_vector_store_d2", False), - ("autodetect_populated_graph_vector_store_d2", True), + ("populated_graph_vector_store_d2", False, False), + ("autodetect_populated_graph_vector_store_d2", True, False), + ("populated_graph_vector_store_vz", False, True), + ("autodetect_populated_graph_vector_store_vz", True, True), + ], + ids=[ + "native_store_d2", + "autodetected_store_d2", + "native_store_vz", + "autodetected_store_vz", ], - ids=["native_store", "autodetected_store"], ) def test_gvs_similarity_search_sync( self, *, store_name: str, is_autodetected: bool, + is_vectorize: bool, request: pytest.FixtureRequest, ) -> None: """Simple (non-graph) similarity search on a graph vector store.""" - store: AstraDBGraphVectorStore = request.getfixturevalue(store_name) - ss_response = store.similarity_search(query="[2, 10]", k=2) + g_store: AstraDBGraphVectorStore = request.getfixturevalue(store_name) + query = "universe" if is_vectorize else "[2, 10]" + embedding = [2.0, 10.0] + + ss_response = g_store.similarity_search(query=query, k=2) ss_labels = [doc.metadata["label"] for doc in ss_response] assert ss_labels == ["AR", "A0"] - ss_by_v_response = store.similarity_search_by_vector(embedding=[2, 10], k=2) - ss_by_v_labels = [doc.metadata["label"] for doc in ss_by_v_response] - assert ss_by_v_labels == ["AR", "A0"] + + if is_vectorize: + with pytest.raises( + ValueError, match=r"Searching by vector .* embeddings is not allowed" + ): + g_store.similarity_search_by_vector(embedding=embedding, k=2) + else: + ss_by_v_response = g_store.similarity_search_by_vector( + embedding=embedding, k=2 + ) + ss_by_v_labels = [doc.metadata["label"] for doc in ss_by_v_response] + assert ss_by_v_labels == ["AR", "A0"] + if is_autodetected: - assert_all_flat_docs(store.vector_store.astra_env.collection) + assert_all_flat_docs( + g_store.vector_store.astra_env.collection, is_vectorize=is_vectorize + ) @pytest.mark.parametrize( - ("store_name", "is_autodetected"), + ("store_name", "is_autodetected", "is_vectorize"), [ - ("populated_graph_vector_store_d2", False), - ("autodetect_populated_graph_vector_store_d2", True), + ("populated_graph_vector_store_d2", False, False), + ("autodetect_populated_graph_vector_store_d2", True, False), + ("populated_graph_vector_store_vz", False, True), + ("autodetect_populated_graph_vector_store_vz", True, True), + ], + ids=[ + "native_store_d2", + "autodetected_store_d2", + "native_store_vz", + "autodetected_store_vz", ], - ids=["native_store", "autodetected_store"], ) async def test_gvs_similarity_search_async( self, *, store_name: str, is_autodetected: bool, + is_vectorize: bool, request: pytest.FixtureRequest, ) -> None: """Simple (non-graph) similarity search on a graph vector store.""" - store: AstraDBGraphVectorStore = request.getfixturevalue(store_name) - ss_response = await store.asimilarity_search(query="[2, 10]", k=2) + g_store: AstraDBGraphVectorStore = request.getfixturevalue(store_name) + query = "universe" if is_vectorize else "[2, 10]" + embedding = [2.0, 10.0] + + ss_response = await g_store.asimilarity_search(query=query, k=2) ss_labels = [doc.metadata["label"] for doc in ss_response] assert ss_labels == ["AR", "A0"] - ss_by_v_response = await store.asimilarity_search_by_vector( - embedding=[2, 10], k=2 - ) - ss_by_v_labels = [doc.metadata["label"] for doc in ss_by_v_response] - assert ss_by_v_labels == ["AR", "A0"] + + if is_vectorize: + with pytest.raises( + ValueError, match=r"Searching by vector .* embeddings is not allowed" + ): + await g_store.asimilarity_search_by_vector(embedding=embedding, k=2) + else: + ss_by_v_response = await g_store.asimilarity_search_by_vector( + embedding=embedding, k=2 + ) + ss_by_v_labels = [doc.metadata["label"] for doc in ss_by_v_response] + assert ss_by_v_labels == ["AR", "A0"] + if is_autodetected: - assert_all_flat_docs(store.vector_store.astra_env.collection) + assert_all_flat_docs( + g_store.vector_store.astra_env.collection, is_vectorize=is_vectorize + ) @pytest.mark.parametrize( - ("store_name", "is_autodetected"), + ("store_name", "is_autodetected", "is_vectorize"), [ - ("populated_graph_vector_store_d2", False), - ("autodetect_populated_graph_vector_store_d2", True), + ("populated_graph_vector_store_d2", False, False), + ("autodetect_populated_graph_vector_store_d2", True, False), + ("populated_graph_vector_store_vz", False, True), + ("autodetect_populated_graph_vector_store_vz", True, True), + ], + ids=[ + "native_store_d2", + "autodetected_store_d2", + "native_store_vz", + "autodetected_store_vz", ], - ids=["native_store", "autodetected_store"], ) def test_gvs_traversal_search_sync( self, *, store_name: str, is_autodetected: bool, + is_vectorize: bool, request: pytest.FixtureRequest, ) -> None: """Graph traversal search on a graph vector store.""" - store: AstraDBGraphVectorStore = request.getfixturevalue(store_name) - ts_response = store.traversal_search(query="[2, 10]", k=2, depth=2) + g_store: AstraDBGraphVectorStore = request.getfixturevalue(store_name) + query = "universe" if is_vectorize else "[2, 10]" + # this is a set, as some of the internals of trav.search are set-driven # so ordering is not deterministic: - ts_labels = {doc.metadata["label"] for doc in ts_response} + ts_labels = { + doc.metadata["label"] + for doc in g_store.traversal_search(query=query, k=2, depth=2) + } assert ts_labels == {"AR", "A0", "BR", "B0", "TR", "T0"} if is_autodetected: - assert_all_flat_docs(store.vector_store.astra_env.collection) + assert_all_flat_docs( + g_store.vector_store.astra_env.collection, is_vectorize=is_vectorize + ) @pytest.mark.parametrize( - ("store_name", "is_autodetected"), + ("store_name", "is_autodetected", "is_vectorize"), [ - ("populated_graph_vector_store_d2", False), - ("autodetect_populated_graph_vector_store_d2", True), + ("populated_graph_vector_store_d2", False, False), + ("autodetect_populated_graph_vector_store_d2", True, False), + ("populated_graph_vector_store_vz", False, True), + ("autodetect_populated_graph_vector_store_vz", True, True), + ], + ids=[ + "native_store_d2", + "autodetected_store_d2", + "native_store_vz", + "autodetected_store_vz", ], - ids=["native_store", "autodetected_store"], ) async def test_gvs_traversal_search_async( self, *, store_name: str, is_autodetected: bool, + is_vectorize: bool, request: pytest.FixtureRequest, ) -> None: """Graph traversal search on a graph vector store.""" - store: AstraDBGraphVectorStore = request.getfixturevalue(store_name) - ts_labels = set() - async for doc in store.atraversal_search(query="[2, 10]", k=2, depth=2): - ts_labels.add(doc.metadata["label"]) + g_store: AstraDBGraphVectorStore = request.getfixturevalue(store_name) + query = "universe" if is_vectorize else "[2, 10]" + # this is a set, as some of the internals of trav.search are set-driven # so ordering is not deterministic: + ts_labels = { + doc.metadata["label"] + async for doc in g_store.atraversal_search(query=query, k=2, depth=2) + } assert ts_labels == {"AR", "A0", "BR", "B0", "TR", "T0"} if is_autodetected: - assert_all_flat_docs(store.vector_store.astra_env.collection) + assert_all_flat_docs( + g_store.vector_store.astra_env.collection, is_vectorize=is_vectorize + ) @pytest.mark.parametrize( - ("store_name", "is_autodetected"), + ("store_name", "is_autodetected", "is_vectorize"), [ - ("populated_graph_vector_store_d2", False), - ("autodetect_populated_graph_vector_store_d2", True), + ("populated_graph_vector_store_d2", False, False), + ("autodetect_populated_graph_vector_store_d2", True, False), + ("populated_graph_vector_store_vz", False, True), + ("autodetect_populated_graph_vector_store_vz", True, True), + ], + ids=[ + "native_store_d2", + "autodetected_store_d2", + "native_store_vz", + "autodetected_store_vz", ], - ids=["native_store", "autodetected_store"], ) def test_gvs_mmr_traversal_search_sync( self, *, store_name: str, is_autodetected: bool, + is_vectorize: bool, request: pytest.FixtureRequest, ) -> None: """MMR Graph traversal search on a graph vector store.""" - store: AstraDBGraphVectorStore = request.getfixturevalue(store_name) - mt_response = store.mmr_traversal_search( - query="[2, 10]", - k=2, - depth=2, - fetch_k=1, - adjacent_k=2, - lambda_mult=0.1, - ) - # TODO: can this rightfully be a list (or must it be a set)? - mt_labels = {doc.metadata["label"] for doc in mt_response} - assert mt_labels == {"AR", "BR"} + g_store: AstraDBGraphVectorStore = request.getfixturevalue(store_name) + query = "universe" if is_vectorize else "[2, 10]" + + mt_labels = [ + doc.metadata["label"] + for doc in g_store.mmr_traversal_search( + query=query, + k=2, + depth=2, + fetch_k=1, + adjacent_k=2, + lambda_mult=0.1, + ) + ] + + assert mt_labels == ["AR", "BR"] if is_autodetected: - assert_all_flat_docs(store.vector_store.astra_env.collection) + assert_all_flat_docs( + g_store.vector_store.astra_env.collection, is_vectorize=is_vectorize + ) @pytest.mark.parametrize( - ("store_name", "is_autodetected"), + ("store_name", "is_autodetected", "is_vectorize"), [ - ("populated_graph_vector_store_d2", False), - ("autodetect_populated_graph_vector_store_d2", True), + ("populated_graph_vector_store_d2", False, False), + ("autodetect_populated_graph_vector_store_d2", True, False), + ("populated_graph_vector_store_vz", False, True), + ("autodetect_populated_graph_vector_store_vz", True, True), + ], + ids=[ + "native_store_d2", + "autodetected_store_d2", + "native_store_vz", + "autodetected_store_vz", ], - ids=["native_store", "autodetected_store"], ) async def test_gvs_mmr_traversal_search_async( self, *, store_name: str, is_autodetected: bool, + is_vectorize: bool, request: pytest.FixtureRequest, ) -> None: """MMR Graph traversal search on a graph vector store.""" - store: AstraDBGraphVectorStore = request.getfixturevalue(store_name) - mt_labels = set() - async for doc in store.ammr_traversal_search( - query="[2, 10]", - k=2, - depth=2, - fetch_k=1, - adjacent_k=2, - lambda_mult=0.1, - ): - mt_labels.add(doc.metadata["label"]) - # TODO: can this rightfully be a list (or must it be a set)? - assert mt_labels == {"AR", "BR"} + g_store: AstraDBGraphVectorStore = request.getfixturevalue(store_name) + query = "universe" if is_vectorize else "[2, 10]" + + mt_labels = [ + doc.metadata["label"] + async for doc in g_store.ammr_traversal_search( + query=query, + k=2, + depth=2, + fetch_k=1, + adjacent_k=2, + lambda_mult=0.1, + ) + ] + + assert mt_labels == ["AR", "BR"] if is_autodetected: - assert_all_flat_docs(store.vector_store.astra_env.collection) + assert_all_flat_docs( + g_store.vector_store.astra_env.collection, is_vectorize=is_vectorize + ) @pytest.mark.parametrize( - ("store_name"), + "store_name", [ - ("populated_graph_vector_store_d2"), - ("autodetect_populated_graph_vector_store_d2"), + "populated_graph_vector_store_d2", + "autodetect_populated_graph_vector_store_d2", + "populated_graph_vector_store_vz", + "autodetect_populated_graph_vector_store_vz", + ], + ids=[ + "native_store_d2", + "autodetected_store_d2", + "native_store_vz", + "autodetected_store_vz", ], - ids=["native_store", "autodetected_store"], ) def test_gvs_metadata_search_sync( self, @@ -374,13 +615,13 @@ def test_gvs_metadata_search_sync( request: pytest.FixtureRequest, ) -> None: """Metadata search on a graph vector store.""" - store: AstraDBGraphVectorStore = request.getfixturevalue(store_name) - mt_response = store.metadata_search( + g_store: AstraDBGraphVectorStore = request.getfixturevalue(store_name) + mt_response = g_store.metadata_search( filter={"label": "T0"}, n=2, ) doc: Document = next(iter(mt_response)) - assert doc.page_content == "[-10, 0]" + assert doc.id == "T0" links = doc.metadata["links"] assert len(links) == 1 link: Link = links.pop() @@ -390,12 +631,19 @@ def test_gvs_metadata_search_sync( assert link.tag == "tag_0" @pytest.mark.parametrize( - ("store_name"), + "store_name", [ - ("populated_graph_vector_store_d2"), - ("autodetect_populated_graph_vector_store_d2"), + "populated_graph_vector_store_d2", + "autodetect_populated_graph_vector_store_d2", + "populated_graph_vector_store_vz", + "autodetect_populated_graph_vector_store_vz", + ], + ids=[ + "native_store_d2", + "autodetected_store_d2", + "native_store_vz", + "autodetected_store_vz", ], - ids=["native_store", "autodetected_store"], ) async def test_gvs_metadata_search_async( self, @@ -404,13 +652,13 @@ async def test_gvs_metadata_search_async( request: pytest.FixtureRequest, ) -> None: """Metadata search on a graph vector store.""" - store: AstraDBGraphVectorStore = request.getfixturevalue(store_name) - mt_response = await store.ametadata_search( + g_store: AstraDBGraphVectorStore = request.getfixturevalue(store_name) + mt_response = await g_store.ametadata_search( filter={"label": "T0"}, n=2, ) doc: Document = next(iter(mt_response)) - assert doc.page_content == "[-10, 0]" + assert doc.id == "T0" links: set[Link] = doc.metadata["links"] assert len(links) == 1 link: Link = links.pop() @@ -420,12 +668,19 @@ async def test_gvs_metadata_search_async( assert link.tag == "tag_0" @pytest.mark.parametrize( - ("store_name"), + "store_name", [ - ("populated_graph_vector_store_d2"), - ("autodetect_populated_graph_vector_store_d2"), + "populated_graph_vector_store_d2", + "autodetect_populated_graph_vector_store_d2", + "populated_graph_vector_store_vz", + "autodetect_populated_graph_vector_store_vz", + ], + ids=[ + "native_store_d2", + "autodetected_store_d2", + "native_store_vz", + "autodetected_store_vz", ], - ids=["native_store", "autodetected_store"], ) def test_gvs_get_by_document_id_sync( self, @@ -434,10 +689,10 @@ def test_gvs_get_by_document_id_sync( request: pytest.FixtureRequest, ) -> None: """Get by document_id on a graph vector store.""" - store: AstraDBGraphVectorStore = request.getfixturevalue(store_name) - doc = store.get_by_document_id(document_id="FL") + g_store: AstraDBGraphVectorStore = request.getfixturevalue(store_name) + doc = g_store.get_by_document_id(document_id="FL") assert doc is not None - assert doc.page_content == "[1, -9]" + assert doc.metadata["label"] == "FL" links = doc.metadata["links"] assert len(links) == 1 link: Link = links.pop() @@ -446,16 +701,23 @@ def test_gvs_get_by_document_id_sync( assert link.kind == "af_example" assert link.tag == "tag_l" - invalid_doc = store.get_by_document_id(document_id="invalid") + invalid_doc = g_store.get_by_document_id(document_id="invalid") assert invalid_doc is None @pytest.mark.parametrize( - ("store_name"), + "store_name", [ - ("populated_graph_vector_store_d2"), - ("autodetect_populated_graph_vector_store_d2"), + "populated_graph_vector_store_d2", + "autodetect_populated_graph_vector_store_d2", + "populated_graph_vector_store_vz", + "autodetect_populated_graph_vector_store_vz", + ], + ids=[ + "native_store_d2", + "autodetected_store_d2", + "native_store_vz", + "autodetected_store_vz", ], - ids=["native_store", "autodetected_store"], ) async def test_gvs_get_by_document_id_async( self, @@ -464,10 +726,10 @@ async def test_gvs_get_by_document_id_async( request: pytest.FixtureRequest, ) -> None: """Get by document_id on a graph vector store.""" - store: AstraDBGraphVectorStore = request.getfixturevalue(store_name) - doc = await store.aget_by_document_id(document_id="FL") + g_store: AstraDBGraphVectorStore = request.getfixturevalue(store_name) + doc = await g_store.aget_by_document_id(document_id="FL") assert doc is not None - assert doc.page_content == "[1, -9]" + assert doc.metadata["label"] == "FL" links = doc.metadata["links"] assert len(links) == 1 link: Link = links.pop() @@ -476,71 +738,127 @@ async def test_gvs_get_by_document_id_async( assert link.kind == "af_example" assert link.tag == "tag_l" - invalid_doc = await store.aget_by_document_id(document_id="invalid") + invalid_doc = await g_store.aget_by_document_id(document_id="invalid") assert invalid_doc is None + @pytest.mark.parametrize( + ("is_vectorize", "page_contents", "collection_fixture_name"), + [ + (False, ["[1, 2]"], "empty_collection_d2"), + (True, ["varenyky, holubtsi, and deruny"], "empty_collection_vz"), + ], + ids=["nonvectorize_store", "vectorize_store"], + ) def test_gvs_from_texts( self, *, - astra_db_credentials: AstraDBCredentials, - empty_collection_d2: Collection, + auth_kwargs: dict[str, Any], + openai_api_key: str, embedding_d2: Embeddings, + is_vectorize: bool, + page_contents: list[str], + collection_fixture_name: str, + request: pytest.FixtureRequest, ) -> None: + collection: Collection = request.getfixturevalue(collection_fixture_name) + init_kwargs: dict[str, Any] + if is_vectorize: + init_kwargs = { + "collection_vector_service_options": OPENAI_VECTORIZE_OPTIONS_HEADER, + "collection_embedding_api_key": openai_api_key, + } + else: + init_kwargs = {"embedding": embedding_d2} + + content_field = CUSTOM_CONTENT_KEY if not is_vectorize else None + g_store = AstraDBGraphVectorStore.from_texts( - texts=["[1, 2]"], - embedding=embedding_d2, + texts=page_contents, metadatas=[{"md": 1}], ids=["x_id"], - collection_name=empty_collection_d2.name, - token=StaticTokenProvider(astra_db_credentials["token"]), - api_endpoint=astra_db_credentials["api_endpoint"], - namespace=astra_db_credentials["namespace"], - environment=astra_db_credentials["environment"], - content_field=CUSTOM_CONTENT_KEY, + collection_name=collection.name, + content_field=content_field, setup_mode=SetupMode.OFF, + **auth_kwargs, + **init_kwargs, ) - hits = g_store.similarity_search("[2, 1]", k=2) + + query = "ukrainian food" if is_vectorize else "[2, 1]" + hits = g_store.similarity_search(query=query, k=2) assert len(hits) == 1 - assert hits[0].page_content == "[1, 2]" + assert hits[0].page_content == page_contents[0] assert hits[0].id == "x_id" - # there may be more re:graph structure. assert hits[0].metadata["md"] == 1 + @pytest.mark.parametrize( + ("is_vectorize", "page_contents", "collection_fixture_name"), + [ + (False, ["[1, 2]"], "empty_collection_d2"), + (True, ["tacos, tamales, and mole"], "empty_collection_vz"), + ], + ids=["nonvectorize_store", "vectorize_store"], + ) def test_gvs_from_documents_containing_ids( self, *, - astra_db_credentials: AstraDBCredentials, - empty_collection_d2: Collection, + auth_kwargs: dict[str, Any], + openai_api_key: str, embedding_d2: Embeddings, + is_vectorize: bool, + page_contents: list[str], + collection_fixture_name: str, + request: pytest.FixtureRequest, ) -> None: + collection: Collection = request.getfixturevalue(collection_fixture_name) + init_kwargs: dict[str, Any] + if is_vectorize: + init_kwargs = { + "collection_vector_service_options": OPENAI_VECTORIZE_OPTIONS_HEADER, + "collection_embedding_api_key": openai_api_key, + } + else: + init_kwargs = {"embedding": embedding_d2} + + content_field = CUSTOM_CONTENT_KEY if not is_vectorize else None + the_document = Document( - page_content="[1, 2]", + page_content=page_contents[0], metadata={"md": 1}, id="x_id", ) g_store = AstraDBGraphVectorStore.from_documents( documents=[the_document], - embedding=embedding_d2, - collection_name=empty_collection_d2.name, - token=StaticTokenProvider(astra_db_credentials["token"]), - api_endpoint=astra_db_credentials["api_endpoint"], - namespace=astra_db_credentials["namespace"], - environment=astra_db_credentials["environment"], - content_field=CUSTOM_CONTENT_KEY, + collection_name=collection.name, + content_field=content_field, setup_mode=SetupMode.OFF, + **auth_kwargs, + **init_kwargs, ) - hits = g_store.similarity_search("[2, 1]", k=2) + + query = "mexican food" if is_vectorize else "[2, 1]" + hits = g_store.similarity_search(query=query, k=2) assert len(hits) == 1 - assert hits[0].page_content == "[1, 2]" + assert hits[0].page_content == page_contents[0] assert hits[0].id == "x_id" - # there may be more re:graph structure. assert hits[0].metadata["md"] == 1 + @pytest.mark.parametrize( + ("is_vectorize", "page_contents", "store_name"), + [ + (False, ["[0, 2]", "[0, 1]"], "graph_vector_store_d2"), + (True, ["lasagna", "hamburger"], "graph_vector_store_vz"), + ], + ids=["nonvectorize_store", "vectorize_store"], + ) def test_gvs_add_nodes_sync( self, *, - graph_vector_store_d2: AstraDBGraphVectorStore, + is_vectorize: bool, + page_contents: list[str], + store_name: str, + request: pytest.FixtureRequest, ) -> None: + g_store: AstraDBGraphVectorStore = request.getfixturevalue(store_name) links0 = [ Link(kind="kA", direction="out", tag="tA"), Link(kind="kB", direction="bidir", tag="tB"), @@ -549,28 +867,40 @@ def test_gvs_add_nodes_sync( Link(kind="kC", direction="in", tag="tC"), ] nodes = [ - Node(id="id0", text="[0, 2]", metadata={"m": 0}, links=links0), - Node(text="[0, 1]", metadata={"m": 1}, links=links1), + Node(id="id0", text=page_contents[0], metadata={"m": 0}, links=links0), + Node(text=page_contents[1], metadata={"m": 1}, links=links1), ] - graph_vector_store_d2.add_nodes(nodes) - hits = graph_vector_store_d2.similarity_search_by_vector([0, 3]) + g_store.add_nodes(nodes) + + query = "italian food" if is_vectorize else "[0, 3]" + hits = g_store.similarity_search(query=query) assert len(hits) == 2 assert hits[0].id == "id0" - assert hits[0].page_content == "[0, 2]" md0 = hits[0].metadata assert md0["m"] == 0 assert any(isinstance(v, set) for k, v in md0.items() if k != "m") assert hits[1].id != "id0" - assert hits[1].page_content == "[0, 1]" md1 = hits[1].metadata assert md1["m"] == 1 assert any(isinstance(v, set) for k, v in md1.items() if k != "m") + @pytest.mark.parametrize( + ("is_vectorize", "page_contents", "store_name"), + [ + (False, ["[0, 2]", "[0, 1]"], "graph_vector_store_d2"), + (True, ["lasagna", "hamburger"], "graph_vector_store_vz"), + ], + ids=["nonvectorize_store", "vectorize_store"], + ) async def test_gvs_add_nodes_async( self, *, - graph_vector_store_d2: AstraDBGraphVectorStore, + is_vectorize: bool, + page_contents: list[str], + store_name: str, + request: pytest.FixtureRequest, ) -> None: + g_store: AstraDBGraphVectorStore = request.getfixturevalue(store_name) links0 = [ Link(kind="kA", direction="out", tag="tA"), Link(kind="kB", direction="bidir", tag="tB"), @@ -579,21 +909,20 @@ async def test_gvs_add_nodes_async( Link(kind="kC", direction="in", tag="tC"), ] nodes = [ - Node(id="id0", text="[0, 2]", metadata={"m": 0}, links=links0), - Node(text="[0, 1]", metadata={"m": 1}, links=links1), + Node(id="id0", text=page_contents[0], metadata={"m": 0}, links=links0), + Node(text=page_contents[1], metadata={"m": 1}, links=links1), ] - async for _ in graph_vector_store_d2.aadd_nodes(nodes): + async for _ in g_store.aadd_nodes(nodes): pass - hits = await graph_vector_store_d2.asimilarity_search_by_vector([0, 3]) + query = "italian food" if is_vectorize else "[0, 3]" + hits = await g_store.asimilarity_search(query=query) assert len(hits) == 2 assert hits[0].id == "id0" - assert hits[0].page_content == "[0, 2]" md0 = hits[0].metadata assert md0["m"] == 0 assert any(isinstance(v, set) for k, v in md0.items() if k != "m") assert hits[1].id != "id0" - assert hits[1].page_content == "[0, 1]" md1 = hits[1].metadata assert md1["m"] == 1 assert any(isinstance(v, set) for k, v in md1.items() if k != "m") diff --git a/libs/astradb/tests/integration_tests/test_vectorstore.py b/libs/astradb/tests/integration_tests/test_vectorstore.py index 6a6a482..e80a5f1 100644 --- a/libs/astradb/tests/integration_tests/test_vectorstore.py +++ b/libs/astradb/tests/integration_tests/test_vectorstore.py @@ -5,6 +5,7 @@ import json import math import os +import random from typing import TYPE_CHECKING, Any import pytest @@ -29,6 +30,11 @@ from .conftest import AstraDBCredentials +def assert_list_of_numeric(value: list[float]) -> None: + assert isinstance(value, list) + assert all(isinstance(item, (float, int)) for item in value) + + @pytest.fixture def metadata_documents() -> list[Document]: """Documents for metadata and id tests""" @@ -824,7 +830,6 @@ async def test_astradb_vectorstore_massive_insert_replace_async( all_ids = [f"doc_{idx}" for idx in range(full_size)] all_texts = [f"[0,{idx + 1}]" for idx in range(full_size)] - all_embeddings = [[0, idx + 1] for idx in range(full_size)] # massive insertion on empty group0_ids = all_ids[0:first_group_size] @@ -856,16 +861,6 @@ async def test_astradb_vectorstore_massive_insert_replace_async( ) for doc, _, doc_id in full_results: assert doc.page_content == expected_text_by_id[doc_id] - expected_embedding_by_id = dict(zip(all_ids, all_embeddings)) - full_results_with_embeddings = ( - await vector_store_d2.asimilarity_search_with_embedding_id_by_vector( - embedding=[1.0, 1.0], - k=full_size, - ) - ) - for doc, embedding, doc_id in full_results_with_embeddings: - assert doc.page_content == expected_text_by_id[doc_id] - assert embedding == expected_embedding_by_id[doc_id] def test_astradb_vectorstore_delete_by_metadata_sync( self, @@ -1419,6 +1414,142 @@ async def test_astradb_vectorstore_similarity_scale_async( assert abs(1 - sco_near) < MATCH_EPSILON assert sco_far < EUCLIDEAN_MIN_SIM_UNIT_VECTORS + MATCH_EPSILON + @pytest.mark.parametrize( + "vector_store", + [ + "vector_store_d2", + "vector_store_vz", + ], + ids=["nonvectorize_store", "vectorize_store"], + ) + async def test_astradb_vectorstore_asimilarity_search_with_embedding( + self, + *, + vector_store: str, + metadata_documents: list[Document], + request: pytest.FixtureRequest, + ) -> None: + """asimilarity_search_with_embedding is used as the building + block for other components (like AstraDBGraphVectorStore). + """ + vstore: AstraDBVectorStore = request.getfixturevalue(vector_store) + await vstore.aadd_documents(metadata_documents) + + query_embedding, results = await vstore.asimilarity_search_with_embedding( + query="[-1,2]" + ) + + assert_list_of_numeric(query_embedding) + assert isinstance(results, list) + assert len(results) > 0 + (doc, embedding) = results[0] + assert isinstance(doc, Document) + assert_list_of_numeric(embedding) + + @pytest.mark.parametrize( + ("is_vectorize", "vector_store"), + [ + (False, "vector_store_d2"), + (True, "vector_store_vz"), + ], + ids=["nonvectorize_store", "vectorize_store"], + ) + async def test_astradb_vectorstore_asimilarity_search_with_embedding_by_vector( + self, + *, + is_vectorize: bool, + vector_store: str, + metadata_documents: list[Document], + request: pytest.FixtureRequest, + ) -> None: + """asimilarity_search_with_embedding_by_vector is used as the building + block for other components (like AstraDBGraphVectorStore). + """ + vstore: AstraDBVectorStore = request.getfixturevalue(vector_store) + await vstore.aadd_documents(metadata_documents) + + vector_dimensions = 1536 if is_vectorize else 2 + results = await vstore.asimilarity_search_with_embedding_by_vector( + embedding=[ + random.uniform(0.0, 1.0) # noqa: S311 + for _ in range(vector_dimensions) + ] + ) + + assert isinstance(results, list) + assert len(results) > 0 + (doc, embedding) = results[0] + assert isinstance(doc, Document) + assert_list_of_numeric(embedding) + + @pytest.mark.parametrize( + "vector_store", + [ + "vector_store_d2", + "vector_store_vz", + ], + ids=["nonvectorize_store", "vectorize_store"], + ) + def test_astradb_vectorstore_similarity_search_with_embedding( + self, + *, + vector_store: str, + metadata_documents: list[Document], + request: pytest.FixtureRequest, + ) -> None: + """similarity_search_with_embedding is used as the building + block for other components (like AstraDBGraphVectorStore). + """ + vstore: AstraDBVectorStore = request.getfixturevalue(vector_store) + vstore.add_documents(metadata_documents) + + query_embedding, results = vstore.similarity_search_with_embedding( + query="[-1,2]" + ) + + assert_list_of_numeric(query_embedding) + assert isinstance(results, list) + assert len(results) > 0 + (doc, embedding) = results[0] + assert isinstance(doc, Document) + assert_list_of_numeric(embedding) + + @pytest.mark.parametrize( + ("is_vectorize", "vector_store"), + [ + (False, "vector_store_d2"), + (True, "vector_store_vz"), + ], + ids=["nonvectorize_store", "vectorize_store"], + ) + def test_astradb_vectorstore_similarity_search_with_embedding_by_vector( + self, + *, + is_vectorize: bool, + vector_store: str, + metadata_documents: list[Document], + request: pytest.FixtureRequest, + ) -> None: + """similarity_search_with_embedding_by_vector is used as the building + block for other components (like AstraDBGraphVectorStore). + """ + vstore: AstraDBVectorStore = request.getfixturevalue(vector_store) + vstore.add_documents(metadata_documents) + + vector_dimensions = 1536 if is_vectorize else 2 + results = vstore.similarity_search_with_embedding_by_vector( + embedding=[ + random.uniform(0.0, 1.0) # noqa: S311 + for _ in range(vector_dimensions) + ] + ) + + assert isinstance(results, list) + assert len(results) > 0 + (doc, embedding) = results[0] + assert isinstance(doc, Document) + assert_list_of_numeric(embedding) + @pytest.mark.parametrize( "vector_store", [