Skip to content

Commit

Permalink
Use Blockbuster to detect blocking calls in asyncio during tests
Browse files Browse the repository at this point in the history
  • Loading branch information
cbornet committed Jan 6, 2025
1 parent edbe7d5 commit 0983c3d
Show file tree
Hide file tree
Showing 16 changed files with 331 additions and 132 deletions.
6 changes: 4 additions & 2 deletions libs/core/langchain_core/beta/runnables/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def _config_with_context(
return patch_config(config, configurable=context_funcs)


def aconfig_with_context(
async def aconfig_with_context(
config: RunnableConfig,
steps: list[Runnable],
) -> RunnableConfig:
Expand All @@ -134,7 +134,9 @@ def aconfig_with_context(
Returns:
The patched runnable config.
"""
return _config_with_context(config, steps, _asetter, _agetter, asyncio.Event)
return await asyncio.to_thread(
_config_with_context, config, steps, _asetter, _agetter, asyncio.Event
)


def config_with_context(
Expand Down
6 changes: 3 additions & 3 deletions libs/core/langchain_core/runnables/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3037,7 +3037,7 @@ async def ainvoke(
from langchain_core.beta.runnables.context import aconfig_with_context

# setup callbacks and context
config = aconfig_with_context(ensure_config(config), self.steps)
config = await aconfig_with_context(ensure_config(config), self.steps)
callback_manager = get_async_callback_manager_for_config(config)
# start the root run
run_manager = await callback_manager.on_chain_start(
Expand Down Expand Up @@ -3214,7 +3214,7 @@ async def abatch(

# setup callbacks and context
configs = [
aconfig_with_context(c, self.steps)
await aconfig_with_context(c, self.steps)
for c in get_config_list(config, len(inputs))
]
callback_managers = [
Expand Down Expand Up @@ -3364,7 +3364,7 @@ async def _atransform(
from langchain_core.beta.runnables.context import aconfig_with_context

steps = [self.first] + self.middle + [self.last]
config = aconfig_with_context(config, self.steps)
config = await aconfig_with_context(config, self.steps)

# stream the last steps
# transform the input stream of each step with the next
Expand Down
39 changes: 37 additions & 2 deletions libs/core/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions libs/core/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,8 @@ pytest-asyncio = "^0.21.1"
grandalf = "^0.8"
responses = "^0.25.0"
pytest-socket = "^0.7.0"
blockbuster = "~1.5.8"
aiofiles = "^24.1.0"
[[tool.poetry.group.test.dependencies.numpy]]
version = "^1.24.0"
python = "<3.12"
Expand Down
17 changes: 17 additions & 0 deletions libs/core/tests/unit_tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,27 @@
from uuid import UUID

import pytest
from blockbuster import blockbuster_ctx
from pytest import Config, Function, Parser
from pytest_mock import MockerFixture


@pytest.fixture(autouse=True)
def blockbuster(request):
with blockbuster_ctx() as bb:
for func in ["os.stat", "os.path.abspath"]:
bb.functions[func].can_block_in(
"langchain_core/_api/internal.py", "is_caller_internal"
)

for func in ["os.stat", "io.TextIOWrapper.read"]:
bb.functions[func].can_block_in(
"langsmith/client.py", "_default_retry_config"
)

yield bb


def pytest_addoption(parser: Parser) -> None:
"""Add custom command line options to pytest."""
parser.addoption(
Expand Down
7 changes: 6 additions & 1 deletion libs/core/tests/unit_tests/fake/test_fake_chat_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,12 @@ async def on_llm_new_token(
model = GenericFakeChatModel(messages=infinite_cycle)
tokens: list[str] = []
# New model
results = list(model.stream("meow", {"callbacks": [MyCustomAsyncHandler(tokens)]}))
results = [
chunk
async for chunk in model.astream(
"meow", {"callbacks": [MyCustomAsyncHandler(tokens)]}
)
]
assert results == [
_any_id_ai_message_chunk(content="hello"),
_any_id_ai_message_chunk(content=" "),
Expand Down
59 changes: 37 additions & 22 deletions libs/core/tests/unit_tests/language_models/chat_models/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
)
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.outputs.llm_result import LLMResult
from langchain_core.tracers import LogStreamCallbackHandler
from langchain_core.tracers.base import BaseTracer
from langchain_core.tracers.context import collect_runs
from langchain_core.tracers.event_stream import _AstreamEventsCallbackHandler
Expand Down Expand Up @@ -304,39 +305,48 @@ def _stream(


@pytest.mark.parametrize("disable_streaming", [True, False, "tool_calling"])
async def test_disable_streaming(
def test_disable_streaming(
disable_streaming: Union[bool, Literal["tool_calling"]],
) -> None:
model = StreamingModel(disable_streaming=disable_streaming)
assert model.invoke([]).content == "invoke"
assert (await model.ainvoke([])).content == "invoke"

expected = "invoke" if disable_streaming is True else "stream"
assert next(model.stream([])).content == expected
async for c in model.astream([]):
assert c.content == expected
break
assert (
model.invoke([], config={"callbacks": [LogStreamCallbackHandler()]}).content
== expected
)

expected = "invoke" if disable_streaming in ("tool_calling", True) else "stream"
assert next(model.stream([], tools=[{"type": "function"}])).content == expected
assert (
model.invoke(
[], config={"callbacks": [_AstreamEventsCallbackHandler()]}
[], config={"callbacks": [LogStreamCallbackHandler()]}, tools=[{}]
).content
== expected
)


@pytest.mark.parametrize("disable_streaming", [True, False, "tool_calling"])
async def test_disable_streaming_async(
disable_streaming: Union[bool, Literal["tool_calling"]],
) -> None:
model = StreamingModel(disable_streaming=disable_streaming)
assert (await model.ainvoke([])).content == "invoke"

expected = "invoke" if disable_streaming is True else "stream"
async for c in model.astream([]):
assert c.content == expected
break
assert (
await model.ainvoke([], config={"callbacks": [_AstreamEventsCallbackHandler()]})
).content == expected

expected = "invoke" if disable_streaming in ("tool_calling", True) else "stream"
assert next(model.stream([], tools=[{"type": "function"}])).content == expected
async for c in model.astream([], tools=[{}]):
assert c.content == expected
break
assert (
model.invoke(
[], config={"callbacks": [_AstreamEventsCallbackHandler()]}, tools=[{}]
).content
== expected
)
assert (
await model.ainvoke(
[], config={"callbacks": [_AstreamEventsCallbackHandler()]}, tools=[{}]
Expand All @@ -345,26 +355,31 @@ async def test_disable_streaming(


@pytest.mark.parametrize("disable_streaming", [True, False, "tool_calling"])
async def test_disable_streaming_no_streaming_model(
def test_disable_streaming_no_streaming_model(
disable_streaming: Union[bool, Literal["tool_calling"]],
) -> None:
model = NoStreamingModel(disable_streaming=disable_streaming)
assert model.invoke([]).content == "invoke"
assert (await model.ainvoke([])).content == "invoke"
assert next(model.stream([])).content == "invoke"
async for c in model.astream([]):
assert c.content == "invoke"
break
assert (
model.invoke(
[], config={"callbacks": [_AstreamEventsCallbackHandler()]}
).content
model.invoke([], config={"callbacks": [LogStreamCallbackHandler()]}).content
== "invoke"
)
assert next(model.stream([], tools=[{}])).content == "invoke"


@pytest.mark.parametrize("disable_streaming", [True, False, "tool_calling"])
async def test_disable_streaming_no_streaming_model_async(
disable_streaming: Union[bool, Literal["tool_calling"]],
) -> None:
model = NoStreamingModel(disable_streaming=disable_streaming)
assert (await model.ainvoke([])).content == "invoke"
async for c in model.astream([]):
assert c.content == "invoke"
break
assert (
await model.ainvoke([], config={"callbacks": [_AstreamEventsCallbackHandler()]})
).content == "invoke"
assert next(model.stream([], tools=[{}])).content == "invoke"
async for c in model.astream([], tools=[{}]):
assert c.content == "invoke"
break
10 changes: 6 additions & 4 deletions libs/core/tests/unit_tests/prompts/test_chat.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import base64
import tempfile
import warnings
from pathlib import Path
from typing import Any, Union, cast

import aiofiles
import pytest
from pydantic import ValidationError
from syrupy import SnapshotAssertion
Expand Down Expand Up @@ -724,9 +724,11 @@ async def test_chat_tmpl_from_messages_multipart_formatting_with_path() -> None:
in_mem = "base64mem"
in_file_data = "base64file01"

with tempfile.NamedTemporaryFile(delete=True, suffix=".jpg") as temp_file:
temp_file.write(base64.b64decode(in_file_data))
temp_file.flush()
async with aiofiles.tempfile.NamedTemporaryFile(
delete=True, suffix=".jpg"
) as temp_file:
await temp_file.write(base64.b64decode(in_file_data))
await temp_file.flush()

template = ChatPromptTemplate.from_messages(
[
Expand Down
17 changes: 12 additions & 5 deletions libs/core/tests/unit_tests/runnables/test_context.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
from typing import Any, Callable, NamedTuple, Union

import pytest
Expand Down Expand Up @@ -330,19 +331,26 @@ def seq_naive_rag_scoped() -> Runnable:


@pytest.mark.parametrize("runnable, cases", test_cases)
async def test_context_runnables(
def test_context_runnables(
runnable: Union[Runnable, Callable[[], Runnable]], cases: list[_TestCase]
) -> None:
runnable = runnable if isinstance(runnable, Runnable) else runnable()
assert runnable.invoke(cases[0].input) == cases[0].output
assert await runnable.ainvoke(cases[1].input) == cases[1].output
assert runnable.batch([case.input for case in cases]) == [
case.output for case in cases
]
assert add(runnable.stream(cases[0].input)) == cases[0].output


@pytest.mark.parametrize("runnable, cases", test_cases)
async def test_context_runnables_async(
runnable: Union[Runnable, Callable[[], Runnable]], cases: list[_TestCase]
) -> None:
runnable = runnable if isinstance(runnable, Runnable) else runnable()
assert await runnable.ainvoke(cases[1].input) == cases[1].output
assert await runnable.abatch([case.input for case in cases]) == [
case.output for case in cases
]
assert add(runnable.stream(cases[0].input)) == cases[0].output
assert await aadd(runnable.astream(cases[1].input)) == cases[1].output


Expand Down Expand Up @@ -390,8 +398,7 @@ async def test_runnable_seq_streaming_chunks() -> None:
"prompt": Context.getter("prompt"),
}
)

chunks = list(chain.stream({"foo": "foo", "bar": "bar"}))
chunks = await asyncio.to_thread(list, chain.stream({"foo": "foo", "bar": "bar"}))
achunks = [c async for c in chain.astream({"foo": "foo", "bar": "bar"})]
for c in chunks:
assert c in achunks
Expand Down
14 changes: 12 additions & 2 deletions libs/core/tests/unit_tests/runnables/test_fallbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,17 +82,27 @@ def chain_pass_exceptions() -> Runnable:
"runnable",
["llm", "llm_multi", "chain", "chain_pass_exceptions"],
)
async def test_fallbacks(
def test_fallbacks(
runnable: RunnableWithFallbacks, request: Any, snapshot: SnapshotAssertion
) -> None:
runnable = request.getfixturevalue(runnable)
assert runnable.invoke("hello") == "bar"
assert runnable.batch(["hi", "hey", "bye"]) == ["bar"] * 3
assert list(runnable.stream("hello")) == ["bar"]
assert dumps(runnable, pretty=True) == snapshot


@pytest.mark.parametrize(
"runnable",
["llm", "llm_multi", "chain", "chain_pass_exceptions"],
)
async def test_fallbacks_async(
runnable: RunnableWithFallbacks, request: Any, snapshot: SnapshotAssertion
) -> None:
runnable = request.getfixturevalue(runnable)
assert await runnable.ainvoke("hello") == "bar"
assert await runnable.abatch(["hi", "hey", "bye"]) == ["bar"] * 3
assert list(await runnable.ainvoke("hello")) == list("bar")
assert dumps(runnable, pretty=True) == snapshot


def _runnable(inputs: dict) -> str:
Expand Down
Loading

0 comments on commit 0983c3d

Please sign in to comment.