diff --git a/libs/core/langchain_core/beta/runnables/context.py b/libs/core/langchain_core/beta/runnables/context.py index a53e8fdf57969..398ad488b8942 100644 --- a/libs/core/langchain_core/beta/runnables/context.py +++ b/libs/core/langchain_core/beta/runnables/context.py @@ -121,7 +121,7 @@ def _config_with_context( return patch_config(config, configurable=context_funcs) -def aconfig_with_context( +async def aconfig_with_context( config: RunnableConfig, steps: list[Runnable], ) -> RunnableConfig: @@ -134,7 +134,9 @@ def aconfig_with_context( Returns: The patched runnable config. """ - return _config_with_context(config, steps, _asetter, _agetter, asyncio.Event) + return await asyncio.to_thread( + _config_with_context, config, steps, _asetter, _agetter, asyncio.Event + ) def config_with_context( diff --git a/libs/core/langchain_core/runnables/base.py b/libs/core/langchain_core/runnables/base.py index 893f393d8b174..eabc0b8ed2c37 100644 --- a/libs/core/langchain_core/runnables/base.py +++ b/libs/core/langchain_core/runnables/base.py @@ -3037,7 +3037,7 @@ async def ainvoke( from langchain_core.beta.runnables.context import aconfig_with_context # setup callbacks and context - config = aconfig_with_context(ensure_config(config), self.steps) + config = await aconfig_with_context(ensure_config(config), self.steps) callback_manager = get_async_callback_manager_for_config(config) # start the root run run_manager = await callback_manager.on_chain_start( @@ -3214,7 +3214,7 @@ async def abatch( # setup callbacks and context configs = [ - aconfig_with_context(c, self.steps) + await aconfig_with_context(c, self.steps) for c in get_config_list(config, len(inputs)) ] callback_managers = [ @@ -3364,7 +3364,7 @@ async def _atransform( from langchain_core.beta.runnables.context import aconfig_with_context steps = [self.first] + self.middle + [self.last] - config = aconfig_with_context(config, self.steps) + config = await aconfig_with_context(config, self.steps) # stream the last steps # transform the input stream of each step with the next diff --git a/libs/core/poetry.lock b/libs/core/poetry.lock index 53962198d8297..5bcc88c6e05e4 100644 --- a/libs/core/poetry.lock +++ b/libs/core/poetry.lock @@ -1,4 +1,15 @@ -# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. + +[[package]] +name = "aiofiles" +version = "24.1.0" +description = "File support for asyncio." +optional = false +python-versions = ">=3.8" +files = [ + {file = "aiofiles-24.1.0-py3-none-any.whl", hash = "sha256:b4ec55f4195e3eb5d7abd1bf7e061763e864dd4954231fb8539a0ef8bb8260e5"}, + {file = "aiofiles-24.1.0.tar.gz", hash = "sha256:22a075c9e5a3810f0c2e48f3008c94d68c65d763b9b03857924c99e57355166c"}, +] [[package]] name = "annotated-types" @@ -220,6 +231,20 @@ webencodings = "*" [package.extras] css = ["tinycss2 (>=1.1.0,<1.5)"] +[[package]] +name = "blockbuster" +version = "1.5.8" +description = "Utility to detect blocking calls in the async event loop" +optional = false +python-versions = ">=3.8" +files = [ + {file = "blockbuster-1.5.8-py3-none-any.whl", hash = "sha256:ea0352823acbd837872785a96ec87701f1c7938410d66c6a8857ae11646d9744"}, + {file = "blockbuster-1.5.8.tar.gz", hash = "sha256:0018080fd735e84f0b138f15ff0a339f99444ecdc0045f16056c7b23db92706b"}, +] + +[package.dependencies] +forbiddenfruit = ">=0.1.4" + [[package]] name = "certifi" version = "2024.12.14" @@ -551,6 +576,16 @@ files = [ [package.extras] devel = ["colorama", "json-spec", "jsonschema", "pylint", "pytest", "pytest-benchmark", "pytest-cache", "validictory"] +[[package]] +name = "forbiddenfruit" +version = "0.1.4" +description = "Patch python built-in objects" +optional = false +python-versions = "*" +files = [ + {file = "forbiddenfruit-0.1.4.tar.gz", hash = "sha256:e3f7e66561a29ae129aac139a85d610dbf3dd896128187ed5454b6421f624253"}, +] + [[package]] name = "fqdn" version = "1.5.1" @@ -3104,4 +3139,4 @@ type = ["pytest-mypy"] [metadata] lock-version = "2.0" python-versions = ">=3.9,<4.0" -content-hash = "5accfdfd412486fbf7bb3ef18f00e75db40599034428651ef014b0bc3927ddfa" +content-hash = "052d2f4a12a0abec317943c7492d882699f06f15548b3524d011d58d1ed94e92" diff --git a/libs/core/pyproject.toml b/libs/core/pyproject.toml index 2bba5c7ff609d..d48185ce13f6f 100644 --- a/libs/core/pyproject.toml +++ b/libs/core/pyproject.toml @@ -105,6 +105,8 @@ pytest-asyncio = "^0.21.1" grandalf = "^0.8" responses = "^0.25.0" pytest-socket = "^0.7.0" +blockbuster = "~1.5.8" +aiofiles = "^24.1.0" [[tool.poetry.group.test.dependencies.numpy]] version = "^1.24.0" python = "<3.12" diff --git a/libs/core/tests/unit_tests/conftest.py b/libs/core/tests/unit_tests/conftest.py index 29819a8066958..050aa4b8e175b 100644 --- a/libs/core/tests/unit_tests/conftest.py +++ b/libs/core/tests/unit_tests/conftest.py @@ -5,10 +5,27 @@ from uuid import UUID import pytest +from blockbuster import blockbuster_ctx from pytest import Config, Function, Parser from pytest_mock import MockerFixture +@pytest.fixture(autouse=True) +def blockbuster(request): + with blockbuster_ctx() as bb: + for func in ["os.stat", "os.path.abspath"]: + bb.functions[func].can_block_in( + "langchain_core/_api/internal.py", "is_caller_internal" + ) + + for func in ["os.stat", "io.TextIOWrapper.read"]: + bb.functions[func].can_block_in( + "langsmith/client.py", "_default_retry_config" + ) + + yield bb + + def pytest_addoption(parser: Parser) -> None: """Add custom command line options to pytest.""" parser.addoption( diff --git a/libs/core/tests/unit_tests/fake/test_fake_chat_model.py b/libs/core/tests/unit_tests/fake/test_fake_chat_model.py index 7502e17c50fde..5d7d4525e3ffe 100644 --- a/libs/core/tests/unit_tests/fake/test_fake_chat_model.py +++ b/libs/core/tests/unit_tests/fake/test_fake_chat_model.py @@ -191,7 +191,12 @@ async def on_llm_new_token( model = GenericFakeChatModel(messages=infinite_cycle) tokens: list[str] = [] # New model - results = list(model.stream("meow", {"callbacks": [MyCustomAsyncHandler(tokens)]})) + results = [ + chunk + async for chunk in model.astream( + "meow", {"callbacks": [MyCustomAsyncHandler(tokens)]} + ) + ] assert results == [ _any_id_ai_message_chunk(content="hello"), _any_id_ai_message_chunk(content=" "), diff --git a/libs/core/tests/unit_tests/language_models/chat_models/test_base.py b/libs/core/tests/unit_tests/language_models/chat_models/test_base.py index 2cd08a27d0383..16305eaa21fc8 100644 --- a/libs/core/tests/unit_tests/language_models/chat_models/test_base.py +++ b/libs/core/tests/unit_tests/language_models/chat_models/test_base.py @@ -18,6 +18,7 @@ ) from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult from langchain_core.outputs.llm_result import LLMResult +from langchain_core.tracers import LogStreamCallbackHandler from langchain_core.tracers.base import BaseTracer from langchain_core.tracers.context import collect_runs from langchain_core.tracers.event_stream import _AstreamEventsCallbackHandler @@ -304,39 +305,48 @@ def _stream( @pytest.mark.parametrize("disable_streaming", [True, False, "tool_calling"]) -async def test_disable_streaming( +def test_disable_streaming( disable_streaming: Union[bool, Literal["tool_calling"]], ) -> None: model = StreamingModel(disable_streaming=disable_streaming) assert model.invoke([]).content == "invoke" - assert (await model.ainvoke([])).content == "invoke" expected = "invoke" if disable_streaming is True else "stream" assert next(model.stream([])).content == expected - async for c in model.astream([]): - assert c.content == expected - break + assert ( + model.invoke([], config={"callbacks": [LogStreamCallbackHandler()]}).content + == expected + ) + + expected = "invoke" if disable_streaming in ("tool_calling", True) else "stream" + assert next(model.stream([], tools=[{"type": "function"}])).content == expected assert ( model.invoke( - [], config={"callbacks": [_AstreamEventsCallbackHandler()]} + [], config={"callbacks": [LogStreamCallbackHandler()]}, tools=[{}] ).content == expected ) + + +@pytest.mark.parametrize("disable_streaming", [True, False, "tool_calling"]) +async def test_disable_streaming_async( + disable_streaming: Union[bool, Literal["tool_calling"]], +) -> None: + model = StreamingModel(disable_streaming=disable_streaming) + assert (await model.ainvoke([])).content == "invoke" + + expected = "invoke" if disable_streaming is True else "stream" + async for c in model.astream([]): + assert c.content == expected + break assert ( await model.ainvoke([], config={"callbacks": [_AstreamEventsCallbackHandler()]}) ).content == expected expected = "invoke" if disable_streaming in ("tool_calling", True) else "stream" - assert next(model.stream([], tools=[{"type": "function"}])).content == expected async for c in model.astream([], tools=[{}]): assert c.content == expected break - assert ( - model.invoke( - [], config={"callbacks": [_AstreamEventsCallbackHandler()]}, tools=[{}] - ).content - == expected - ) assert ( await model.ainvoke( [], config={"callbacks": [_AstreamEventsCallbackHandler()]}, tools=[{}] @@ -345,26 +355,31 @@ async def test_disable_streaming( @pytest.mark.parametrize("disable_streaming", [True, False, "tool_calling"]) -async def test_disable_streaming_no_streaming_model( +def test_disable_streaming_no_streaming_model( disable_streaming: Union[bool, Literal["tool_calling"]], ) -> None: model = NoStreamingModel(disable_streaming=disable_streaming) assert model.invoke([]).content == "invoke" - assert (await model.ainvoke([])).content == "invoke" assert next(model.stream([])).content == "invoke" - async for c in model.astream([]): - assert c.content == "invoke" - break assert ( - model.invoke( - [], config={"callbacks": [_AstreamEventsCallbackHandler()]} - ).content + model.invoke([], config={"callbacks": [LogStreamCallbackHandler()]}).content == "invoke" ) + assert next(model.stream([], tools=[{}])).content == "invoke" + + +@pytest.mark.parametrize("disable_streaming", [True, False, "tool_calling"]) +async def test_disable_streaming_no_streaming_model_async( + disable_streaming: Union[bool, Literal["tool_calling"]], +) -> None: + model = NoStreamingModel(disable_streaming=disable_streaming) + assert (await model.ainvoke([])).content == "invoke" + async for c in model.astream([]): + assert c.content == "invoke" + break assert ( await model.ainvoke([], config={"callbacks": [_AstreamEventsCallbackHandler()]}) ).content == "invoke" - assert next(model.stream([], tools=[{}])).content == "invoke" async for c in model.astream([], tools=[{}]): assert c.content == "invoke" break diff --git a/libs/core/tests/unit_tests/prompts/test_chat.py b/libs/core/tests/unit_tests/prompts/test_chat.py index 6249aa6f47893..0d76f06bd0d05 100644 --- a/libs/core/tests/unit_tests/prompts/test_chat.py +++ b/libs/core/tests/unit_tests/prompts/test_chat.py @@ -1,9 +1,9 @@ import base64 -import tempfile import warnings from pathlib import Path from typing import Any, Union, cast +import aiofiles import pytest from pydantic import ValidationError from syrupy import SnapshotAssertion @@ -724,9 +724,11 @@ async def test_chat_tmpl_from_messages_multipart_formatting_with_path() -> None: in_mem = "base64mem" in_file_data = "base64file01" - with tempfile.NamedTemporaryFile(delete=True, suffix=".jpg") as temp_file: - temp_file.write(base64.b64decode(in_file_data)) - temp_file.flush() + async with aiofiles.tempfile.NamedTemporaryFile( + delete=True, suffix=".jpg" + ) as temp_file: + await temp_file.write(base64.b64decode(in_file_data)) + await temp_file.flush() template = ChatPromptTemplate.from_messages( [ diff --git a/libs/core/tests/unit_tests/runnables/test_context.py b/libs/core/tests/unit_tests/runnables/test_context.py index c00eb999424cb..cb8de2dd8088b 100644 --- a/libs/core/tests/unit_tests/runnables/test_context.py +++ b/libs/core/tests/unit_tests/runnables/test_context.py @@ -1,3 +1,4 @@ +import asyncio from typing import Any, Callable, NamedTuple, Union import pytest @@ -330,19 +331,26 @@ def seq_naive_rag_scoped() -> Runnable: @pytest.mark.parametrize("runnable, cases", test_cases) -async def test_context_runnables( +def test_context_runnables( runnable: Union[Runnable, Callable[[], Runnable]], cases: list[_TestCase] ) -> None: runnable = runnable if isinstance(runnable, Runnable) else runnable() assert runnable.invoke(cases[0].input) == cases[0].output - assert await runnable.ainvoke(cases[1].input) == cases[1].output assert runnable.batch([case.input for case in cases]) == [ case.output for case in cases ] + assert add(runnable.stream(cases[0].input)) == cases[0].output + + +@pytest.mark.parametrize("runnable, cases", test_cases) +async def test_context_runnables_async( + runnable: Union[Runnable, Callable[[], Runnable]], cases: list[_TestCase] +) -> None: + runnable = runnable if isinstance(runnable, Runnable) else runnable() + assert await runnable.ainvoke(cases[1].input) == cases[1].output assert await runnable.abatch([case.input for case in cases]) == [ case.output for case in cases ] - assert add(runnable.stream(cases[0].input)) == cases[0].output assert await aadd(runnable.astream(cases[1].input)) == cases[1].output @@ -390,8 +398,7 @@ async def test_runnable_seq_streaming_chunks() -> None: "prompt": Context.getter("prompt"), } ) - - chunks = list(chain.stream({"foo": "foo", "bar": "bar"})) + chunks = await asyncio.to_thread(list, chain.stream({"foo": "foo", "bar": "bar"})) achunks = [c async for c in chain.astream({"foo": "foo", "bar": "bar"})] for c in chunks: assert c in achunks diff --git a/libs/core/tests/unit_tests/runnables/test_fallbacks.py b/libs/core/tests/unit_tests/runnables/test_fallbacks.py index 731b3ddaa62aa..13f500d17dcbc 100644 --- a/libs/core/tests/unit_tests/runnables/test_fallbacks.py +++ b/libs/core/tests/unit_tests/runnables/test_fallbacks.py @@ -82,17 +82,27 @@ def chain_pass_exceptions() -> Runnable: "runnable", ["llm", "llm_multi", "chain", "chain_pass_exceptions"], ) -async def test_fallbacks( +def test_fallbacks( runnable: RunnableWithFallbacks, request: Any, snapshot: SnapshotAssertion ) -> None: runnable = request.getfixturevalue(runnable) assert runnable.invoke("hello") == "bar" assert runnable.batch(["hi", "hey", "bye"]) == ["bar"] * 3 assert list(runnable.stream("hello")) == ["bar"] + assert dumps(runnable, pretty=True) == snapshot + + +@pytest.mark.parametrize( + "runnable", + ["llm", "llm_multi", "chain", "chain_pass_exceptions"], +) +async def test_fallbacks_async( + runnable: RunnableWithFallbacks, request: Any, snapshot: SnapshotAssertion +) -> None: + runnable = request.getfixturevalue(runnable) assert await runnable.ainvoke("hello") == "bar" assert await runnable.abatch(["hi", "hey", "bye"]) == ["bar"] * 3 assert list(await runnable.ainvoke("hello")) == list("bar") - assert dumps(runnable, pretty=True) == snapshot def _runnable(inputs: dict) -> str: diff --git a/libs/core/tests/unit_tests/runnables/test_runnable.py b/libs/core/tests/unit_tests/runnables/test_runnable.py index b06cb381e80e6..123fa3d0fc184 100644 --- a/libs/core/tests/unit_tests/runnables/test_runnable.py +++ b/libs/core/tests/unit_tests/runnables/test_runnable.py @@ -1,3 +1,4 @@ +import asyncio import sys import uuid import warnings @@ -1011,20 +1012,20 @@ def test_configurable_fields_example(snapshot: SnapshotAssertion) -> None: ) -async def test_passthrough_tap_async(mocker: MockerFixture) -> None: +def test_passthrough_tap(mocker: MockerFixture) -> None: fake = FakeRunnable() mock = mocker.Mock() seq: Runnable = RunnablePassthrough(mock) | fake | RunnablePassthrough(mock) - assert await seq.ainvoke("hello", my_kwarg="value") == 5 + assert seq.invoke("hello", my_kwarg="value") == 5 # type: ignore[call-arg] assert mock.call_args_list == [ mocker.call("hello", my_kwarg="value"), mocker.call(5), ] mock.reset_mock() - assert await seq.abatch(["hello", "byebye"], my_kwarg="value") == [5, 6] + assert seq.batch(["hello", "byebye"], my_kwarg="value") == [5, 6] assert len(mock.call_args_list) == 4 for call in [ mocker.call("hello", my_kwarg="value"), @@ -1035,9 +1036,7 @@ async def test_passthrough_tap_async(mocker: MockerFixture) -> None: assert call in mock.call_args_list mock.reset_mock() - assert await seq.abatch( - ["hello", "byebye"], my_kwarg="value", return_exceptions=True - ) == [ + assert seq.batch(["hello", "byebye"], my_kwarg="value", return_exceptions=True) == [ 5, 6, ] @@ -1052,12 +1051,10 @@ async def test_passthrough_tap_async(mocker: MockerFixture) -> None: mock.reset_mock() assert sorted( - [ - a - async for a in seq.abatch_as_completed( - ["hello", "byebye"], my_kwarg="value", return_exceptions=True - ) - ] + a + for a in seq.batch_as_completed( + ["hello", "byebye"], my_kwarg="value", return_exceptions=True + ) ) == [ (0, 5), (1, 6), @@ -1072,26 +1069,30 @@ async def test_passthrough_tap_async(mocker: MockerFixture) -> None: assert call in mock.call_args_list mock.reset_mock() - assert [ - part - async for part in seq.astream( - "hello", {"metadata": {"key": "value"}}, my_kwarg="value" - ) - ] == [5] + assert list( + seq.stream("hello", {"metadata": {"key": "value"}}, my_kwarg="value") + ) == [5] assert mock.call_args_list == [ mocker.call("hello", my_kwarg="value"), mocker.call(5), ] mock.reset_mock() - assert seq.invoke("hello", my_kwarg="value") == 5 # type: ignore[call-arg] + +async def test_passthrough_tap_async(mocker: MockerFixture) -> None: + fake = FakeRunnable() + mock = mocker.Mock() + + seq: Runnable = RunnablePassthrough(mock) | fake | RunnablePassthrough(mock) + + assert await seq.ainvoke("hello", my_kwarg="value") == 5 assert mock.call_args_list == [ mocker.call("hello", my_kwarg="value"), mocker.call(5), ] mock.reset_mock() - assert seq.batch(["hello", "byebye"], my_kwarg="value") == [5, 6] + assert await seq.abatch(["hello", "byebye"], my_kwarg="value") == [5, 6] assert len(mock.call_args_list) == 4 for call in [ mocker.call("hello", my_kwarg="value"), @@ -1102,7 +1103,9 @@ async def test_passthrough_tap_async(mocker: MockerFixture) -> None: assert call in mock.call_args_list mock.reset_mock() - assert seq.batch(["hello", "byebye"], my_kwarg="value", return_exceptions=True) == [ + assert await seq.abatch( + ["hello", "byebye"], my_kwarg="value", return_exceptions=True + ) == [ 5, 6, ] @@ -1117,10 +1120,12 @@ async def test_passthrough_tap_async(mocker: MockerFixture) -> None: mock.reset_mock() assert sorted( - a - for a in seq.batch_as_completed( - ["hello", "byebye"], my_kwarg="value", return_exceptions=True - ) + [ + a + async for a in seq.abatch_as_completed( + ["hello", "byebye"], my_kwarg="value", return_exceptions=True + ) + ] ) == [ (0, 5), (1, 6), @@ -1135,14 +1140,16 @@ async def test_passthrough_tap_async(mocker: MockerFixture) -> None: assert call in mock.call_args_list mock.reset_mock() - assert list( - seq.stream("hello", {"metadata": {"key": "value"}}, my_kwarg="value") - ) == [5] + assert [ + part + async for part in seq.astream( + "hello", {"metadata": {"key": "value"}}, my_kwarg="value" + ) + ] == [5] assert mock.call_args_list == [ mocker.call("hello", my_kwarg="value"), mocker.call(5), ] - mock.reset_mock() async def test_with_config_metadata_passthrough(mocker: MockerFixture) -> None: @@ -1173,7 +1180,7 @@ async def test_with_config_metadata_passthrough(mocker: MockerFixture) -> None: spy.reset_mock() -async def test_with_config(mocker: MockerFixture) -> None: +def test_with_config(mocker: MockerFixture) -> None: fake = FakeRunnable() spy = mocker.spy(fake, "invoke") @@ -1279,7 +1286,11 @@ async def test_with_config(mocker: MockerFixture) -> None: for i, call in enumerate(spy.call_args_list): assert call.args[0] == ("hello" if i == 0 else "wooorld") assert call.args[1].get("tags") == ["a-tag"] - spy.reset_mock() + + +async def test_with_config_async(mocker: MockerFixture) -> None: + fake = FakeRunnable() + spy = mocker.spy(fake, "invoke") handler = ConsoleCallbackHandler() assert ( @@ -1375,7 +1386,7 @@ async def test_with_config(mocker: MockerFixture) -> None: ) -async def test_default_method_implementations(mocker: MockerFixture) -> None: +def test_default_method_implementations(mocker: MockerFixture) -> None: fake = FakeRunnable() spy = mocker.spy(fake, "invoke") @@ -1416,7 +1427,11 @@ async def test_default_method_implementations(mocker: MockerFixture) -> None: for call in spy.call_args_list: assert call.args[1].get("tags") == ["a-tag"] assert call.args[1].get("metadata") == {} - spy.reset_mock() + + +async def test_default_method_implementations_async(mocker: MockerFixture) -> None: + fake = FakeRunnable() + spy = mocker.spy(fake, "invoke") assert await fake.ainvoke("hello", config={"callbacks": []}) == 5 assert spy.call_args_list == [ @@ -1445,7 +1460,7 @@ async def test_default_method_implementations(mocker: MockerFixture) -> None: } -async def test_prompt() -> None: +def test_prompt() -> None: prompt = ChatPromptTemplate.from_messages( messages=[ SystemMessage(content="You are a nice assistant."), @@ -1478,6 +1493,21 @@ async def test_prompt() -> None: assert [*prompt.stream({"question": "What is your name?"})] == [expected] + +async def test_prompt_async() -> None: + prompt = ChatPromptTemplate.from_messages( + messages=[ + SystemMessage(content="You are a nice assistant."), + HumanMessagePromptTemplate.from_template("{question}"), + ] + ) + expected = ChatPromptValue( + messages=[ + SystemMessage(content="You are a nice assistant."), + HumanMessage(content="What is your name?"), + ] + ) + assert await prompt.ainvoke({"question": "What is your name?"}) == expected assert await prompt.abatch( @@ -2773,9 +2803,7 @@ def test_seq_prompt_dict(mocker: MockerFixture, snapshot: SnapshotAssertion) -> @freeze_time("2023-01-01") -async def test_router_runnable( - mocker: MockerFixture, snapshot: SnapshotAssertion -) -> None: +def test_router_runnable(mocker: MockerFixture, snapshot: SnapshotAssertion) -> None: chain1: Runnable = ChatPromptTemplate.from_template( "You are a math genius. Answer the question: {question}" ) | FakeListLLM(responses=["4"]) @@ -2797,14 +2825,6 @@ async def test_router_runnable( ) assert result2 == ["4", "2"] - result = await chain.ainvoke({"key": "math", "question": "2 + 2"}) - assert result == "4" - - result2 = await chain.abatch( - [{"key": "math", "question": "2 + 2"}, {"key": "english", "question": "2 + 2"}] - ) - assert result2 == ["4", "2"] - # Test invoke router_spy = mocker.spy(router.__class__, "invoke") tracer = FakeTracer() @@ -2824,8 +2844,30 @@ async def test_router_runnable( assert len(router_run.child_runs) == 2 +async def test_router_runnable_async() -> None: + chain1: Runnable = ChatPromptTemplate.from_template( + "You are a math genius. Answer the question: {question}" + ) | FakeListLLM(responses=["4"]) + chain2: Runnable = ChatPromptTemplate.from_template( + "You are an english major. Answer the question: {question}" + ) | FakeListLLM(responses=["2"]) + router: Runnable = RouterRunnable({"math": chain1, "english": chain2}) + chain: Runnable = { + "key": lambda x: x["key"], + "input": {"question": lambda x: x["question"]}, + } | router + + result = await chain.ainvoke({"key": "math", "question": "2 + 2"}) + assert result == "4" + + result2 = await chain.abatch( + [{"key": "math", "question": "2 + 2"}, {"key": "english", "question": "2 + 2"}] + ) + assert result2 == ["4", "2"] + + @freeze_time("2023-01-01") -async def test_higher_order_lambda_runnable( +def test_higher_order_lambda_runnable( mocker: MockerFixture, snapshot: SnapshotAssertion ) -> None: math_chain: Runnable = ChatPromptTemplate.from_template( @@ -2859,14 +2901,6 @@ def router(input: dict[str, Any]) -> Runnable: ) assert result2 == ["4", "2"] - result = await chain.ainvoke({"key": "math", "question": "2 + 2"}) - assert result == "4" - - result2 = await chain.abatch( - [{"key": "math", "question": "2 + 2"}, {"key": "english", "question": "2 + 2"}] - ) - assert result2 == ["4", "2"] - # Test invoke math_spy = mocker.spy(math_chain.__class__, "invoke") tracer = FakeTracer() @@ -2888,6 +2922,38 @@ def router(input: dict[str, Any]) -> Runnable: assert math_run.name == "RunnableSequence" assert len(math_run.child_runs) == 3 + +async def test_higher_order_lambda_runnable_async(mocker: MockerFixture) -> None: + math_chain: Runnable = ChatPromptTemplate.from_template( + "You are a math genius. Answer the question: {question}" + ) | FakeListLLM(responses=["4"]) + english_chain: Runnable = ChatPromptTemplate.from_template( + "You are an english major. Answer the question: {question}" + ) | FakeListLLM(responses=["2"]) + input_map: Runnable = RunnableParallel( + key=lambda x: x["key"], + input={"question": lambda x: x["question"]}, + ) + + def router(input: dict[str, Any]) -> Runnable: + if input["key"] == "math": + return itemgetter("input") | math_chain + elif input["key"] == "english": + return itemgetter("input") | english_chain + else: + msg = f"Unknown key: {input['key']}" + raise ValueError(msg) + + chain: Runnable = input_map | router + + result = await chain.ainvoke({"key": "math", "question": "2 + 2"}) + assert result == "4" + + result2 = await chain.abatch( + [{"key": "math", "question": "2 + 2"}, {"key": "english", "question": "2 + 2"}] + ) + assert result2 == ["4", "2"] + # Test ainvoke async def arouter(input: dict[str, Any]) -> Runnable: if input["key"] == "math": @@ -4643,7 +4709,7 @@ async def test_tool_from_runnable() -> None: } -async def test_runnable_gen() -> None: +def test_runnable_gen() -> None: """Test that a generator can be used as a runnable.""" def gen(input: Iterator[Any]) -> Iterator[int]: @@ -4663,6 +4729,10 @@ def gen(input: Iterator[Any]) -> Iterator[int]: assert list(runnable.stream(None)) == [1, 2, 3] assert runnable.batch([None, None]) == [6, 6] + +async def test_runnable_gen_async() -> None: + """Test that a generator can be used as a runnable.""" + async def agen(input: AsyncIterator[Any]) -> AsyncIterator[int]: yield 1 yield 2 @@ -4685,14 +4755,14 @@ async def __call__(self, input: AsyncIterator[Any]) -> AsyncIterator[int]: assert [p async for p in arunnablecallable.astream(None)] == [1, 2, 3] assert await arunnablecallable.abatch([None, None]) == [6, 6] with pytest.raises(NotImplementedError): - arunnablecallable.invoke(None) + await asyncio.to_thread(arunnablecallable.invoke, None) with pytest.raises(NotImplementedError): - arunnablecallable.stream(None) + await asyncio.to_thread(arunnablecallable.stream, None) with pytest.raises(NotImplementedError): - arunnablecallable.batch([None, None]) + await asyncio.to_thread(arunnablecallable.batch, [None, None]) -async def test_runnable_gen_context_config() -> None: +def test_runnable_gen_context_config() -> None: """Test that a generator can call other runnables with config propagated from the context.""" @@ -4761,9 +4831,16 @@ def gen(input: Iterator[Any]) -> Iterator[int]: assert [r.inputs["input"] for r in tracer.runs[1].child_runs] == ["a", "aa", "aaa"] assert [(r.outputs or {})["output"] for r in tracer.runs[1].child_runs] == [1, 2, 3] - if sys.version_info < (3, 11): - # Python 3.10 and below don't support running async tasks in a specific context - return + +@pytest.mark.skipif( + sys.version_info < (3, 11), + reason="Python 3.10 and below don't support running async tasks in a specific context", +) +async def test_runnable_gen_context_config_async() -> None: + """Test that a generator can call other runnables with config + propagated from the context.""" + + fake = RunnableLambda(len) async def agen(input: AsyncIterator[Any]) -> AsyncIterator[int]: yield await fake.ainvoke("a") @@ -4827,7 +4904,7 @@ async def agen(input: AsyncIterator[Any]) -> AsyncIterator[int]: assert [(r.outputs or {})["output"] for r in tracer.runs[1].child_runs] == [1, 2, 3] -async def test_runnable_iter_context_config() -> None: +def test_runnable_iter_context_config() -> None: """Test that a generator can call other runnables with config propagated from the context.""" @@ -4880,9 +4957,16 @@ def gen(input: str) -> Iterator[int]: assert [r.inputs["input"] for r in tracer.runs[1].child_runs] == ["a", "aa", "aaa"] assert [(r.outputs or {})["output"] for r in tracer.runs[1].child_runs] == [1, 2, 3] - if sys.version_info < (3, 11): - # Python 3.10 and below don't support running async tasks in a specific context - return + +@pytest.mark.skipif( + sys.version_info < (3, 11), + reason="Python 3.10 and below don't support running async tasks in a specific context", +) +async def test_runnable_iter_context_config_async() -> None: + """Test that a generator can call other runnables with config + propagated from the context.""" + + fake = RunnableLambda(len) @chain async def agen(input: str) -> AsyncIterator[int]: @@ -4944,7 +5028,7 @@ async def agen(input: str) -> AsyncIterator[int]: assert [(r.outputs or {})["output"] for r in tracer.runs[1].child_runs] == [1, 2, 3] -async def test_runnable_lambda_context_config() -> None: +def test_runnable_lambda_context_config() -> None: """Test that a function can call other runnables with config propagated from the context.""" @@ -4995,9 +5079,16 @@ def fun(input: str) -> int: assert [r.inputs["input"] for r in tracer.runs[1].child_runs] == ["a", "aa", "aaa"] assert [(r.outputs or {})["output"] for r in tracer.runs[1].child_runs] == [1, 2, 3] - if sys.version_info < (3, 11): - # Python 3.10 and below don't support running async tasks in a specific context - return + +@pytest.mark.skipif( + sys.version_info < (3, 11), + reason="Python 3.10 and below don't support running async tasks in a specific context", +) +async def test_runnable_lambda_context_config_async() -> None: + """Test that a function can call other runnables with config + propagated from the context.""" + + fake = RunnableLambda(len) @chain async def afun(input: str) -> int: diff --git a/libs/core/tests/unit_tests/runnables/test_runnable_events_v1.py b/libs/core/tests/unit_tests/runnables/test_runnable_events_v1.py index 7389d887769d8..89530594d828e 100644 --- a/libs/core/tests/unit_tests/runnables/test_runnable_events_v1.py +++ b/libs/core/tests/unit_tests/runnables/test_runnable_events_v1.py @@ -1,5 +1,6 @@ """Module that contains tests for runnable.astream_events API.""" +import asyncio import sys from collections.abc import AsyncIterator, Sequence from itertools import cycle @@ -1948,9 +1949,12 @@ def get_by_session_id(session_id: str) -> BaseChatMessageHistory: ] } - with_message_history.with_config( - {"configurable": {"session_id": "session-123"}} - ).invoke({"question": "meow"}) + await asyncio.to_thread( + with_message_history.with_config( + {"configurable": {"session_id": "session-123"}} + ).invoke, + {"question": "meow"}, + ) assert store == { "session-123": [ HumanMessage(content="hello"), diff --git a/libs/core/tests/unit_tests/runnables/test_runnable_events_v2.py b/libs/core/tests/unit_tests/runnables/test_runnable_events_v2.py index 8ceb4bf38b5f1..a7071835786a7 100644 --- a/libs/core/tests/unit_tests/runnables/test_runnable_events_v2.py +++ b/libs/core/tests/unit_tests/runnables/test_runnable_events_v2.py @@ -38,7 +38,9 @@ chain, ensure_config, ) -from langchain_core.runnables.config import get_callback_manager_for_config +from langchain_core.runnables.config import ( + get_async_callback_manager_for_config, +) from langchain_core.runnables.history import RunnableWithMessageHistory from langchain_core.runnables.schema import StreamEvent from langchain_core.runnables.utils import Input, Output @@ -1914,9 +1916,12 @@ def _get_output_messages(*args, **kwargs): # type: ignore ] } - with_message_history.with_config( - {"configurable": {"session_id": "session-123"}} - ).invoke({"question": "meow"}) + await asyncio.to_thread( + with_message_history.with_config( + {"configurable": {"session_id": "session-123"}} + ).invoke, + {"question": "meow"}, + ) assert store == { "session-123": [ HumanMessage(content="hello"), @@ -1986,8 +1991,9 @@ def _get_output_messages(*args, **kwargs): # type: ignore ] -async def test_sync_in_async_stream_lambdas() -> None: +async def test_sync_in_async_stream_lambdas(blockbuster) -> None: """Test invoking nested runnable lambda.""" + blockbuster.deactivate() def add_one(x: int) -> int: return x + 1 @@ -2076,8 +2082,8 @@ async def astream( **kwargs: Optional[Any], ) -> AsyncIterator[Output]: config = ensure_config(config) - callback_manager = get_callback_manager_for_config(config) - run_manager = callback_manager.on_chain_start( + callback_manager = get_async_callback_manager_for_config(config) + run_manager = await callback_manager.on_chain_start( None, input, name=config.get("run_name", self.get_name()), @@ -2100,9 +2106,9 @@ async def astream( final_output = element # set final channel values as run output - run_manager.on_chain_end(final_output) + await run_manager.on_chain_end(final_output) except BaseException as e: - run_manager.on_chain_error(e) + await run_manager.on_chain_error(e) raise diff --git a/libs/core/tests/unit_tests/runnables/test_tracing_interops.py b/libs/core/tests/unit_tests/runnables/test_tracing_interops.py index 0743929f86120..eb6ff8de81d91 100644 --- a/libs/core/tests/unit_tests/runnables/test_tracing_interops.py +++ b/libs/core/tests/unit_tests/runnables/test_tracing_interops.py @@ -1,3 +1,4 @@ +import asyncio import json import sys import uuid @@ -298,17 +299,17 @@ def parent(a: int) -> int: # Now run the chain and check the resulting posts cb = [tracer] if method == "invoke": - res: Any = parent.invoke(1, {"callbacks": cb}) # type: ignore + res: Any = await asyncio.to_thread(parent.invoke, 1, {"callbacks": cb}) # type: ignore elif method == "ainvoke": res = await parent.ainvoke(1, {"callbacks": cb}) # type: ignore elif method == "stream": - results = list(parent.stream(1, {"callbacks": cb})) # type: ignore + results = await asyncio.to_thread(list, parent.stream(1, {"callbacks": cb})) # type: ignore res = results[-1] elif method == "astream": results = [res async for res in parent.astream(1, {"callbacks": cb})] # type: ignore res = results[-1] elif method == "batch": - res = parent.batch([1], {"callbacks": cb})[0] # type: ignore + res = (await asyncio.to_thread(parent.batch, [1], {"callbacks": cb}))[0] # type: ignore elif method == "abatch": res = (await parent.abatch([1], {"callbacks": cb}))[0] # type: ignore else: diff --git a/libs/core/tests/unit_tests/tracers/test_memory_stream.py b/libs/core/tests/unit_tests/tracers/test_memory_stream.py index 451ab35bb678e..d49a619f12ff6 100644 --- a/libs/core/tests/unit_tests/tracers/test_memory_stream.py +++ b/libs/core/tests/unit_tests/tracers/test_memory_stream.py @@ -2,7 +2,6 @@ import math import time from collections.abc import AsyncIterator -from concurrent.futures import ThreadPoolExecutor from langchain_core.tracers.memory_stream import _MemoryStream @@ -93,9 +92,9 @@ async def consumer() -> AsyncIterator[dict]: **item, } - with ThreadPoolExecutor() as executor: - executor.submit(sync_call) - items = [item async for item in consumer()] + task = asyncio.create_task(asyncio.to_thread(sync_call)) + items = [item async for item in consumer()] + await task for item in items: delta_time = item["receive_time"] - item["produce_time"] diff --git a/libs/core/tests/unit_tests/vectorstores/test_in_memory.py b/libs/core/tests/unit_tests/vectorstores/test_in_memory.py index e76f843616be8..bc030ae0a41c6 100644 --- a/libs/core/tests/unit_tests/vectorstores/test_in_memory.py +++ b/libs/core/tests/unit_tests/vectorstores/test_in_memory.py @@ -1,3 +1,4 @@ +import asyncio from pathlib import Path from unittest.mock import AsyncMock, Mock @@ -90,9 +91,11 @@ async def test_inmemory_dump_load(tmp_path: Path) -> None: output = await store.asimilarity_search("foo", k=1) test_file = str(tmp_path / "test.json") - store.dump(test_file) + await asyncio.to_thread(store.dump, test_file) - loaded_store = InMemoryVectorStore.load(test_file, embedding) + loaded_store = await asyncio.to_thread( + InMemoryVectorStore.load, test_file, embedding + ) loaded_output = await loaded_store.asimilarity_search("foo", k=1) assert output == loaded_output