Skip to content

Commit

Permalink
Use blockbuster to detect blocking calls in the event loop (#105)
Browse files Browse the repository at this point in the history
  • Loading branch information
cbornet authored Dec 19, 2024
1 parent 25ad55f commit bee975f
Show file tree
Hide file tree
Showing 9 changed files with 131 additions and 41 deletions.
2 changes: 1 addition & 1 deletion libs/astradb/langchain_astradb/graph_vectorstores.py
Original file line number Diff line number Diff line change
Expand Up @@ -1255,7 +1255,7 @@ async def visit_targets(d: int, docs: Iterable[Document]) -> None:
await asyncio.gather(*visit_node_tasks)

# Start the traversal
initial_docs = self.vector_store.similarity_search(
initial_docs = await self.vector_store.asimilarity_search(
query=query,
k=k,
filter=filter,
Expand Down
6 changes: 3 additions & 3 deletions libs/astradb/langchain_astradb/utils/astradb.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,9 +425,9 @@ async def _asetup_db(
except DataAPIException as data_api_exception:
# possibly the collection is preexisting and may have legacy,
# or custom, indexing settings: verify
collection_descriptors = [
coll_desc async for coll_desc in self.async_database.list_collections()
]
collection_descriptors = list(
await asyncio.to_thread(self.database.list_collections)
)
try:
if not self._validate_indexing_policy(
collection_descriptors=collection_descriptors,
Expand Down
26 changes: 25 additions & 1 deletion libs/astradb/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions libs/astradb/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ langchain = { git = "https://github.com/langchain-ai/langchain.git", subdirector
langchain-text-splitters = { git = "https://github.com/langchain-ai/langchain.git", subdirectory = "libs/text-splitters" }
langchain-core = { git = "https://github.com/langchain-ai/langchain.git", subdirectory = "libs/core" }
langchain-community = { git = "https://github.com/langchain-ai/langchain.git", subdirectory = "libs/community" }
blockbuster = "^1.2.0"

[tool.poetry.group.codespell]
optional = true
Expand Down
20 changes: 19 additions & 1 deletion libs/astradb/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@
from __future__ import annotations

import json
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Iterator

import pytest
from blockbuster import BlockBuster, blockbuster_ctx
from langchain_core.embeddings import Embeddings
from langchain_core.language_models import LLM
from typing_extensions import override
Expand All @@ -15,6 +17,22 @@
from langchain_core.callbacks import CallbackManagerForLLMRun


@pytest.fixture(autouse=True)
def blockbuster() -> Iterator[BlockBuster]:
with blockbuster_ctx() as bb:
# TODO: GraphVectorStore init is blocking. Should be fixed.
for method in (
"socket.socket.connect",
"ssl.SSLSocket.send",
"ssl.SSLSocket.recv",
"ssl.SSLSocket.read",
):
bb.functions[method].can_block_functions.append(
("langchain_astradb/graph_vectorstores.py", {"__init__"}),
)
yield bb


class ParserEmbeddings(Embeddings):
"""Parse input texts: if they are json for a List[float], fine.
Otherwise, return all zeros and call it a day.
Expand Down
55 changes: 39 additions & 16 deletions libs/astradb/tests/integration_tests/test_graphvectorstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from __future__ import annotations

import asyncio
from typing import TYPE_CHECKING, Any

import pytest
Expand Down Expand Up @@ -400,7 +401,9 @@ async def test_gvs_similarity_search_async(
request: pytest.FixtureRequest,
) -> None:
"""Simple (non-graph) similarity search on a graph vector store."""
g_store: AstraDBGraphVectorStore = request.getfixturevalue(store_name)
g_store: AstraDBGraphVectorStore = await asyncio.to_thread(
request.getfixturevalue, store_name
)
query = "universe" if is_vectorize else "[2, 10]"
embedding = [2.0, 10.0]

Expand All @@ -421,8 +424,10 @@ async def test_gvs_similarity_search_async(
assert ss_by_v_labels == ["AR", "A0"]

if is_autodetected:
assert_all_flat_docs(
g_store.vector_store.astra_env.collection, is_vectorize=is_vectorize
await asyncio.to_thread(
assert_all_flat_docs,
g_store.vector_store.astra_env.collection,
is_vectorize=is_vectorize,
)

@pytest.mark.parametrize(
Expand Down Expand Up @@ -488,7 +493,9 @@ async def test_gvs_traversal_search_async(
request: pytest.FixtureRequest,
) -> None:
"""Graph traversal search on a graph vector store."""
g_store: AstraDBGraphVectorStore = request.getfixturevalue(store_name)
g_store: AstraDBGraphVectorStore = await asyncio.to_thread(
request.getfixturevalue, store_name
)
query = "universe" if is_vectorize else "[2, 10]"

# this is a set, as some of the internals of trav.search are set-driven
Expand All @@ -499,8 +506,10 @@ async def test_gvs_traversal_search_async(
}
assert ts_labels == {"AR", "A0", "BR", "B0", "TR", "T0"}
if is_autodetected:
assert_all_flat_docs(
g_store.vector_store.astra_env.collection, is_vectorize=is_vectorize
await asyncio.to_thread(
assert_all_flat_docs,
g_store.vector_store.astra_env.collection,
is_vectorize=is_vectorize,
)

@pytest.mark.parametrize(
Expand Down Expand Up @@ -572,7 +581,9 @@ async def test_gvs_mmr_traversal_search_async(
request: pytest.FixtureRequest,
) -> None:
"""MMR Graph traversal search on a graph vector store."""
g_store: AstraDBGraphVectorStore = request.getfixturevalue(store_name)
g_store: AstraDBGraphVectorStore = await asyncio.to_thread(
request.getfixturevalue, store_name
)
query = "universe" if is_vectorize else "[2, 10]"

mt_labels = [
Expand All @@ -589,8 +600,10 @@ async def test_gvs_mmr_traversal_search_async(

assert mt_labels == ["AR", "BR"]
if is_autodetected:
assert_all_flat_docs(
g_store.vector_store.astra_env.collection, is_vectorize=is_vectorize
await asyncio.to_thread(
assert_all_flat_docs,
g_store.vector_store.astra_env.collection,
is_vectorize=is_vectorize,
)

@pytest.mark.parametrize(
Expand Down Expand Up @@ -652,7 +665,9 @@ async def test_gvs_metadata_search_async(
request: pytest.FixtureRequest,
) -> None:
"""Metadata search on a graph vector store."""
g_store: AstraDBGraphVectorStore = request.getfixturevalue(store_name)
g_store: AstraDBGraphVectorStore = await asyncio.to_thread(
request.getfixturevalue, store_name
)
mt_response = await g_store.ametadata_search(
filter={"label": "T0"},
n=2,
Expand Down Expand Up @@ -726,7 +741,9 @@ async def test_gvs_get_by_document_id_async(
request: pytest.FixtureRequest,
) -> None:
"""Get by document_id on a graph vector store."""
g_store: AstraDBGraphVectorStore = request.getfixturevalue(store_name)
g_store: AstraDBGraphVectorStore = await asyncio.to_thread(
request.getfixturevalue, store_name
)
doc = await g_store.aget_by_document_id(document_id="FL")
assert doc is not None
assert doc.metadata["label"] == "FL"
Expand Down Expand Up @@ -816,7 +833,9 @@ async def test_gvs_from_texts_async(
collection_fixture_name: str,
request: pytest.FixtureRequest,
) -> None:
collection_name: str = request.getfixturevalue(collection_fixture_name)
collection_name: str = await asyncio.to_thread(
request.getfixturevalue, collection_fixture_name
)
init_kwargs: dict[str, Any]
if is_vectorize:
init_kwargs = {
Expand All @@ -839,7 +858,7 @@ async def test_gvs_from_texts_async(
)

query = "ukrainian food" if is_vectorize else "[2, 1]"
hits = g_store.similarity_search(query=query, k=2)
hits = await g_store.asimilarity_search(query=query, k=2)
assert len(hits) == 1
assert hits[0].page_content == page_contents[0]
assert hits[0].id == "x_id"
Expand Down Expand Up @@ -915,7 +934,9 @@ async def test_gvs_from_documents_containing_ids_async(
collection_fixture_name: str,
request: pytest.FixtureRequest,
) -> None:
collection_name: str = request.getfixturevalue(collection_fixture_name)
collection_name: str = await asyncio.to_thread(
request.getfixturevalue, collection_fixture_name
)
init_kwargs: dict[str, Any]
if is_vectorize:
init_kwargs = {
Expand All @@ -941,7 +962,7 @@ async def test_gvs_from_documents_containing_ids_async(
)

query = "mexican food" if is_vectorize else "[2, 1]"
hits = g_store.similarity_search(query=query, k=2)
hits = await g_store.asimilarity_search(query=query, k=2)
assert len(hits) == 1
assert hits[0].page_content == page_contents[0]
assert hits[0].id == "x_id"
Expand Down Expand Up @@ -1005,7 +1026,9 @@ async def test_gvs_add_nodes_async(
store_name: str,
request: pytest.FixtureRequest,
) -> None:
g_store: AstraDBGraphVectorStore = request.getfixturevalue(store_name)
g_store: AstraDBGraphVectorStore = await asyncio.to_thread(
request.getfixturevalue, store_name
)
links0 = [
Link(kind="kA", direction="out", tag="tA"),
Link(kind="kB", direction="bidir", tag="tB"),
Expand Down
6 changes: 3 additions & 3 deletions libs/astradb/tests/integration_tests/test_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ async def astra_db_empty_store_async(
astra_db_credentials: AstraDBCredentials,
collection_idxid: Collection,
) -> AstraDBStore:
collection_idxid.delete_many({})
await collection_idxid.to_async().delete_many({})
return AstraDBStore(
collection_name=collection_idxid.name,
token=StaticTokenProvider(astra_db_credentials["token"]),
Expand Down Expand Up @@ -364,7 +364,7 @@ async def test_store_indexing_on_legacy_async(
database: Database,
) -> None:
"""Test of instantiation against a legacy collection, async version."""
database.create_collection(
await database.to_async().create_collection(
EPHEMERAL_LEGACY_IDX_NAME,
indexing=None,
check_exists=False,
Expand Down Expand Up @@ -413,7 +413,7 @@ async def test_store_indexing_on_custom_async(
database: Database,
) -> None:
"""Test of instantiation against a legacy collection, async version."""
database.create_collection(
await database.to_async().create_collection(
EPHEMERAL_CUSTOM_IDX_NAME,
indexing={"deny": ["useless", "forgettable"]},
check_exists=False,
Expand Down
Loading

0 comments on commit bee975f

Please sign in to comment.