diff --git a/libs/astradb/langchain_astradb/graph_vectorstores.py b/libs/astradb/langchain_astradb/graph_vectorstores.py index e14e78e..5134acb 100644 --- a/libs/astradb/langchain_astradb/graph_vectorstores.py +++ b/libs/astradb/langchain_astradb/graph_vectorstores.py @@ -1255,7 +1255,7 @@ async def visit_targets(d: int, docs: Iterable[Document]) -> None: await asyncio.gather(*visit_node_tasks) # Start the traversal - initial_docs = self.vector_store.similarity_search( + initial_docs = await self.vector_store.asimilarity_search( query=query, k=k, filter=filter, diff --git a/libs/astradb/langchain_astradb/utils/astradb.py b/libs/astradb/langchain_astradb/utils/astradb.py index 66d88dd..7825c85 100644 --- a/libs/astradb/langchain_astradb/utils/astradb.py +++ b/libs/astradb/langchain_astradb/utils/astradb.py @@ -425,9 +425,9 @@ async def _asetup_db( except DataAPIException as data_api_exception: # possibly the collection is preexisting and may have legacy, # or custom, indexing settings: verify - collection_descriptors = [ - coll_desc async for coll_desc in self.async_database.list_collections() - ] + collection_descriptors = list( + await asyncio.to_thread(self.database.list_collections) + ) try: if not self._validate_indexing_policy( collection_descriptors=collection_descriptors, diff --git a/libs/astradb/poetry.lock b/libs/astradb/poetry.lock index 4194c94..b8c3125 100644 --- a/libs/astradb/poetry.lock +++ b/libs/astradb/poetry.lock @@ -218,6 +218,20 @@ docs = ["cogapp", "furo", "myst-parser", "sphinx", "sphinx-notfound-page", "sphi tests = ["cloudpickle", "hypothesis", "mypy (>=1.11.1)", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "pytest-xdist[psutil]"] tests-mypy = ["mypy (>=1.11.1)", "pytest-mypy-plugins"] +[[package]] +name = "blockbuster" +version = "1.2.0" +description = "Utility to detect blocking calls in the async event loop" +optional = false +python-versions = ">=3.8" +files = [ + {file = "blockbuster-1.2.0-py3-none-any.whl", hash = "sha256:5210faccc22695bd3c338d3de2ec0581d5a270729e1b18d98d78eefd95eea2c5"}, + {file = "blockbuster-1.2.0.tar.gz", hash = "sha256:c54f5184debf708488447fec53205274518ec806b9fb551d9e6d489b5a26b703"}, +] + +[package.dependencies] +forbiddenfruit = ">=0.1.4" + [[package]] name = "certifi" version = "2024.8.30" @@ -419,6 +433,16 @@ files = [ [package.extras] test = ["pytest (>=6)"] +[[package]] +name = "forbiddenfruit" +version = "0.1.4" +description = "Patch python built-in objects" +optional = false +python-versions = "*" +files = [ + {file = "forbiddenfruit-0.1.4.tar.gz", hash = "sha256:e3f7e66561a29ae129aac139a85d610dbf3dd896128187ed5454b6421f624253"}, +] + [[package]] name = "freezegun" version = "1.5.1" @@ -2240,4 +2264,4 @@ propcache = ">=0.2.0" [metadata] lock-version = "2.0" python-versions = ">=3.9,<4.0" -content-hash = "6041fa7043d4f92cd9beedece338a472b725f110a12f1b6ff5510bf1bf3f2c1b" +content-hash = "dda1c3c1d1d906f1ea5d68c186e200ba3e6c82798f8c2750ef35a25b0cef6dca" diff --git a/libs/astradb/pyproject.toml b/libs/astradb/pyproject.toml index c455700..a574f3f 100644 --- a/libs/astradb/pyproject.toml +++ b/libs/astradb/pyproject.toml @@ -32,6 +32,7 @@ langchain = { git = "https://github.com/langchain-ai/langchain.git", subdirector langchain-text-splitters = { git = "https://github.com/langchain-ai/langchain.git", subdirectory = "libs/text-splitters" } langchain-core = { git = "https://github.com/langchain-ai/langchain.git", subdirectory = "libs/core" } langchain-community = { git = "https://github.com/langchain-ai/langchain.git", subdirectory = "libs/community" } +blockbuster = "^1.2.0" [tool.poetry.group.codespell] optional = true diff --git a/libs/astradb/tests/conftest.py b/libs/astradb/tests/conftest.py index db44211..518c903 100644 --- a/libs/astradb/tests/conftest.py +++ b/libs/astradb/tests/conftest.py @@ -5,8 +5,10 @@ from __future__ import annotations import json -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Iterator +import pytest +from blockbuster import BlockBuster, blockbuster_ctx from langchain_core.embeddings import Embeddings from langchain_core.language_models import LLM from typing_extensions import override @@ -15,6 +17,21 @@ from langchain_core.callbacks import CallbackManagerForLLMRun +@pytest.fixture(autouse=True) +def blockbuster() -> Iterator[BlockBuster]: + with blockbuster_ctx() as bb: + for method in ( + "socket.socket.connect", + "ssl.SSLSocket.send", + "ssl.SSLSocket.recv", + "ssl.SSLSocket.read", + ): + bb.functions[method].can_block_functions.append( + ("langchain_astradb/graph_vectorstores.py", {"__init__"}), + ) + yield bb + + class ParserEmbeddings(Embeddings): """Parse input texts: if they are json for a List[float], fine. Otherwise, return all zeros and call it a day. diff --git a/libs/astradb/tests/integration_tests/test_graphvectorstore.py b/libs/astradb/tests/integration_tests/test_graphvectorstore.py index 515009f..8bae1df 100644 --- a/libs/astradb/tests/integration_tests/test_graphvectorstore.py +++ b/libs/astradb/tests/integration_tests/test_graphvectorstore.py @@ -5,6 +5,7 @@ from __future__ import annotations +import asyncio from typing import TYPE_CHECKING, Any import pytest @@ -400,7 +401,9 @@ async def test_gvs_similarity_search_async( request: pytest.FixtureRequest, ) -> None: """Simple (non-graph) similarity search on a graph vector store.""" - g_store: AstraDBGraphVectorStore = request.getfixturevalue(store_name) + g_store: AstraDBGraphVectorStore = await asyncio.to_thread( + request.getfixturevalue, store_name + ) query = "universe" if is_vectorize else "[2, 10]" embedding = [2.0, 10.0] @@ -421,8 +424,10 @@ async def test_gvs_similarity_search_async( assert ss_by_v_labels == ["AR", "A0"] if is_autodetected: - assert_all_flat_docs( - g_store.vector_store.astra_env.collection, is_vectorize=is_vectorize + await asyncio.to_thread( + assert_all_flat_docs, + g_store.vector_store.astra_env.collection, + is_vectorize=is_vectorize, ) @pytest.mark.parametrize( @@ -488,7 +493,9 @@ async def test_gvs_traversal_search_async( request: pytest.FixtureRequest, ) -> None: """Graph traversal search on a graph vector store.""" - g_store: AstraDBGraphVectorStore = request.getfixturevalue(store_name) + g_store: AstraDBGraphVectorStore = await asyncio.to_thread( + request.getfixturevalue, store_name + ) query = "universe" if is_vectorize else "[2, 10]" # this is a set, as some of the internals of trav.search are set-driven @@ -499,8 +506,10 @@ async def test_gvs_traversal_search_async( } assert ts_labels == {"AR", "A0", "BR", "B0", "TR", "T0"} if is_autodetected: - assert_all_flat_docs( - g_store.vector_store.astra_env.collection, is_vectorize=is_vectorize + await asyncio.to_thread( + assert_all_flat_docs, + g_store.vector_store.astra_env.collection, + is_vectorize=is_vectorize, ) @pytest.mark.parametrize( @@ -572,7 +581,9 @@ async def test_gvs_mmr_traversal_search_async( request: pytest.FixtureRequest, ) -> None: """MMR Graph traversal search on a graph vector store.""" - g_store: AstraDBGraphVectorStore = request.getfixturevalue(store_name) + g_store: AstraDBGraphVectorStore = await asyncio.to_thread( + request.getfixturevalue, store_name + ) query = "universe" if is_vectorize else "[2, 10]" mt_labels = [ @@ -589,8 +600,10 @@ async def test_gvs_mmr_traversal_search_async( assert mt_labels == ["AR", "BR"] if is_autodetected: - assert_all_flat_docs( - g_store.vector_store.astra_env.collection, is_vectorize=is_vectorize + await asyncio.to_thread( + assert_all_flat_docs, + g_store.vector_store.astra_env.collection, + is_vectorize=is_vectorize, ) @pytest.mark.parametrize( @@ -652,7 +665,9 @@ async def test_gvs_metadata_search_async( request: pytest.FixtureRequest, ) -> None: """Metadata search on a graph vector store.""" - g_store: AstraDBGraphVectorStore = request.getfixturevalue(store_name) + g_store: AstraDBGraphVectorStore = await asyncio.to_thread( + request.getfixturevalue, store_name + ) mt_response = await g_store.ametadata_search( filter={"label": "T0"}, n=2, @@ -726,7 +741,9 @@ async def test_gvs_get_by_document_id_async( request: pytest.FixtureRequest, ) -> None: """Get by document_id on a graph vector store.""" - g_store: AstraDBGraphVectorStore = request.getfixturevalue(store_name) + g_store: AstraDBGraphVectorStore = await asyncio.to_thread( + request.getfixturevalue, store_name + ) doc = await g_store.aget_by_document_id(document_id="FL") assert doc is not None assert doc.metadata["label"] == "FL" @@ -816,7 +833,9 @@ async def test_gvs_from_texts_async( collection_fixture_name: str, request: pytest.FixtureRequest, ) -> None: - collection_name: str = request.getfixturevalue(collection_fixture_name) + collection_name: str = await asyncio.to_thread( + request.getfixturevalue, collection_fixture_name + ) init_kwargs: dict[str, Any] if is_vectorize: init_kwargs = { @@ -839,7 +858,7 @@ async def test_gvs_from_texts_async( ) query = "ukrainian food" if is_vectorize else "[2, 1]" - hits = g_store.similarity_search(query=query, k=2) + hits = await g_store.asimilarity_search(query=query, k=2) assert len(hits) == 1 assert hits[0].page_content == page_contents[0] assert hits[0].id == "x_id" @@ -915,7 +934,9 @@ async def test_gvs_from_documents_containing_ids_async( collection_fixture_name: str, request: pytest.FixtureRequest, ) -> None: - collection_name: str = request.getfixturevalue(collection_fixture_name) + collection_name: str = await asyncio.to_thread( + request.getfixturevalue, collection_fixture_name + ) init_kwargs: dict[str, Any] if is_vectorize: init_kwargs = { @@ -941,7 +962,7 @@ async def test_gvs_from_documents_containing_ids_async( ) query = "mexican food" if is_vectorize else "[2, 1]" - hits = g_store.similarity_search(query=query, k=2) + hits = await g_store.asimilarity_search(query=query, k=2) assert len(hits) == 1 assert hits[0].page_content == page_contents[0] assert hits[0].id == "x_id" @@ -1005,7 +1026,9 @@ async def test_gvs_add_nodes_async( store_name: str, request: pytest.FixtureRequest, ) -> None: - g_store: AstraDBGraphVectorStore = request.getfixturevalue(store_name) + g_store: AstraDBGraphVectorStore = await asyncio.to_thread( + request.getfixturevalue, store_name + ) links0 = [ Link(kind="kA", direction="out", tag="tA"), Link(kind="kB", direction="bidir", tag="tB"), diff --git a/libs/astradb/tests/integration_tests/test_storage.py b/libs/astradb/tests/integration_tests/test_storage.py index d105656..75eae09 100644 --- a/libs/astradb/tests/integration_tests/test_storage.py +++ b/libs/astradb/tests/integration_tests/test_storage.py @@ -44,7 +44,7 @@ async def astra_db_empty_store_async( astra_db_credentials: AstraDBCredentials, collection_idxid: Collection, ) -> AstraDBStore: - collection_idxid.delete_many({}) + await collection_idxid.to_async().delete_many({}) return AstraDBStore( collection_name=collection_idxid.name, token=StaticTokenProvider(astra_db_credentials["token"]), @@ -364,7 +364,7 @@ async def test_store_indexing_on_legacy_async( database: Database, ) -> None: """Test of instantiation against a legacy collection, async version.""" - database.create_collection( + await database.to_async().create_collection( EPHEMERAL_LEGACY_IDX_NAME, indexing=None, check_exists=False, @@ -413,7 +413,7 @@ async def test_store_indexing_on_custom_async( database: Database, ) -> None: """Test of instantiation against a legacy collection, async version.""" - database.create_collection( + await database.to_async().create_collection( EPHEMERAL_CUSTOM_IDX_NAME, indexing={"deny": ["useless", "forgettable"]}, check_exists=False, diff --git a/libs/astradb/tests/integration_tests/test_vectorstore.py b/libs/astradb/tests/integration_tests/test_vectorstore.py index e80a5f1..f7dba49 100644 --- a/libs/astradb/tests/integration_tests/test_vectorstore.py +++ b/libs/astradb/tests/integration_tests/test_vectorstore.py @@ -2,6 +2,7 @@ from __future__ import annotations +import asyncio import json import math import os @@ -371,7 +372,9 @@ async def test_astradb_vectorstore_from_texts_async( request: pytest.FixtureRequest, ) -> None: """from_texts methods and the associated warnings, async version.""" - collection: Collection = request.getfixturevalue(collection_fixture_name) + collection: Collection = await asyncio.to_thread( + request.getfixturevalue, collection_fixture_name + ) init_kwargs: dict[str, Any] if is_vectorize: init_kwargs = { @@ -481,7 +484,9 @@ async def test_astradb_vectorstore_from_documents_async( from_documents, esp. the various handling of ID-in-doc vs external. Async version. """ - collection: Collection = request.getfixturevalue(collection_fixture_name) + collection: Collection = await asyncio.to_thread( + request.getfixturevalue, collection_fixture_name + ) pc1, pc2 = page_contents init_kwargs: dict[str, Any] if is_vectorize: @@ -509,7 +514,7 @@ async def test_astradb_vectorstore_from_documents_async( assert len(hits) == 1 assert hits[0].page_content == pc2 assert hits[0].metadata == {"m": 3} - v_store.clear() + await v_store.aclear() # IDs passed separately. with pytest.warns(DeprecationWarning) as rec_warnings: @@ -536,7 +541,7 @@ async def test_astradb_vectorstore_from_documents_async( assert hits[0].page_content == pc2 assert hits[0].metadata == {"m": 3} assert hits[0].id == "idx3" - v_store_2.clear() + await v_store_2.aclear() # IDs in documents. v_store_3 = await AstraDBVectorStore.afrom_documents( @@ -557,7 +562,7 @@ async def test_astradb_vectorstore_from_documents_async( assert hits[0].page_content == pc2 assert hits[0].metadata == {"m": 3} assert hits[0].id == "idx3" - v_store_3.clear() + await v_store_3.aclear() # IDs both in documents and aside. with pytest.warns(DeprecationWarning) as rec_warnings: @@ -692,7 +697,9 @@ async def test_astradb_vectorstore_crud_async( request: pytest.FixtureRequest, ) -> None: """Add/delete/update behaviour, async version.""" - vstore: AstraDBVectorStore = request.getfixturevalue(vector_store) + vstore: AstraDBVectorStore = await asyncio.to_thread( + request.getfixturevalue, vector_store + ) res0 = await vstore.asimilarity_search("[-1,-1]", k=2) assert res0 == [] @@ -1242,13 +1249,15 @@ async def test_astradb_vectorstore_metadata_search_async( metadata_documents: list[Document], ) -> None: """Metadata Search""" - vstore: AstraDBVectorStore = request.getfixturevalue(vector_store) + vstore: AstraDBVectorStore = await asyncio.to_thread( + request.getfixturevalue, vector_store + ) await vstore.aadd_documents(metadata_documents) # no filters res0 = await vstore.ametadata_search(filter={}, n=10) assert {doc.metadata["letter"] for doc in res0} == set("qwreio") # single filter - res1 = vstore.metadata_search( + res1 = await vstore.ametadata_search( n=10, filter={"group": "vowel"}, ) @@ -1313,7 +1322,9 @@ async def test_astradb_vectorstore_get_by_document_id_async( metadata_documents: list[Document], ) -> None: """Get by document_id""" - vstore: AstraDBVectorStore = request.getfixturevalue(vector_store) + vstore: AstraDBVectorStore = await asyncio.to_thread( + request.getfixturevalue, vector_store + ) await vstore.aadd_documents(metadata_documents) # invalid id invalid = await vstore.aget_by_document_id(document_id="z") @@ -1398,7 +1409,9 @@ async def test_astradb_vectorstore_similarity_scale_async( request: pytest.FixtureRequest, ) -> None: """Scale of the similarity scores, async version.""" - vstore: AstraDBVectorStore = request.getfixturevalue(vector_store) + vstore: AstraDBVectorStore = await asyncio.to_thread( + request.getfixturevalue, vector_store + ) await vstore.aadd_texts( texts=texts, ids=["near", "far"], @@ -1432,7 +1445,9 @@ async def test_astradb_vectorstore_asimilarity_search_with_embedding( """asimilarity_search_with_embedding is used as the building block for other components (like AstraDBGraphVectorStore). """ - vstore: AstraDBVectorStore = request.getfixturevalue(vector_store) + vstore: AstraDBVectorStore = await asyncio.to_thread( + request.getfixturevalue, vector_store + ) await vstore.aadd_documents(metadata_documents) query_embedding, results = await vstore.asimilarity_search_with_embedding( @@ -1465,7 +1480,9 @@ async def test_astradb_vectorstore_asimilarity_search_with_embedding_by_vector( """asimilarity_search_with_embedding_by_vector is used as the building block for other components (like AstraDBGraphVectorStore). """ - vstore: AstraDBVectorStore = request.getfixturevalue(vector_store) + vstore: AstraDBVectorStore = await asyncio.to_thread( + request.getfixturevalue, vector_store + ) await vstore.aadd_documents(metadata_documents) vector_dimensions = 1536 if is_vectorize else 2 @@ -1748,7 +1765,7 @@ async def test_astradb_vectorstore_coreclients_init_async( Expect a deprecation warning from passing a (core) AstraDB class, but it must work. Async version. """ - vector_store_d2.add_texts(["[1,2]"]) + await vector_store_d2.aadd_texts(["[1,2]"]) with pytest.warns(DeprecationWarning) as rec_warnings: v_store_init_core = AstraDBVectorStore( diff --git a/libs/astradb/tests/integration_tests/test_vectorstore_ddl_tests.py b/libs/astradb/tests/integration_tests/test_vectorstore_ddl_tests.py index bec53be..8218ff5 100644 --- a/libs/astradb/tests/integration_tests/test_vectorstore_ddl_tests.py +++ b/libs/astradb/tests/integration_tests/test_vectorstore_ddl_tests.py @@ -5,6 +5,7 @@ from __future__ import annotations +import asyncio import warnings from typing import TYPE_CHECKING @@ -96,10 +97,13 @@ async def test_astradb_vectorstore_create_delete_async( namespace=astra_db_credentials["namespace"], environment=astra_db_credentials["environment"], metric="cosine", + setup_mode=SetupMode.ASYNC, ) await v_store.aadd_texts(["[1,2]"]) await v_store.adelete_collection() - assert ephemeral_collection_cleaner_d2 not in database.list_collection_names() + assert ephemeral_collection_cleaner_d2 not in await asyncio.to_thread( + database.list_collection_names + ) async def test_astradb_vectorstore_create_delete_vectorize_async( self, @@ -118,10 +122,13 @@ async def test_astradb_vectorstore_create_delete_vectorize_async( metric="cosine", collection_vector_service_options=OPENAI_VECTORIZE_OPTIONS_HEADER, collection_embedding_api_key=openai_api_key, + setup_mode=SetupMode.ASYNC, ) await v_store.aadd_texts(["[1,2]"]) await v_store.adelete_collection() - assert ephemeral_collection_cleaner_vz not in database.list_collection_names() + assert ephemeral_collection_cleaner_vz not in await asyncio.to_thread( + database.list_collection_names + ) def test_astradb_vectorstore_pre_delete_collection_sync( self, @@ -340,7 +347,7 @@ async def test_astradb_vectorstore_indexing_legacy_async( Test of the vector store behaviour for various indexing settings, with an existing 'legacy' collection (i.e. unspecified indexing policy). """ - database.create_collection( + await database.to_async().create_collection( EPHEMERAL_LEGACY_IDX_NAME_D2, dimension=2, check_exists=False,