From f1c2af60bc71ffefba79655908c60d0ae3ed806b Mon Sep 17 00:00:00 2001 From: Christophe Bornet Date: Mon, 6 Jan 2025 16:10:15 +0100 Subject: [PATCH 1/4] Use Blockbuster to detect blocking calls in asyncio during tests --- .../langchain_core/beta/runnables/context.py | 6 +- libs/core/langchain_core/runnables/base.py | 6 +- libs/core/poetry.lock | 32 ++- libs/core/pyproject.toml | 1 + libs/core/tests/unit_tests/conftest.py | 17 ++ .../unit_tests/fake/test_fake_chat_model.py | 7 +- .../language_models/chat_models/test_base.py | 59 +++-- .../tests/unit_tests/prompts/test_chat.py | 67 +++-- .../unit_tests/runnables/test_context.py | 17 +- .../unit_tests/runnables/test_fallbacks.py | 14 +- .../unit_tests/runnables/test_runnable.py | 231 ++++++++++++------ .../runnables/test_runnable_events_v1.py | 10 +- .../runnables/test_runnable_events_v2.py | 24 +- .../runnables/test_tracing_interops.py | 7 +- .../unit_tests/tracers/test_memory_stream.py | 7 +- .../unit_tests/vectorstores/test_in_memory.py | 7 +- 16 files changed, 345 insertions(+), 167 deletions(-) diff --git a/libs/core/langchain_core/beta/runnables/context.py b/libs/core/langchain_core/beta/runnables/context.py index a53e8fdf57969..398ad488b8942 100644 --- a/libs/core/langchain_core/beta/runnables/context.py +++ b/libs/core/langchain_core/beta/runnables/context.py @@ -121,7 +121,7 @@ def _config_with_context( return patch_config(config, configurable=context_funcs) -def aconfig_with_context( +async def aconfig_with_context( config: RunnableConfig, steps: list[Runnable], ) -> RunnableConfig: @@ -134,7 +134,9 @@ def aconfig_with_context( Returns: The patched runnable config. """ - return _config_with_context(config, steps, _asetter, _agetter, asyncio.Event) + return await asyncio.to_thread( + _config_with_context, config, steps, _asetter, _agetter, asyncio.Event + ) def config_with_context( diff --git a/libs/core/langchain_core/runnables/base.py b/libs/core/langchain_core/runnables/base.py index 893f393d8b174..eabc0b8ed2c37 100644 --- a/libs/core/langchain_core/runnables/base.py +++ b/libs/core/langchain_core/runnables/base.py @@ -3037,7 +3037,7 @@ async def ainvoke( from langchain_core.beta.runnables.context import aconfig_with_context # setup callbacks and context - config = aconfig_with_context(ensure_config(config), self.steps) + config = await aconfig_with_context(ensure_config(config), self.steps) callback_manager = get_async_callback_manager_for_config(config) # start the root run run_manager = await callback_manager.on_chain_start( @@ -3214,7 +3214,7 @@ async def abatch( # setup callbacks and context configs = [ - aconfig_with_context(c, self.steps) + await aconfig_with_context(c, self.steps) for c in get_config_list(config, len(inputs)) ] callback_managers = [ @@ -3364,7 +3364,7 @@ async def _atransform( from langchain_core.beta.runnables.context import aconfig_with_context steps = [self.first] + self.middle + [self.last] - config = aconfig_with_context(config, self.steps) + config = await aconfig_with_context(config, self.steps) # stream the last steps # transform the input stream of each step with the next diff --git a/libs/core/poetry.lock b/libs/core/poetry.lock index eadfb8d460c70..2bdc229af92b0 100644 --- a/libs/core/poetry.lock +++ b/libs/core/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. [[package]] name = "annotated-types" @@ -220,6 +220,20 @@ webencodings = "*" [package.extras] css = ["tinycss2 (>=1.1.0,<1.5)"] +[[package]] +name = "blockbuster" +version = "1.5.8" +description = "Utility to detect blocking calls in the async event loop" +optional = false +python-versions = ">=3.8" +files = [ + {file = "blockbuster-1.5.8-py3-none-any.whl", hash = "sha256:ea0352823acbd837872785a96ec87701f1c7938410d66c6a8857ae11646d9744"}, + {file = "blockbuster-1.5.8.tar.gz", hash = "sha256:0018080fd735e84f0b138f15ff0a339f99444ecdc0045f16056c7b23db92706b"}, +] + +[package.dependencies] +forbiddenfruit = ">=0.1.4" + [[package]] name = "certifi" version = "2024.12.14" @@ -565,6 +579,16 @@ files = [ [package.extras] devel = ["colorama", "json-spec", "jsonschema", "pylint", "pytest", "pytest-benchmark", "pytest-cache", "validictory"] +[[package]] +name = "forbiddenfruit" +version = "0.1.4" +description = "Patch python built-in objects" +optional = false +python-versions = "*" +files = [ + {file = "forbiddenfruit-0.1.4.tar.gz", hash = "sha256:e3f7e66561a29ae129aac139a85d610dbf3dd896128187ed5454b6421f624253"}, +] + [[package]] name = "fqdn" version = "1.5.1" @@ -1225,7 +1249,7 @@ url = "../standard-tests" [[package]] name = "langchain-text-splitters" -version = "0.3.4" +version = "0.3.5" description = "LangChain text splitting utilities" optional = false python-versions = ">=3.9,<4.0" @@ -1233,7 +1257,7 @@ files = [] develop = true [package.dependencies] -langchain-core = "^0.3.26" +langchain-core = "^0.3.29" [package.source] type = "directory" @@ -3138,4 +3162,4 @@ type = ["pytest-mypy"] [metadata] lock-version = "2.0" python-versions = ">=3.9,<4.0" -content-hash = "65d2f612fead6395befc285353347bf82d09044ce832c278f8b35e4f179caebb" +content-hash = "356040a2decb213bfe3a35a0f79cad3c9b10e7e941d43e309ab5e41e45b40093" diff --git a/libs/core/pyproject.toml b/libs/core/pyproject.toml index 65e40fafb09f8..a23fdc495e2d9 100644 --- a/libs/core/pyproject.toml +++ b/libs/core/pyproject.toml @@ -109,6 +109,7 @@ grandalf = "^0.8" responses = "^0.25.0" pytest-socket = "^0.7.0" pytest-xdist = "^3.6.1" +blockbuster = "~1.5.8" [[tool.poetry.group.test.dependencies.numpy]] version = "^1.24.0" python = "<3.12" diff --git a/libs/core/tests/unit_tests/conftest.py b/libs/core/tests/unit_tests/conftest.py index 29819a8066958..050aa4b8e175b 100644 --- a/libs/core/tests/unit_tests/conftest.py +++ b/libs/core/tests/unit_tests/conftest.py @@ -5,10 +5,27 @@ from uuid import UUID import pytest +from blockbuster import blockbuster_ctx from pytest import Config, Function, Parser from pytest_mock import MockerFixture +@pytest.fixture(autouse=True) +def blockbuster(request): + with blockbuster_ctx() as bb: + for func in ["os.stat", "os.path.abspath"]: + bb.functions[func].can_block_in( + "langchain_core/_api/internal.py", "is_caller_internal" + ) + + for func in ["os.stat", "io.TextIOWrapper.read"]: + bb.functions[func].can_block_in( + "langsmith/client.py", "_default_retry_config" + ) + + yield bb + + def pytest_addoption(parser: Parser) -> None: """Add custom command line options to pytest.""" parser.addoption( diff --git a/libs/core/tests/unit_tests/fake/test_fake_chat_model.py b/libs/core/tests/unit_tests/fake/test_fake_chat_model.py index 7502e17c50fde..5d7d4525e3ffe 100644 --- a/libs/core/tests/unit_tests/fake/test_fake_chat_model.py +++ b/libs/core/tests/unit_tests/fake/test_fake_chat_model.py @@ -191,7 +191,12 @@ async def on_llm_new_token( model = GenericFakeChatModel(messages=infinite_cycle) tokens: list[str] = [] # New model - results = list(model.stream("meow", {"callbacks": [MyCustomAsyncHandler(tokens)]})) + results = [ + chunk + async for chunk in model.astream( + "meow", {"callbacks": [MyCustomAsyncHandler(tokens)]} + ) + ] assert results == [ _any_id_ai_message_chunk(content="hello"), _any_id_ai_message_chunk(content=" "), diff --git a/libs/core/tests/unit_tests/language_models/chat_models/test_base.py b/libs/core/tests/unit_tests/language_models/chat_models/test_base.py index 2cd08a27d0383..16305eaa21fc8 100644 --- a/libs/core/tests/unit_tests/language_models/chat_models/test_base.py +++ b/libs/core/tests/unit_tests/language_models/chat_models/test_base.py @@ -18,6 +18,7 @@ ) from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult from langchain_core.outputs.llm_result import LLMResult +from langchain_core.tracers import LogStreamCallbackHandler from langchain_core.tracers.base import BaseTracer from langchain_core.tracers.context import collect_runs from langchain_core.tracers.event_stream import _AstreamEventsCallbackHandler @@ -304,39 +305,48 @@ def _stream( @pytest.mark.parametrize("disable_streaming", [True, False, "tool_calling"]) -async def test_disable_streaming( +def test_disable_streaming( disable_streaming: Union[bool, Literal["tool_calling"]], ) -> None: model = StreamingModel(disable_streaming=disable_streaming) assert model.invoke([]).content == "invoke" - assert (await model.ainvoke([])).content == "invoke" expected = "invoke" if disable_streaming is True else "stream" assert next(model.stream([])).content == expected - async for c in model.astream([]): - assert c.content == expected - break + assert ( + model.invoke([], config={"callbacks": [LogStreamCallbackHandler()]}).content + == expected + ) + + expected = "invoke" if disable_streaming in ("tool_calling", True) else "stream" + assert next(model.stream([], tools=[{"type": "function"}])).content == expected assert ( model.invoke( - [], config={"callbacks": [_AstreamEventsCallbackHandler()]} + [], config={"callbacks": [LogStreamCallbackHandler()]}, tools=[{}] ).content == expected ) + + +@pytest.mark.parametrize("disable_streaming", [True, False, "tool_calling"]) +async def test_disable_streaming_async( + disable_streaming: Union[bool, Literal["tool_calling"]], +) -> None: + model = StreamingModel(disable_streaming=disable_streaming) + assert (await model.ainvoke([])).content == "invoke" + + expected = "invoke" if disable_streaming is True else "stream" + async for c in model.astream([]): + assert c.content == expected + break assert ( await model.ainvoke([], config={"callbacks": [_AstreamEventsCallbackHandler()]}) ).content == expected expected = "invoke" if disable_streaming in ("tool_calling", True) else "stream" - assert next(model.stream([], tools=[{"type": "function"}])).content == expected async for c in model.astream([], tools=[{}]): assert c.content == expected break - assert ( - model.invoke( - [], config={"callbacks": [_AstreamEventsCallbackHandler()]}, tools=[{}] - ).content - == expected - ) assert ( await model.ainvoke( [], config={"callbacks": [_AstreamEventsCallbackHandler()]}, tools=[{}] @@ -345,26 +355,31 @@ async def test_disable_streaming( @pytest.mark.parametrize("disable_streaming", [True, False, "tool_calling"]) -async def test_disable_streaming_no_streaming_model( +def test_disable_streaming_no_streaming_model( disable_streaming: Union[bool, Literal["tool_calling"]], ) -> None: model = NoStreamingModel(disable_streaming=disable_streaming) assert model.invoke([]).content == "invoke" - assert (await model.ainvoke([])).content == "invoke" assert next(model.stream([])).content == "invoke" - async for c in model.astream([]): - assert c.content == "invoke" - break assert ( - model.invoke( - [], config={"callbacks": [_AstreamEventsCallbackHandler()]} - ).content + model.invoke([], config={"callbacks": [LogStreamCallbackHandler()]}).content == "invoke" ) + assert next(model.stream([], tools=[{}])).content == "invoke" + + +@pytest.mark.parametrize("disable_streaming", [True, False, "tool_calling"]) +async def test_disable_streaming_no_streaming_model_async( + disable_streaming: Union[bool, Literal["tool_calling"]], +) -> None: + model = NoStreamingModel(disable_streaming=disable_streaming) + assert (await model.ainvoke([])).content == "invoke" + async for c in model.astream([]): + assert c.content == "invoke" + break assert ( await model.ainvoke([], config={"callbacks": [_AstreamEventsCallbackHandler()]}) ).content == "invoke" - assert next(model.stream([], tools=[{}])).content == "invoke" async for c in model.astream([], tools=[{}]): assert c.content == "invoke" break diff --git a/libs/core/tests/unit_tests/prompts/test_chat.py b/libs/core/tests/unit_tests/prompts/test_chat.py index 6249aa6f47893..b61f4748952d8 100644 --- a/libs/core/tests/unit_tests/prompts/test_chat.py +++ b/libs/core/tests/unit_tests/prompts/test_chat.py @@ -1,5 +1,3 @@ -import base64 -import tempfile import warnings from pathlib import Path from typing import Any, Union, cast @@ -722,44 +720,39 @@ async def test_chat_tmpl_from_messages_multipart_image() -> None: async def test_chat_tmpl_from_messages_multipart_formatting_with_path() -> None: """Verify that we cannot pass `path` for an image as a variable.""" in_mem = "base64mem" - in_file_data = "base64file01" - with tempfile.NamedTemporaryFile(delete=True, suffix=".jpg") as temp_file: - temp_file.write(base64.b64decode(in_file_data)) - temp_file.flush() - - template = ChatPromptTemplate.from_messages( - [ - ("system", "You are an AI assistant named {name}."), - ( - "human", - [ - {"type": "text", "text": "What's in this image?"}, - { - "type": "image_url", - "image_url": "data:image/jpeg;base64,{in_mem}", - }, - { - "type": "image_url", - "image_url": {"path": "{file_path}"}, - }, - ], - ), - ] + template = ChatPromptTemplate.from_messages( + [ + ("system", "You are an AI assistant named {name}."), + ( + "human", + [ + {"type": "text", "text": "What's in this image?"}, + { + "type": "image_url", + "image_url": "data:image/jpeg;base64,{in_mem}", + }, + { + "type": "image_url", + "image_url": {"path": "{file_path}"}, + }, + ], + ), + ] + ) + with pytest.raises(ValueError): + template.format_messages( + name="R2D2", + in_mem=in_mem, + file_path="some/path", ) - with pytest.raises(ValueError): - template.format_messages( - name="R2D2", - in_mem=in_mem, - file_path=temp_file.name, - ) - with pytest.raises(ValueError): - await template.aformat_messages( - name="R2D2", - in_mem=in_mem, - file_path=temp_file.name, - ) + with pytest.raises(ValueError): + await template.aformat_messages( + name="R2D2", + in_mem=in_mem, + file_path="some/path", + ) def test_messages_placeholder() -> None: diff --git a/libs/core/tests/unit_tests/runnables/test_context.py b/libs/core/tests/unit_tests/runnables/test_context.py index c00eb999424cb..cb8de2dd8088b 100644 --- a/libs/core/tests/unit_tests/runnables/test_context.py +++ b/libs/core/tests/unit_tests/runnables/test_context.py @@ -1,3 +1,4 @@ +import asyncio from typing import Any, Callable, NamedTuple, Union import pytest @@ -330,19 +331,26 @@ def seq_naive_rag_scoped() -> Runnable: @pytest.mark.parametrize("runnable, cases", test_cases) -async def test_context_runnables( +def test_context_runnables( runnable: Union[Runnable, Callable[[], Runnable]], cases: list[_TestCase] ) -> None: runnable = runnable if isinstance(runnable, Runnable) else runnable() assert runnable.invoke(cases[0].input) == cases[0].output - assert await runnable.ainvoke(cases[1].input) == cases[1].output assert runnable.batch([case.input for case in cases]) == [ case.output for case in cases ] + assert add(runnable.stream(cases[0].input)) == cases[0].output + + +@pytest.mark.parametrize("runnable, cases", test_cases) +async def test_context_runnables_async( + runnable: Union[Runnable, Callable[[], Runnable]], cases: list[_TestCase] +) -> None: + runnable = runnable if isinstance(runnable, Runnable) else runnable() + assert await runnable.ainvoke(cases[1].input) == cases[1].output assert await runnable.abatch([case.input for case in cases]) == [ case.output for case in cases ] - assert add(runnable.stream(cases[0].input)) == cases[0].output assert await aadd(runnable.astream(cases[1].input)) == cases[1].output @@ -390,8 +398,7 @@ async def test_runnable_seq_streaming_chunks() -> None: "prompt": Context.getter("prompt"), } ) - - chunks = list(chain.stream({"foo": "foo", "bar": "bar"})) + chunks = await asyncio.to_thread(list, chain.stream({"foo": "foo", "bar": "bar"})) achunks = [c async for c in chain.astream({"foo": "foo", "bar": "bar"})] for c in chunks: assert c in achunks diff --git a/libs/core/tests/unit_tests/runnables/test_fallbacks.py b/libs/core/tests/unit_tests/runnables/test_fallbacks.py index 731b3ddaa62aa..13f500d17dcbc 100644 --- a/libs/core/tests/unit_tests/runnables/test_fallbacks.py +++ b/libs/core/tests/unit_tests/runnables/test_fallbacks.py @@ -82,17 +82,27 @@ def chain_pass_exceptions() -> Runnable: "runnable", ["llm", "llm_multi", "chain", "chain_pass_exceptions"], ) -async def test_fallbacks( +def test_fallbacks( runnable: RunnableWithFallbacks, request: Any, snapshot: SnapshotAssertion ) -> None: runnable = request.getfixturevalue(runnable) assert runnable.invoke("hello") == "bar" assert runnable.batch(["hi", "hey", "bye"]) == ["bar"] * 3 assert list(runnable.stream("hello")) == ["bar"] + assert dumps(runnable, pretty=True) == snapshot + + +@pytest.mark.parametrize( + "runnable", + ["llm", "llm_multi", "chain", "chain_pass_exceptions"], +) +async def test_fallbacks_async( + runnable: RunnableWithFallbacks, request: Any, snapshot: SnapshotAssertion +) -> None: + runnable = request.getfixturevalue(runnable) assert await runnable.ainvoke("hello") == "bar" assert await runnable.abatch(["hi", "hey", "bye"]) == ["bar"] * 3 assert list(await runnable.ainvoke("hello")) == list("bar") - assert dumps(runnable, pretty=True) == snapshot def _runnable(inputs: dict) -> str: diff --git a/libs/core/tests/unit_tests/runnables/test_runnable.py b/libs/core/tests/unit_tests/runnables/test_runnable.py index b06cb381e80e6..123fa3d0fc184 100644 --- a/libs/core/tests/unit_tests/runnables/test_runnable.py +++ b/libs/core/tests/unit_tests/runnables/test_runnable.py @@ -1,3 +1,4 @@ +import asyncio import sys import uuid import warnings @@ -1011,20 +1012,20 @@ def test_configurable_fields_example(snapshot: SnapshotAssertion) -> None: ) -async def test_passthrough_tap_async(mocker: MockerFixture) -> None: +def test_passthrough_tap(mocker: MockerFixture) -> None: fake = FakeRunnable() mock = mocker.Mock() seq: Runnable = RunnablePassthrough(mock) | fake | RunnablePassthrough(mock) - assert await seq.ainvoke("hello", my_kwarg="value") == 5 + assert seq.invoke("hello", my_kwarg="value") == 5 # type: ignore[call-arg] assert mock.call_args_list == [ mocker.call("hello", my_kwarg="value"), mocker.call(5), ] mock.reset_mock() - assert await seq.abatch(["hello", "byebye"], my_kwarg="value") == [5, 6] + assert seq.batch(["hello", "byebye"], my_kwarg="value") == [5, 6] assert len(mock.call_args_list) == 4 for call in [ mocker.call("hello", my_kwarg="value"), @@ -1035,9 +1036,7 @@ async def test_passthrough_tap_async(mocker: MockerFixture) -> None: assert call in mock.call_args_list mock.reset_mock() - assert await seq.abatch( - ["hello", "byebye"], my_kwarg="value", return_exceptions=True - ) == [ + assert seq.batch(["hello", "byebye"], my_kwarg="value", return_exceptions=True) == [ 5, 6, ] @@ -1052,12 +1051,10 @@ async def test_passthrough_tap_async(mocker: MockerFixture) -> None: mock.reset_mock() assert sorted( - [ - a - async for a in seq.abatch_as_completed( - ["hello", "byebye"], my_kwarg="value", return_exceptions=True - ) - ] + a + for a in seq.batch_as_completed( + ["hello", "byebye"], my_kwarg="value", return_exceptions=True + ) ) == [ (0, 5), (1, 6), @@ -1072,26 +1069,30 @@ async def test_passthrough_tap_async(mocker: MockerFixture) -> None: assert call in mock.call_args_list mock.reset_mock() - assert [ - part - async for part in seq.astream( - "hello", {"metadata": {"key": "value"}}, my_kwarg="value" - ) - ] == [5] + assert list( + seq.stream("hello", {"metadata": {"key": "value"}}, my_kwarg="value") + ) == [5] assert mock.call_args_list == [ mocker.call("hello", my_kwarg="value"), mocker.call(5), ] mock.reset_mock() - assert seq.invoke("hello", my_kwarg="value") == 5 # type: ignore[call-arg] + +async def test_passthrough_tap_async(mocker: MockerFixture) -> None: + fake = FakeRunnable() + mock = mocker.Mock() + + seq: Runnable = RunnablePassthrough(mock) | fake | RunnablePassthrough(mock) + + assert await seq.ainvoke("hello", my_kwarg="value") == 5 assert mock.call_args_list == [ mocker.call("hello", my_kwarg="value"), mocker.call(5), ] mock.reset_mock() - assert seq.batch(["hello", "byebye"], my_kwarg="value") == [5, 6] + assert await seq.abatch(["hello", "byebye"], my_kwarg="value") == [5, 6] assert len(mock.call_args_list) == 4 for call in [ mocker.call("hello", my_kwarg="value"), @@ -1102,7 +1103,9 @@ async def test_passthrough_tap_async(mocker: MockerFixture) -> None: assert call in mock.call_args_list mock.reset_mock() - assert seq.batch(["hello", "byebye"], my_kwarg="value", return_exceptions=True) == [ + assert await seq.abatch( + ["hello", "byebye"], my_kwarg="value", return_exceptions=True + ) == [ 5, 6, ] @@ -1117,10 +1120,12 @@ async def test_passthrough_tap_async(mocker: MockerFixture) -> None: mock.reset_mock() assert sorted( - a - for a in seq.batch_as_completed( - ["hello", "byebye"], my_kwarg="value", return_exceptions=True - ) + [ + a + async for a in seq.abatch_as_completed( + ["hello", "byebye"], my_kwarg="value", return_exceptions=True + ) + ] ) == [ (0, 5), (1, 6), @@ -1135,14 +1140,16 @@ async def test_passthrough_tap_async(mocker: MockerFixture) -> None: assert call in mock.call_args_list mock.reset_mock() - assert list( - seq.stream("hello", {"metadata": {"key": "value"}}, my_kwarg="value") - ) == [5] + assert [ + part + async for part in seq.astream( + "hello", {"metadata": {"key": "value"}}, my_kwarg="value" + ) + ] == [5] assert mock.call_args_list == [ mocker.call("hello", my_kwarg="value"), mocker.call(5), ] - mock.reset_mock() async def test_with_config_metadata_passthrough(mocker: MockerFixture) -> None: @@ -1173,7 +1180,7 @@ async def test_with_config_metadata_passthrough(mocker: MockerFixture) -> None: spy.reset_mock() -async def test_with_config(mocker: MockerFixture) -> None: +def test_with_config(mocker: MockerFixture) -> None: fake = FakeRunnable() spy = mocker.spy(fake, "invoke") @@ -1279,7 +1286,11 @@ async def test_with_config(mocker: MockerFixture) -> None: for i, call in enumerate(spy.call_args_list): assert call.args[0] == ("hello" if i == 0 else "wooorld") assert call.args[1].get("tags") == ["a-tag"] - spy.reset_mock() + + +async def test_with_config_async(mocker: MockerFixture) -> None: + fake = FakeRunnable() + spy = mocker.spy(fake, "invoke") handler = ConsoleCallbackHandler() assert ( @@ -1375,7 +1386,7 @@ async def test_with_config(mocker: MockerFixture) -> None: ) -async def test_default_method_implementations(mocker: MockerFixture) -> None: +def test_default_method_implementations(mocker: MockerFixture) -> None: fake = FakeRunnable() spy = mocker.spy(fake, "invoke") @@ -1416,7 +1427,11 @@ async def test_default_method_implementations(mocker: MockerFixture) -> None: for call in spy.call_args_list: assert call.args[1].get("tags") == ["a-tag"] assert call.args[1].get("metadata") == {} - spy.reset_mock() + + +async def test_default_method_implementations_async(mocker: MockerFixture) -> None: + fake = FakeRunnable() + spy = mocker.spy(fake, "invoke") assert await fake.ainvoke("hello", config={"callbacks": []}) == 5 assert spy.call_args_list == [ @@ -1445,7 +1460,7 @@ async def test_default_method_implementations(mocker: MockerFixture) -> None: } -async def test_prompt() -> None: +def test_prompt() -> None: prompt = ChatPromptTemplate.from_messages( messages=[ SystemMessage(content="You are a nice assistant."), @@ -1478,6 +1493,21 @@ async def test_prompt() -> None: assert [*prompt.stream({"question": "What is your name?"})] == [expected] + +async def test_prompt_async() -> None: + prompt = ChatPromptTemplate.from_messages( + messages=[ + SystemMessage(content="You are a nice assistant."), + HumanMessagePromptTemplate.from_template("{question}"), + ] + ) + expected = ChatPromptValue( + messages=[ + SystemMessage(content="You are a nice assistant."), + HumanMessage(content="What is your name?"), + ] + ) + assert await prompt.ainvoke({"question": "What is your name?"}) == expected assert await prompt.abatch( @@ -2773,9 +2803,7 @@ def test_seq_prompt_dict(mocker: MockerFixture, snapshot: SnapshotAssertion) -> @freeze_time("2023-01-01") -async def test_router_runnable( - mocker: MockerFixture, snapshot: SnapshotAssertion -) -> None: +def test_router_runnable(mocker: MockerFixture, snapshot: SnapshotAssertion) -> None: chain1: Runnable = ChatPromptTemplate.from_template( "You are a math genius. Answer the question: {question}" ) | FakeListLLM(responses=["4"]) @@ -2797,14 +2825,6 @@ async def test_router_runnable( ) assert result2 == ["4", "2"] - result = await chain.ainvoke({"key": "math", "question": "2 + 2"}) - assert result == "4" - - result2 = await chain.abatch( - [{"key": "math", "question": "2 + 2"}, {"key": "english", "question": "2 + 2"}] - ) - assert result2 == ["4", "2"] - # Test invoke router_spy = mocker.spy(router.__class__, "invoke") tracer = FakeTracer() @@ -2824,8 +2844,30 @@ async def test_router_runnable( assert len(router_run.child_runs) == 2 +async def test_router_runnable_async() -> None: + chain1: Runnable = ChatPromptTemplate.from_template( + "You are a math genius. Answer the question: {question}" + ) | FakeListLLM(responses=["4"]) + chain2: Runnable = ChatPromptTemplate.from_template( + "You are an english major. Answer the question: {question}" + ) | FakeListLLM(responses=["2"]) + router: Runnable = RouterRunnable({"math": chain1, "english": chain2}) + chain: Runnable = { + "key": lambda x: x["key"], + "input": {"question": lambda x: x["question"]}, + } | router + + result = await chain.ainvoke({"key": "math", "question": "2 + 2"}) + assert result == "4" + + result2 = await chain.abatch( + [{"key": "math", "question": "2 + 2"}, {"key": "english", "question": "2 + 2"}] + ) + assert result2 == ["4", "2"] + + @freeze_time("2023-01-01") -async def test_higher_order_lambda_runnable( +def test_higher_order_lambda_runnable( mocker: MockerFixture, snapshot: SnapshotAssertion ) -> None: math_chain: Runnable = ChatPromptTemplate.from_template( @@ -2859,14 +2901,6 @@ def router(input: dict[str, Any]) -> Runnable: ) assert result2 == ["4", "2"] - result = await chain.ainvoke({"key": "math", "question": "2 + 2"}) - assert result == "4" - - result2 = await chain.abatch( - [{"key": "math", "question": "2 + 2"}, {"key": "english", "question": "2 + 2"}] - ) - assert result2 == ["4", "2"] - # Test invoke math_spy = mocker.spy(math_chain.__class__, "invoke") tracer = FakeTracer() @@ -2888,6 +2922,38 @@ def router(input: dict[str, Any]) -> Runnable: assert math_run.name == "RunnableSequence" assert len(math_run.child_runs) == 3 + +async def test_higher_order_lambda_runnable_async(mocker: MockerFixture) -> None: + math_chain: Runnable = ChatPromptTemplate.from_template( + "You are a math genius. Answer the question: {question}" + ) | FakeListLLM(responses=["4"]) + english_chain: Runnable = ChatPromptTemplate.from_template( + "You are an english major. Answer the question: {question}" + ) | FakeListLLM(responses=["2"]) + input_map: Runnable = RunnableParallel( + key=lambda x: x["key"], + input={"question": lambda x: x["question"]}, + ) + + def router(input: dict[str, Any]) -> Runnable: + if input["key"] == "math": + return itemgetter("input") | math_chain + elif input["key"] == "english": + return itemgetter("input") | english_chain + else: + msg = f"Unknown key: {input['key']}" + raise ValueError(msg) + + chain: Runnable = input_map | router + + result = await chain.ainvoke({"key": "math", "question": "2 + 2"}) + assert result == "4" + + result2 = await chain.abatch( + [{"key": "math", "question": "2 + 2"}, {"key": "english", "question": "2 + 2"}] + ) + assert result2 == ["4", "2"] + # Test ainvoke async def arouter(input: dict[str, Any]) -> Runnable: if input["key"] == "math": @@ -4643,7 +4709,7 @@ async def test_tool_from_runnable() -> None: } -async def test_runnable_gen() -> None: +def test_runnable_gen() -> None: """Test that a generator can be used as a runnable.""" def gen(input: Iterator[Any]) -> Iterator[int]: @@ -4663,6 +4729,10 @@ def gen(input: Iterator[Any]) -> Iterator[int]: assert list(runnable.stream(None)) == [1, 2, 3] assert runnable.batch([None, None]) == [6, 6] + +async def test_runnable_gen_async() -> None: + """Test that a generator can be used as a runnable.""" + async def agen(input: AsyncIterator[Any]) -> AsyncIterator[int]: yield 1 yield 2 @@ -4685,14 +4755,14 @@ async def __call__(self, input: AsyncIterator[Any]) -> AsyncIterator[int]: assert [p async for p in arunnablecallable.astream(None)] == [1, 2, 3] assert await arunnablecallable.abatch([None, None]) == [6, 6] with pytest.raises(NotImplementedError): - arunnablecallable.invoke(None) + await asyncio.to_thread(arunnablecallable.invoke, None) with pytest.raises(NotImplementedError): - arunnablecallable.stream(None) + await asyncio.to_thread(arunnablecallable.stream, None) with pytest.raises(NotImplementedError): - arunnablecallable.batch([None, None]) + await asyncio.to_thread(arunnablecallable.batch, [None, None]) -async def test_runnable_gen_context_config() -> None: +def test_runnable_gen_context_config() -> None: """Test that a generator can call other runnables with config propagated from the context.""" @@ -4761,9 +4831,16 @@ def gen(input: Iterator[Any]) -> Iterator[int]: assert [r.inputs["input"] for r in tracer.runs[1].child_runs] == ["a", "aa", "aaa"] assert [(r.outputs or {})["output"] for r in tracer.runs[1].child_runs] == [1, 2, 3] - if sys.version_info < (3, 11): - # Python 3.10 and below don't support running async tasks in a specific context - return + +@pytest.mark.skipif( + sys.version_info < (3, 11), + reason="Python 3.10 and below don't support running async tasks in a specific context", +) +async def test_runnable_gen_context_config_async() -> None: + """Test that a generator can call other runnables with config + propagated from the context.""" + + fake = RunnableLambda(len) async def agen(input: AsyncIterator[Any]) -> AsyncIterator[int]: yield await fake.ainvoke("a") @@ -4827,7 +4904,7 @@ async def agen(input: AsyncIterator[Any]) -> AsyncIterator[int]: assert [(r.outputs or {})["output"] for r in tracer.runs[1].child_runs] == [1, 2, 3] -async def test_runnable_iter_context_config() -> None: +def test_runnable_iter_context_config() -> None: """Test that a generator can call other runnables with config propagated from the context.""" @@ -4880,9 +4957,16 @@ def gen(input: str) -> Iterator[int]: assert [r.inputs["input"] for r in tracer.runs[1].child_runs] == ["a", "aa", "aaa"] assert [(r.outputs or {})["output"] for r in tracer.runs[1].child_runs] == [1, 2, 3] - if sys.version_info < (3, 11): - # Python 3.10 and below don't support running async tasks in a specific context - return + +@pytest.mark.skipif( + sys.version_info < (3, 11), + reason="Python 3.10 and below don't support running async tasks in a specific context", +) +async def test_runnable_iter_context_config_async() -> None: + """Test that a generator can call other runnables with config + propagated from the context.""" + + fake = RunnableLambda(len) @chain async def agen(input: str) -> AsyncIterator[int]: @@ -4944,7 +5028,7 @@ async def agen(input: str) -> AsyncIterator[int]: assert [(r.outputs or {})["output"] for r in tracer.runs[1].child_runs] == [1, 2, 3] -async def test_runnable_lambda_context_config() -> None: +def test_runnable_lambda_context_config() -> None: """Test that a function can call other runnables with config propagated from the context.""" @@ -4995,9 +5079,16 @@ def fun(input: str) -> int: assert [r.inputs["input"] for r in tracer.runs[1].child_runs] == ["a", "aa", "aaa"] assert [(r.outputs or {})["output"] for r in tracer.runs[1].child_runs] == [1, 2, 3] - if sys.version_info < (3, 11): - # Python 3.10 and below don't support running async tasks in a specific context - return + +@pytest.mark.skipif( + sys.version_info < (3, 11), + reason="Python 3.10 and below don't support running async tasks in a specific context", +) +async def test_runnable_lambda_context_config_async() -> None: + """Test that a function can call other runnables with config + propagated from the context.""" + + fake = RunnableLambda(len) @chain async def afun(input: str) -> int: diff --git a/libs/core/tests/unit_tests/runnables/test_runnable_events_v1.py b/libs/core/tests/unit_tests/runnables/test_runnable_events_v1.py index 7389d887769d8..89530594d828e 100644 --- a/libs/core/tests/unit_tests/runnables/test_runnable_events_v1.py +++ b/libs/core/tests/unit_tests/runnables/test_runnable_events_v1.py @@ -1,5 +1,6 @@ """Module that contains tests for runnable.astream_events API.""" +import asyncio import sys from collections.abc import AsyncIterator, Sequence from itertools import cycle @@ -1948,9 +1949,12 @@ def get_by_session_id(session_id: str) -> BaseChatMessageHistory: ] } - with_message_history.with_config( - {"configurable": {"session_id": "session-123"}} - ).invoke({"question": "meow"}) + await asyncio.to_thread( + with_message_history.with_config( + {"configurable": {"session_id": "session-123"}} + ).invoke, + {"question": "meow"}, + ) assert store == { "session-123": [ HumanMessage(content="hello"), diff --git a/libs/core/tests/unit_tests/runnables/test_runnable_events_v2.py b/libs/core/tests/unit_tests/runnables/test_runnable_events_v2.py index 8ceb4bf38b5f1..a7071835786a7 100644 --- a/libs/core/tests/unit_tests/runnables/test_runnable_events_v2.py +++ b/libs/core/tests/unit_tests/runnables/test_runnable_events_v2.py @@ -38,7 +38,9 @@ chain, ensure_config, ) -from langchain_core.runnables.config import get_callback_manager_for_config +from langchain_core.runnables.config import ( + get_async_callback_manager_for_config, +) from langchain_core.runnables.history import RunnableWithMessageHistory from langchain_core.runnables.schema import StreamEvent from langchain_core.runnables.utils import Input, Output @@ -1914,9 +1916,12 @@ def _get_output_messages(*args, **kwargs): # type: ignore ] } - with_message_history.with_config( - {"configurable": {"session_id": "session-123"}} - ).invoke({"question": "meow"}) + await asyncio.to_thread( + with_message_history.with_config( + {"configurable": {"session_id": "session-123"}} + ).invoke, + {"question": "meow"}, + ) assert store == { "session-123": [ HumanMessage(content="hello"), @@ -1986,8 +1991,9 @@ def _get_output_messages(*args, **kwargs): # type: ignore ] -async def test_sync_in_async_stream_lambdas() -> None: +async def test_sync_in_async_stream_lambdas(blockbuster) -> None: """Test invoking nested runnable lambda.""" + blockbuster.deactivate() def add_one(x: int) -> int: return x + 1 @@ -2076,8 +2082,8 @@ async def astream( **kwargs: Optional[Any], ) -> AsyncIterator[Output]: config = ensure_config(config) - callback_manager = get_callback_manager_for_config(config) - run_manager = callback_manager.on_chain_start( + callback_manager = get_async_callback_manager_for_config(config) + run_manager = await callback_manager.on_chain_start( None, input, name=config.get("run_name", self.get_name()), @@ -2100,9 +2106,9 @@ async def astream( final_output = element # set final channel values as run output - run_manager.on_chain_end(final_output) + await run_manager.on_chain_end(final_output) except BaseException as e: - run_manager.on_chain_error(e) + await run_manager.on_chain_error(e) raise diff --git a/libs/core/tests/unit_tests/runnables/test_tracing_interops.py b/libs/core/tests/unit_tests/runnables/test_tracing_interops.py index 0743929f86120..eb6ff8de81d91 100644 --- a/libs/core/tests/unit_tests/runnables/test_tracing_interops.py +++ b/libs/core/tests/unit_tests/runnables/test_tracing_interops.py @@ -1,3 +1,4 @@ +import asyncio import json import sys import uuid @@ -298,17 +299,17 @@ def parent(a: int) -> int: # Now run the chain and check the resulting posts cb = [tracer] if method == "invoke": - res: Any = parent.invoke(1, {"callbacks": cb}) # type: ignore + res: Any = await asyncio.to_thread(parent.invoke, 1, {"callbacks": cb}) # type: ignore elif method == "ainvoke": res = await parent.ainvoke(1, {"callbacks": cb}) # type: ignore elif method == "stream": - results = list(parent.stream(1, {"callbacks": cb})) # type: ignore + results = await asyncio.to_thread(list, parent.stream(1, {"callbacks": cb})) # type: ignore res = results[-1] elif method == "astream": results = [res async for res in parent.astream(1, {"callbacks": cb})] # type: ignore res = results[-1] elif method == "batch": - res = parent.batch([1], {"callbacks": cb})[0] # type: ignore + res = (await asyncio.to_thread(parent.batch, [1], {"callbacks": cb}))[0] # type: ignore elif method == "abatch": res = (await parent.abatch([1], {"callbacks": cb}))[0] # type: ignore else: diff --git a/libs/core/tests/unit_tests/tracers/test_memory_stream.py b/libs/core/tests/unit_tests/tracers/test_memory_stream.py index 451ab35bb678e..d49a619f12ff6 100644 --- a/libs/core/tests/unit_tests/tracers/test_memory_stream.py +++ b/libs/core/tests/unit_tests/tracers/test_memory_stream.py @@ -2,7 +2,6 @@ import math import time from collections.abc import AsyncIterator -from concurrent.futures import ThreadPoolExecutor from langchain_core.tracers.memory_stream import _MemoryStream @@ -93,9 +92,9 @@ async def consumer() -> AsyncIterator[dict]: **item, } - with ThreadPoolExecutor() as executor: - executor.submit(sync_call) - items = [item async for item in consumer()] + task = asyncio.create_task(asyncio.to_thread(sync_call)) + items = [item async for item in consumer()] + await task for item in items: delta_time = item["receive_time"] - item["produce_time"] diff --git a/libs/core/tests/unit_tests/vectorstores/test_in_memory.py b/libs/core/tests/unit_tests/vectorstores/test_in_memory.py index e76f843616be8..bc030ae0a41c6 100644 --- a/libs/core/tests/unit_tests/vectorstores/test_in_memory.py +++ b/libs/core/tests/unit_tests/vectorstores/test_in_memory.py @@ -1,3 +1,4 @@ +import asyncio from pathlib import Path from unittest.mock import AsyncMock, Mock @@ -90,9 +91,11 @@ async def test_inmemory_dump_load(tmp_path: Path) -> None: output = await store.asimilarity_search("foo", k=1) test_file = str(tmp_path / "test.json") - store.dump(test_file) + await asyncio.to_thread(store.dump, test_file) - loaded_store = InMemoryVectorStore.load(test_file, embedding) + loaded_store = await asyncio.to_thread( + InMemoryVectorStore.load, test_file, embedding + ) loaded_output = await loaded_store.asimilarity_search("foo", k=1) assert output == loaded_output From a4d6433352275628e6982e98cc14fa75aab2130e Mon Sep 17 00:00:00 2001 From: Christophe Bornet Date: Fri, 10 Jan 2025 13:46:41 +0100 Subject: [PATCH 2/4] Refactor test_runnable_sequence_parallel_trace_nesting --- libs/core/tests/unit_tests/conftest.py | 6 +- .../runnables/test_runnable_events_v2.py | 3 +- .../runnables/test_tracing_interops.py | 314 ++++++++++-------- 3 files changed, 179 insertions(+), 144 deletions(-) diff --git a/libs/core/tests/unit_tests/conftest.py b/libs/core/tests/unit_tests/conftest.py index 050aa4b8e175b..c1b8d3c43f978 100644 --- a/libs/core/tests/unit_tests/conftest.py +++ b/libs/core/tests/unit_tests/conftest.py @@ -1,17 +1,17 @@ """Configuration for unit tests.""" -from collections.abc import Sequence +from collections.abc import Iterator, Sequence from importlib import util from uuid import UUID import pytest -from blockbuster import blockbuster_ctx +from blockbuster import BlockBuster, blockbuster_ctx from pytest import Config, Function, Parser from pytest_mock import MockerFixture @pytest.fixture(autouse=True) -def blockbuster(request): +def blockbuster() -> Iterator[BlockBuster]: with blockbuster_ctx() as bb: for func in ["os.stat", "os.path.abspath"]: bb.functions[func].can_block_in( diff --git a/libs/core/tests/unit_tests/runnables/test_runnable_events_v2.py b/libs/core/tests/unit_tests/runnables/test_runnable_events_v2.py index a7071835786a7..97661b47403ee 100644 --- a/libs/core/tests/unit_tests/runnables/test_runnable_events_v2.py +++ b/libs/core/tests/unit_tests/runnables/test_runnable_events_v2.py @@ -13,6 +13,7 @@ ) import pytest +from blockbuster import BlockBuster from pydantic import BaseModel from langchain_core.callbacks import CallbackManagerForRetrieverRun, Callbacks @@ -1991,7 +1992,7 @@ def _get_output_messages(*args, **kwargs): # type: ignore ] -async def test_sync_in_async_stream_lambdas(blockbuster) -> None: +async def test_sync_in_async_stream_lambdas(blockbuster: BlockBuster) -> None: """Test invoking nested runnable lambda.""" blockbuster.deactivate() diff --git a/libs/core/tests/unit_tests/runnables/test_tracing_interops.py b/libs/core/tests/unit_tests/runnables/test_tracing_interops.py index eb6ff8de81d91..c4f1e79d794c7 100644 --- a/libs/core/tests/unit_tests/runnables/test_tracing_interops.py +++ b/libs/core/tests/unit_tests/runnables/test_tracing_interops.py @@ -1,9 +1,11 @@ -import asyncio +from __future__ import annotations + import json import sys import uuid -from collections.abc import AsyncGenerator, Generator -from typing import Any +from collections.abc import AsyncGenerator, Coroutine, Generator +from inspect import isasyncgenfunction +from typing import Any, Callable, Optional from unittest.mock import MagicMock, patch import pytest @@ -13,6 +15,7 @@ from langsmith.utils import get_env_var from typing_extensions import Literal +from langchain_core.callbacks import BaseCallbackHandler from langchain_core.runnables.base import RunnableLambda, RunnableParallel from langchain_core.tracers.langchain import LangChainTracer @@ -36,6 +39,17 @@ def _get_posts(client: Client) -> list: return posts +def _create_tracer_with_mocked_client( + project_name: Optional[str] = None, + tags: Optional[list[str]] = None, +) -> LangChainTracer: + mock_session = MagicMock() + mock_client_ = Client( + session=mock_session, api_key="test", auto_batch_tracing=False + ) + return LangChainTracer(client=mock_client_, project_name=project_name, tags=tags) + + def test_tracing_context() -> None: mock_session = MagicMock() mock_client_ = Client( @@ -57,12 +71,8 @@ def my_function(a: int) -> int: def test_config_traceable_handoff() -> None: get_env_var.cache_clear() - mock_session = MagicMock() - mock_client_ = Client( - session=mock_session, api_key="test", auto_batch_tracing=False - ) - tracer = LangChainTracer( - client=mock_client_, project_name="another-flippin-project", tags=["such-a-tag"] + tracer = _create_tracer_with_mocked_client( + project_name="another-flippin-project", tags=["such-a-tag"] ) @traceable @@ -101,7 +111,7 @@ def my_parent_function(a: int) -> int: my_parent_runnable = RunnableLambda(my_parent_function) assert my_parent_runnable.invoke(1, {"callbacks": [tracer]}) == 6 - posts = _get_posts(mock_client_) + posts = _get_posts(tracer.client) assert all(post["session_name"] == "another-flippin-project" for post in posts) # There should have been 6 runs created, # one for each function invocation @@ -144,11 +154,7 @@ def my_parent_function(a: int) -> int: sys.version_info < (3, 11), reason="Asyncio context vars require Python 3.11+" ) async def test_config_traceable_async_handoff() -> None: - mock_session = MagicMock() - mock_client_ = Client( - session=mock_session, api_key="test", auto_batch_tracing=False - ) - tracer = LangChainTracer(client=mock_client_) + tracer = _create_tracer_with_mocked_client() @traceable def my_great_great_grandchild_function(a: int) -> int: @@ -176,7 +182,7 @@ async def my_parent_function(a: int) -> int: my_parent_runnable = RunnableLambda(my_parent_function) # type: ignore result = await my_parent_runnable.ainvoke(1, {"callbacks": [tracer]}) assert result == 6 - posts = _get_posts(mock_client_) + posts = _get_posts(tracer.client) # There should have been 6 runs created, # one for each function invocation assert len(posts) == 6 @@ -246,144 +252,172 @@ def my_func(a: int) -> int: assert not mock_posts -@pytest.mark.parametrize( - "method", ["invoke", "stream", "batch", "ainvoke", "astream", "abatch"] -) -async def test_runnable_sequence_parallel_trace_nesting(method: str) -> None: - if method.startswith("a") and sys.version_info < (3, 11): - pytest.skip("Asyncio context vars require Python 3.11+") - mock_session = MagicMock() - mock_client_ = Client( - session=mock_session, api_key="test", auto_batch_tracing=False - ) - tracer = LangChainTracer(client=mock_client_) +class TestRunnableSequenceParallelTraceNesting: + @pytest.fixture(autouse=True) + def _setup(self) -> None: + self.tracer = _create_tracer_with_mocked_client() - @RunnableLambda - def my_child_function(a: int) -> int: - return a + 2 + @staticmethod + def _create_parent( + other_thing: Callable[ + [int], Generator[int, None, None] | AsyncGenerator[int, None] + ], + ) -> RunnableLambda: + @RunnableLambda + def my_child_function(a: int) -> int: + return a + 2 - if method.startswith("a"): + parallel = RunnableParallel( + chain_result=my_child_function.with_config(tags=["atag"]), + other_thing=other_thing, + ) - async def other_thing(a: int) -> AsyncGenerator[int, None]: - yield 1 + def before(x: int) -> int: + return x - else: + def after(x: dict) -> int: + return x["chain_result"] + + sequence = before | parallel | after + if isasyncgenfunction(other_thing): + + @RunnableLambda # type: ignore + async def parent(a: int) -> int: + return await sequence.ainvoke(a) + else: + + @RunnableLambda + def parent(a: int) -> int: + return sequence.invoke(a) + + return parent + + def _check_posts(self) -> None: + posts = _get_posts(self.tracer.client) + name_order = [ + "parent", + "RunnableSequence", + "before", + "RunnableParallel", + ["my_child_function", "other_thing"], + "after", + ] + expected_parents = { + "parent": None, + "RunnableSequence": "parent", + "before": "RunnableSequence", + "RunnableParallel": "RunnableSequence", + "my_child_function": "RunnableParallel", + "other_thing": "RunnableParallel", + "after": "RunnableSequence", + } + assert len(posts) == sum( + [1 if isinstance(n, str) else len(n) for n in name_order] + ) + prev_dotted_order = None + dotted_order_map = {} + id_map = {} + parent_id_map = {} + i = 0 + for name in name_order: + if isinstance(name, list): + for n in name: + matching_post = next( + p for p in posts[i : i + len(name)] if p["name"] == n + ) + assert matching_post + dotted_order = matching_post["dotted_order"] + if prev_dotted_order is not None: + assert dotted_order > prev_dotted_order + dotted_order_map[n] = dotted_order + id_map[n] = matching_post["id"] + parent_id_map[n] = matching_post.get("parent_run_id") + i += len(name) + continue + else: + assert posts[i]["name"] == name + dotted_order = posts[i]["dotted_order"] + if prev_dotted_order is not None and not str( + expected_parents[name] + ).startswith("RunnableParallel"): + assert ( + dotted_order > prev_dotted_order + ), f"{name} not after {name_order[i - 1]}" + prev_dotted_order = dotted_order + if name in dotted_order_map: + msg = f"Duplicate name {name}" + raise ValueError(msg) + dotted_order_map[name] = dotted_order + id_map[name] = posts[i]["id"] + parent_id_map[name] = posts[i].get("parent_run_id") + i += 1 + + # Now check the dotted orders + for name, parent_ in expected_parents.items(): + dotted_order = dotted_order_map[name] + if parent_ is not None: + parent_dotted_order = dotted_order_map[parent_] + assert dotted_order.startswith( + parent_dotted_order + ), f"{name}, {parent_dotted_order} not in {dotted_order}" + assert str(parent_id_map[name]) == str(id_map[parent_]) + else: + assert dotted_order.split(".")[0] == dotted_order + + @pytest.mark.parametrize( + "method", + [ + lambda parent, cb: parent.invoke(1, {"callbacks": cb}), + lambda parent, cb: list(parent.stream(1, {"callbacks": cb}))[-1], + lambda parent, cb: parent.batch([1], {"callbacks": cb})[0], + ], + ids=["invoke", "stream", "batch"], + ) + def test_sync( + self, method: Callable[[RunnableLambda, list[BaseCallbackHandler]], int] + ) -> None: def other_thing(a: int) -> Generator[int, None, None]: # type: ignore yield 1 - parallel = RunnableParallel( - chain_result=my_child_function.with_config(tags=["atag"]), - other_thing=other_thing, - ) + parent = self._create_parent(other_thing) - def before(x: int) -> int: - return x + # Now run the chain and check the resulting posts + assert method(parent, [self.tracer]) == 3 - def after(x: dict) -> int: - return x["chain_result"] + self._check_posts() - sequence = before | parallel | after - if method.startswith("a"): + @staticmethod + async def ainvoke(parent: RunnableLambda, cb: list[BaseCallbackHandler]) -> int: + return await parent.ainvoke(1, {"callbacks": cb}) - @RunnableLambda # type: ignore - async def parent(a: int) -> int: - return await sequence.ainvoke(a) + @staticmethod + async def astream(parent: RunnableLambda, cb: list[BaseCallbackHandler]) -> int: + return [res async for res in parent.astream(1, {"callbacks": cb})][-1] - else: + @staticmethod + async def abatch(parent: RunnableLambda, cb: list[BaseCallbackHandler]) -> int: + return (await parent.abatch([1], {"callbacks": cb}))[0] - @RunnableLambda - def parent(a: int) -> int: - return sequence.invoke(a) - - # Now run the chain and check the resulting posts - cb = [tracer] - if method == "invoke": - res: Any = await asyncio.to_thread(parent.invoke, 1, {"callbacks": cb}) # type: ignore - elif method == "ainvoke": - res = await parent.ainvoke(1, {"callbacks": cb}) # type: ignore - elif method == "stream": - results = await asyncio.to_thread(list, parent.stream(1, {"callbacks": cb})) # type: ignore - res = results[-1] - elif method == "astream": - results = [res async for res in parent.astream(1, {"callbacks": cb})] # type: ignore - res = results[-1] - elif method == "batch": - res = (await asyncio.to_thread(parent.batch, [1], {"callbacks": cb}))[0] # type: ignore - elif method == "abatch": - res = (await parent.abatch([1], {"callbacks": cb}))[0] # type: ignore - else: - msg = f"Unknown method {method}" - raise ValueError(msg) - assert res == 3 - posts = _get_posts(mock_client_) - name_order = [ - "parent", - "RunnableSequence", - "before", - "RunnableParallel", - ["my_child_function", "other_thing"], - "after", - ] - expected_parents = { - "parent": None, - "RunnableSequence": "parent", - "before": "RunnableSequence", - "RunnableParallel": "RunnableSequence", - "my_child_function": "RunnableParallel", - "other_thing": "RunnableParallel", - "after": "RunnableSequence", - } - assert len(posts) == sum([1 if isinstance(n, str) else len(n) for n in name_order]) - prev_dotted_order = None - dotted_order_map = {} - id_map = {} - parent_id_map = {} - i = 0 - for name in name_order: - if isinstance(name, list): - for n in name: - matching_post = next( - p for p in posts[i : i + len(name)] if p["name"] == n - ) - assert matching_post - dotted_order = matching_post["dotted_order"] - if prev_dotted_order is not None: - assert dotted_order > prev_dotted_order - dotted_order_map[n] = dotted_order - id_map[n] = matching_post["id"] - parent_id_map[n] = matching_post.get("parent_run_id") - i += len(name) - continue - else: - assert posts[i]["name"] == name - dotted_order = posts[i]["dotted_order"] - if prev_dotted_order is not None and not str( - expected_parents[name] - ).startswith("RunnableParallel"): - assert ( - dotted_order > prev_dotted_order - ), f"{name} not after {name_order[i-1]}" - prev_dotted_order = dotted_order - if name in dotted_order_map: - msg = f"Duplicate name {name}" - raise ValueError(msg) - dotted_order_map[name] = dotted_order - id_map[name] = posts[i]["id"] - parent_id_map[name] = posts[i].get("parent_run_id") - i += 1 - - # Now check the dotted orders - for name, parent_ in expected_parents.items(): - dotted_order = dotted_order_map[name] - if parent_ is not None: - parent_dotted_order = dotted_order_map[parent_] - assert dotted_order.startswith( - parent_dotted_order - ), f"{name}, {parent_dotted_order} not in {dotted_order}" - assert str(parent_id_map[name]) == str(id_map[parent_]) - else: - assert dotted_order.split(".")[0] == dotted_order + @pytest.mark.skipif( + sys.version_info < (3, 11), reason="Asyncio context vars require Python 3.11+" + ) + @pytest.mark.parametrize("method", [ainvoke, astream, abatch]) + async def test_async( + self, + method: Callable[ + [RunnableLambda, list[BaseCallbackHandler]], Coroutine[Any, Any, int] + ], + ) -> None: + async def other_thing(a: int) -> AsyncGenerator[int, None]: + yield 1 + + parent = self._create_parent(other_thing) + + # Now run the chain and check the resulting posts + assert await method(parent, [self.tracer]) == 3 + + self._check_posts() @pytest.mark.parametrize("parent_type", ("ls", "lc")) From a57e4464a1196b95551b7b8dc8ab2fbcd6ecad84 Mon Sep 17 00:00:00 2001 From: Christophe Bornet Date: Fri, 10 Jan 2025 17:16:51 +0100 Subject: [PATCH 3/4] Fix aconfig_with_context asyncio.Event is not thread-safe so it must be created in the asyncio thread --- .../langchain_core/beta/runnables/context.py | 37 ++++++++++++------- 1 file changed, 24 insertions(+), 13 deletions(-) diff --git a/libs/core/langchain_core/beta/runnables/context.py b/libs/core/langchain_core/beta/runnables/context.py index 398ad488b8942..c6e26df2e3d8d 100644 --- a/libs/core/langchain_core/beta/runnables/context.py +++ b/libs/core/langchain_core/beta/runnables/context.py @@ -65,20 +65,11 @@ def _key_from_id(id_: str) -> str: def _config_with_context( config: RunnableConfig, - steps: list[Runnable], + context_specs: list[tuple[ConfigurableFieldSpec, int]], setter: Callable, getter: Callable, event_cls: Union[type[threading.Event], type[asyncio.Event]], ) -> RunnableConfig: - if any(k.startswith(CONTEXT_CONFIG_PREFIX) for k in config.get("configurable", {})): - return config - - context_specs = [ - (spec, i) - for i, step in enumerate(steps) - for spec in step.config_specs - if spec.id.startswith(CONTEXT_CONFIG_PREFIX) - ] grouped_by_key = { key: list(group) for key, group in groupby( @@ -134,8 +125,17 @@ async def aconfig_with_context( Returns: The patched runnable config. """ - return await asyncio.to_thread( - _config_with_context, config, steps, _asetter, _agetter, asyncio.Event + if any(k.startswith(CONTEXT_CONFIG_PREFIX) for k in config.get("configurable", {})): + return config + + context_specs = [ + (spec, i) + for i, step in enumerate(steps) + for spec in await asyncio.to_thread(getattr, step, "config_specs") + if spec.id.startswith(CONTEXT_CONFIG_PREFIX) + ] + return _config_with_context( + config, context_specs, _asetter, _agetter, asyncio.Event ) @@ -152,7 +152,18 @@ def config_with_context( Returns: The patched runnable config. """ - return _config_with_context(config, steps, _setter, _getter, threading.Event) + if any(k.startswith(CONTEXT_CONFIG_PREFIX) for k in config.get("configurable", {})): + return config + + context_specs = [ + (spec, i) + for i, step in enumerate(steps) + for spec in step.config_specs + if spec.id.startswith(CONTEXT_CONFIG_PREFIX) + ] + return _config_with_context( + config, context_specs, _setter, _getter, threading.Event + ) @beta() From 82ce644a3d41cca4546422ea0b92f76bcd055a82 Mon Sep 17 00:00:00 2001 From: Christophe Bornet Date: Mon, 13 Jan 2025 16:22:00 +0100 Subject: [PATCH 4/4] Fix tests --- libs/core/poetry.lock | 8 ++++---- libs/core/pyproject.toml | 2 +- libs/core/tests/unit_tests/conftest.py | 11 +++++++++-- .../language_models/chat_models/test_rate_limiting.py | 9 +++++++++ .../core/tests/unit_tests/runnables/test_fallbacks.py | 4 +--- .../tests/unit_tests/tracers/test_memory_stream.py | 4 +++- 6 files changed, 27 insertions(+), 11 deletions(-) diff --git a/libs/core/poetry.lock b/libs/core/poetry.lock index 2bdc229af92b0..d9c6ca4a8ac9f 100644 --- a/libs/core/poetry.lock +++ b/libs/core/poetry.lock @@ -222,13 +222,13 @@ css = ["tinycss2 (>=1.1.0,<1.5)"] [[package]] name = "blockbuster" -version = "1.5.8" +version = "1.5.9" description = "Utility to detect blocking calls in the async event loop" optional = false python-versions = ">=3.8" files = [ - {file = "blockbuster-1.5.8-py3-none-any.whl", hash = "sha256:ea0352823acbd837872785a96ec87701f1c7938410d66c6a8857ae11646d9744"}, - {file = "blockbuster-1.5.8.tar.gz", hash = "sha256:0018080fd735e84f0b138f15ff0a339f99444ecdc0045f16056c7b23db92706b"}, + {file = "blockbuster-1.5.9-py3-none-any.whl", hash = "sha256:1a3c43f1682866a8a9464c0850341ca72ab91c6e6a5fad1e8e1ce58590a7c98a"}, + {file = "blockbuster-1.5.9.tar.gz", hash = "sha256:72d5696425cf86a6413043c5da73b8f436b7ea442606c47cd724329ef2fa4a91"}, ] [package.dependencies] @@ -3162,4 +3162,4 @@ type = ["pytest-mypy"] [metadata] lock-version = "2.0" python-versions = ">=3.9,<4.0" -content-hash = "356040a2decb213bfe3a35a0f79cad3c9b10e7e941d43e309ab5e41e45b40093" +content-hash = "4ade2f7e47a95220acad02edcc7a7c7237d4b726e32632c7a30a2ceb8d369d4d" diff --git a/libs/core/pyproject.toml b/libs/core/pyproject.toml index a23fdc495e2d9..ffd1ab4574a71 100644 --- a/libs/core/pyproject.toml +++ b/libs/core/pyproject.toml @@ -109,7 +109,7 @@ grandalf = "^0.8" responses = "^0.25.0" pytest-socket = "^0.7.0" pytest-xdist = "^3.6.1" -blockbuster = "~1.5.8" +blockbuster = "~1.5.9" [[tool.poetry.group.test.dependencies.numpy]] version = "^1.24.0" python = "<3.12" diff --git a/libs/core/tests/unit_tests/conftest.py b/libs/core/tests/unit_tests/conftest.py index c1b8d3c43f978..f17adfa0c1e20 100644 --- a/libs/core/tests/unit_tests/conftest.py +++ b/libs/core/tests/unit_tests/conftest.py @@ -14,8 +14,10 @@ def blockbuster() -> Iterator[BlockBuster]: with blockbuster_ctx() as bb: for func in ["os.stat", "os.path.abspath"]: - bb.functions[func].can_block_in( - "langchain_core/_api/internal.py", "is_caller_internal" + ( + bb.functions[func] + .can_block_in("langchain_core/_api/internal.py", "is_caller_internal") + .can_block_in("langchain_core/runnables/base.py", "__repr__") ) for func in ["os.stat", "io.TextIOWrapper.read"]: @@ -23,6 +25,11 @@ def blockbuster() -> Iterator[BlockBuster]: "langsmith/client.py", "_default_retry_config" ) + for bb_function in bb.functions.values(): + bb_function.can_block_in( + "freezegun/api.py", "_get_cached_module_attributes" + ) + yield bb diff --git a/libs/core/tests/unit_tests/language_models/chat_models/test_rate_limiting.py b/libs/core/tests/unit_tests/language_models/chat_models/test_rate_limiting.py index 3547cc8af6b59..d0546736b288e 100644 --- a/libs/core/tests/unit_tests/language_models/chat_models/test_rate_limiting.py +++ b/libs/core/tests/unit_tests/language_models/chat_models/test_rate_limiting.py @@ -1,11 +1,20 @@ import time from typing import Optional as Optional +import pytest +from blockbuster import BlockBuster + from langchain_core.caches import InMemoryCache from langchain_core.language_models import GenericFakeChatModel from langchain_core.rate_limiters import InMemoryRateLimiter +@pytest.fixture(autouse=True) +def deactivate_blockbuster(blockbuster: BlockBuster) -> None: + # Deactivate BlockBuster to not disturb the rate limiter timings + blockbuster.deactivate() + + def test_rate_limit_invoke() -> None: """Add rate limiter.""" diff --git a/libs/core/tests/unit_tests/runnables/test_fallbacks.py b/libs/core/tests/unit_tests/runnables/test_fallbacks.py index 13f500d17dcbc..6dd6b6eb226c6 100644 --- a/libs/core/tests/unit_tests/runnables/test_fallbacks.py +++ b/libs/core/tests/unit_tests/runnables/test_fallbacks.py @@ -96,9 +96,7 @@ def test_fallbacks( "runnable", ["llm", "llm_multi", "chain", "chain_pass_exceptions"], ) -async def test_fallbacks_async( - runnable: RunnableWithFallbacks, request: Any, snapshot: SnapshotAssertion -) -> None: +async def test_fallbacks_async(runnable: RunnableWithFallbacks, request: Any) -> None: runnable = request.getfixturevalue(runnable) assert await runnable.ainvoke("hello") == "bar" assert await runnable.abatch(["hi", "hey", "bye"]) == ["bar"] * 3 diff --git a/libs/core/tests/unit_tests/tracers/test_memory_stream.py b/libs/core/tests/unit_tests/tracers/test_memory_stream.py index d49a619f12ff6..fd4a12e209efd 100644 --- a/libs/core/tests/unit_tests/tracers/test_memory_stream.py +++ b/libs/core/tests/unit_tests/tracers/test_memory_stream.py @@ -69,7 +69,7 @@ async def producer() -> None: """Produce items with slight delay.""" tic = time.time() for i in range(3): - await asyncio.sleep(0.10) + await asyncio.sleep(0.2) toc = time.time() await writer.send( { @@ -96,6 +96,8 @@ async def consumer() -> AsyncIterator[dict]: items = [item async for item in consumer()] await task + assert len(items) == 3 + for item in items: delta_time = item["receive_time"] - item["produce_time"] # Allow a generous 10ms of delay