Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use blockbuster to detect blocking calls in the event loop #105

Merged
merged 1 commit into from
Dec 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion libs/astradb/langchain_astradb/graph_vectorstores.py
Original file line number Diff line number Diff line change
Expand Up @@ -1255,7 +1255,7 @@ async def visit_targets(d: int, docs: Iterable[Document]) -> None:
await asyncio.gather(*visit_node_tasks)

# Start the traversal
initial_docs = self.vector_store.similarity_search(
initial_docs = await self.vector_store.asimilarity_search(
query=query,
k=k,
filter=filter,
Expand Down
6 changes: 3 additions & 3 deletions libs/astradb/langchain_astradb/utils/astradb.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,9 +425,9 @@ async def _asetup_db(
except DataAPIException as data_api_exception:
# possibly the collection is preexisting and may have legacy,
# or custom, indexing settings: verify
collection_descriptors = [
coll_desc async for coll_desc in self.async_database.list_collections()
]
collection_descriptors = list(
await asyncio.to_thread(self.database.list_collections)
)
try:
if not self._validate_indexing_policy(
collection_descriptors=collection_descriptors,
Expand Down
26 changes: 25 additions & 1 deletion libs/astradb/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions libs/astradb/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ langchain = { git = "https://github.com/langchain-ai/langchain.git", subdirector
langchain-text-splitters = { git = "https://github.com/langchain-ai/langchain.git", subdirectory = "libs/text-splitters" }
langchain-core = { git = "https://github.com/langchain-ai/langchain.git", subdirectory = "libs/core" }
langchain-community = { git = "https://github.com/langchain-ai/langchain.git", subdirectory = "libs/community" }
blockbuster = "^1.2.0"

[tool.poetry.group.codespell]
optional = true
Expand Down
20 changes: 19 additions & 1 deletion libs/astradb/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@
from __future__ import annotations

import json
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Iterator

import pytest
from blockbuster import BlockBuster, blockbuster_ctx
from langchain_core.embeddings import Embeddings
from langchain_core.language_models import LLM
from typing_extensions import override
Expand All @@ -15,6 +17,22 @@
from langchain_core.callbacks import CallbackManagerForLLMRun


@pytest.fixture(autouse=True)
def blockbuster() -> Iterator[BlockBuster]:
with blockbuster_ctx() as bb:
# TODO: GraphVectorStore init is blocking. Should be fixed.
for method in (
"socket.socket.connect",
"ssl.SSLSocket.send",
"ssl.SSLSocket.recv",
"ssl.SSLSocket.read",
):
bb.functions[method].can_block_functions.append(
("langchain_astradb/graph_vectorstores.py", {"__init__"}),
)
yield bb


class ParserEmbeddings(Embeddings):
"""Parse input texts: if they are json for a List[float], fine.
Otherwise, return all zeros and call it a day.
Expand Down
55 changes: 39 additions & 16 deletions libs/astradb/tests/integration_tests/test_graphvectorstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from __future__ import annotations

import asyncio
from typing import TYPE_CHECKING, Any

import pytest
Expand Down Expand Up @@ -400,7 +401,9 @@ async def test_gvs_similarity_search_async(
request: pytest.FixtureRequest,
) -> None:
"""Simple (non-graph) similarity search on a graph vector store."""
g_store: AstraDBGraphVectorStore = request.getfixturevalue(store_name)
g_store: AstraDBGraphVectorStore = await asyncio.to_thread(
request.getfixturevalue, store_name
)
query = "universe" if is_vectorize else "[2, 10]"
embedding = [2.0, 10.0]

Expand All @@ -421,8 +424,10 @@ async def test_gvs_similarity_search_async(
assert ss_by_v_labels == ["AR", "A0"]

if is_autodetected:
assert_all_flat_docs(
g_store.vector_store.astra_env.collection, is_vectorize=is_vectorize
await asyncio.to_thread(
assert_all_flat_docs,
g_store.vector_store.astra_env.collection,
is_vectorize=is_vectorize,
)

@pytest.mark.parametrize(
Expand Down Expand Up @@ -488,7 +493,9 @@ async def test_gvs_traversal_search_async(
request: pytest.FixtureRequest,
) -> None:
"""Graph traversal search on a graph vector store."""
g_store: AstraDBGraphVectorStore = request.getfixturevalue(store_name)
g_store: AstraDBGraphVectorStore = await asyncio.to_thread(
request.getfixturevalue, store_name
)
query = "universe" if is_vectorize else "[2, 10]"

# this is a set, as some of the internals of trav.search are set-driven
Expand All @@ -499,8 +506,10 @@ async def test_gvs_traversal_search_async(
}
assert ts_labels == {"AR", "A0", "BR", "B0", "TR", "T0"}
if is_autodetected:
assert_all_flat_docs(
g_store.vector_store.astra_env.collection, is_vectorize=is_vectorize
await asyncio.to_thread(
assert_all_flat_docs,
g_store.vector_store.astra_env.collection,
is_vectorize=is_vectorize,
)

@pytest.mark.parametrize(
Expand Down Expand Up @@ -572,7 +581,9 @@ async def test_gvs_mmr_traversal_search_async(
request: pytest.FixtureRequest,
) -> None:
"""MMR Graph traversal search on a graph vector store."""
g_store: AstraDBGraphVectorStore = request.getfixturevalue(store_name)
g_store: AstraDBGraphVectorStore = await asyncio.to_thread(
request.getfixturevalue, store_name
)
query = "universe" if is_vectorize else "[2, 10]"

mt_labels = [
Expand All @@ -589,8 +600,10 @@ async def test_gvs_mmr_traversal_search_async(

assert mt_labels == ["AR", "BR"]
if is_autodetected:
assert_all_flat_docs(
g_store.vector_store.astra_env.collection, is_vectorize=is_vectorize
await asyncio.to_thread(
assert_all_flat_docs,
g_store.vector_store.astra_env.collection,
is_vectorize=is_vectorize,
)

@pytest.mark.parametrize(
Expand Down Expand Up @@ -652,7 +665,9 @@ async def test_gvs_metadata_search_async(
request: pytest.FixtureRequest,
) -> None:
"""Metadata search on a graph vector store."""
g_store: AstraDBGraphVectorStore = request.getfixturevalue(store_name)
g_store: AstraDBGraphVectorStore = await asyncio.to_thread(
request.getfixturevalue, store_name
)
mt_response = await g_store.ametadata_search(
filter={"label": "T0"},
n=2,
Expand Down Expand Up @@ -726,7 +741,9 @@ async def test_gvs_get_by_document_id_async(
request: pytest.FixtureRequest,
) -> None:
"""Get by document_id on a graph vector store."""
g_store: AstraDBGraphVectorStore = request.getfixturevalue(store_name)
g_store: AstraDBGraphVectorStore = await asyncio.to_thread(
request.getfixturevalue, store_name
)
doc = await g_store.aget_by_document_id(document_id="FL")
assert doc is not None
assert doc.metadata["label"] == "FL"
Expand Down Expand Up @@ -816,7 +833,9 @@ async def test_gvs_from_texts_async(
collection_fixture_name: str,
request: pytest.FixtureRequest,
) -> None:
collection_name: str = request.getfixturevalue(collection_fixture_name)
collection_name: str = await asyncio.to_thread(
request.getfixturevalue, collection_fixture_name
)
init_kwargs: dict[str, Any]
if is_vectorize:
init_kwargs = {
Expand All @@ -839,7 +858,7 @@ async def test_gvs_from_texts_async(
)

query = "ukrainian food" if is_vectorize else "[2, 1]"
hits = g_store.similarity_search(query=query, k=2)
hits = await g_store.asimilarity_search(query=query, k=2)
assert len(hits) == 1
assert hits[0].page_content == page_contents[0]
assert hits[0].id == "x_id"
Expand Down Expand Up @@ -915,7 +934,9 @@ async def test_gvs_from_documents_containing_ids_async(
collection_fixture_name: str,
request: pytest.FixtureRequest,
) -> None:
collection_name: str = request.getfixturevalue(collection_fixture_name)
collection_name: str = await asyncio.to_thread(
request.getfixturevalue, collection_fixture_name
)
init_kwargs: dict[str, Any]
if is_vectorize:
init_kwargs = {
Expand All @@ -941,7 +962,7 @@ async def test_gvs_from_documents_containing_ids_async(
)

query = "mexican food" if is_vectorize else "[2, 1]"
hits = g_store.similarity_search(query=query, k=2)
hits = await g_store.asimilarity_search(query=query, k=2)
assert len(hits) == 1
assert hits[0].page_content == page_contents[0]
assert hits[0].id == "x_id"
Expand Down Expand Up @@ -1005,7 +1026,9 @@ async def test_gvs_add_nodes_async(
store_name: str,
request: pytest.FixtureRequest,
) -> None:
g_store: AstraDBGraphVectorStore = request.getfixturevalue(store_name)
g_store: AstraDBGraphVectorStore = await asyncio.to_thread(
request.getfixturevalue, store_name
)
links0 = [
Link(kind="kA", direction="out", tag="tA"),
Link(kind="kB", direction="bidir", tag="tB"),
Expand Down
6 changes: 3 additions & 3 deletions libs/astradb/tests/integration_tests/test_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ async def astra_db_empty_store_async(
astra_db_credentials: AstraDBCredentials,
collection_idxid: Collection,
) -> AstraDBStore:
collection_idxid.delete_many({})
await collection_idxid.to_async().delete_many({})
return AstraDBStore(
collection_name=collection_idxid.name,
token=StaticTokenProvider(astra_db_credentials["token"]),
Expand Down Expand Up @@ -364,7 +364,7 @@ async def test_store_indexing_on_legacy_async(
database: Database,
) -> None:
"""Test of instantiation against a legacy collection, async version."""
database.create_collection(
await database.to_async().create_collection(
EPHEMERAL_LEGACY_IDX_NAME,
indexing=None,
check_exists=False,
Expand Down Expand Up @@ -413,7 +413,7 @@ async def test_store_indexing_on_custom_async(
database: Database,
) -> None:
"""Test of instantiation against a legacy collection, async version."""
database.create_collection(
await database.to_async().create_collection(
EPHEMERAL_CUSTOM_IDX_NAME,
indexing={"deny": ["useless", "forgettable"]},
check_exists=False,
Expand Down
Loading
Loading