diff --git a/libs/langchain/langchain/schema/runnable/__init__.py b/libs/langchain/langchain/schema/runnable/__init__.py index 65da79c489b52..32321e7fc4275 100644 --- a/libs/langchain/langchain/schema/runnable/__init__.py +++ b/libs/langchain/langchain/schema/runnable/__init__.py @@ -14,7 +14,6 @@ This module contains schema and implementation of LangChain Runnables primitives. """ -from langchain.schema.runnable._locals import GetLocalVar, PutLocalVar from langchain.schema.runnable.base import ( Runnable, RunnableBinding, @@ -40,9 +39,7 @@ "ConfigurableField", "ConfigurableFieldSingleOption", "ConfigurableFieldMultiOption", - "GetLocalVar", "patch_config", - "PutLocalVar", "RouterInput", "RouterRunnable", "Runnable", diff --git a/libs/langchain/langchain/schema/runnable/_locals.py b/libs/langchain/langchain/schema/runnable/_locals.py deleted file mode 100644 index cfe1f76aa07ba..0000000000000 --- a/libs/langchain/langchain/schema/runnable/_locals.py +++ /dev/null @@ -1,168 +0,0 @@ -from __future__ import annotations - -from typing import ( - TYPE_CHECKING, - Any, - AsyncIterator, - Dict, - Iterator, - Mapping, - Optional, - Union, -) - -from langchain.schema.runnable.base import Input, Other, Output, RunnableSerializable -from langchain.schema.runnable.config import RunnableConfig -from langchain.schema.runnable.passthrough import RunnablePassthrough - -if TYPE_CHECKING: - from langchain.callbacks.manager import ( - AsyncCallbackManagerForChainRun, - CallbackManagerForChainRun, - ) - - -class PutLocalVar(RunnablePassthrough): - key: Union[str, Mapping[str, str]] - """The key(s) to use for storing the input variable(s) in local state. - - If a string is provided then the entire input is stored under that key. If a - Mapping is provided, then the map values are gotten from the input and - stored in local state under the map keys. - """ - - def __init__(self, key: Union[str, Mapping[str, str]], **kwargs: Any) -> None: - super().__init__(key=key, **kwargs) - - def _concat_put( - self, - input: Other, - *, - config: Optional[RunnableConfig] = None, - replace: bool = False, - ) -> None: - if config is None: - raise ValueError( - "PutLocalVar should only be used in a RunnableSequence, and should " - "therefore always receive a non-null config." - ) - if isinstance(self.key, str): - if self.key not in config["locals"] or replace: - config["locals"][self.key] = input - else: - config["locals"][self.key] += input - elif isinstance(self.key, Mapping): - if not isinstance(input, Mapping): - raise TypeError( - f"Received key of type Mapping but input of type {type(input)}. " - f"input is expected to be of type Mapping when key is Mapping." - ) - for input_key, put_key in self.key.items(): - if put_key not in config["locals"] or replace: - config["locals"][put_key] = input[input_key] - else: - config["locals"][put_key] += input[input_key] - else: - raise TypeError( - f"`key` should be a string or Mapping[str, str], received type " - f"{(type(self.key))}." - ) - - def invoke( - self, input: Other, config: Optional[RunnableConfig] = None, **kwargs: Any - ) -> Other: - self._concat_put(input, config=config, replace=True) - return super().invoke(input, config=config, **kwargs) - - async def ainvoke( - self, - input: Other, - config: Optional[RunnableConfig] = None, - **kwargs: Optional[Any], - ) -> Other: - self._concat_put(input, config=config, replace=True) - return await super().ainvoke(input, config=config, **kwargs) - - def transform( - self, - input: Iterator[Other], - config: Optional[RunnableConfig] = None, - **kwargs: Optional[Any], - ) -> Iterator[Other]: - for chunk in super().transform(input, config=config, **kwargs): - self._concat_put(chunk, config=config) - yield chunk - - async def atransform( - self, - input: AsyncIterator[Other], - config: Optional[RunnableConfig] = None, - **kwargs: Optional[Any], - ) -> AsyncIterator[Other]: - async for chunk in super().atransform(input, config=config, **kwargs): - self._concat_put(chunk, config=config) - yield chunk - - -class GetLocalVar( - RunnableSerializable[Input, Union[Output, Dict[str, Union[Input, Output]]]] -): - key: str - """The key to extract from the local state.""" - passthrough_key: Optional[str] = None - """The key to use for passing through the invocation input. - - If None, then only the value retrieved from local state is returned. Otherwise a - dictionary ``{self.key: <>, self.passthrough_key: <>}`` - is returned. - """ - - def __init__(self, key: str, **kwargs: Any) -> None: - super().__init__(key=key, **kwargs) - - def _get( - self, - input: Input, - run_manager: Union[CallbackManagerForChainRun, Any], - config: RunnableConfig, - ) -> Union[Output, Dict[str, Union[Input, Output]]]: - if self.passthrough_key: - return { - self.key: config["locals"][self.key], - self.passthrough_key: input, - } - else: - return config["locals"][self.key] - - async def _aget( - self, - input: Input, - run_manager: AsyncCallbackManagerForChainRun, - config: RunnableConfig, - ) -> Union[Output, Dict[str, Union[Input, Output]]]: - return self._get(input, run_manager, config) - - def invoke( - self, input: Input, config: Optional[RunnableConfig] = None - ) -> Union[Output, Dict[str, Union[Input, Output]]]: - if config is None: - raise ValueError( - "GetLocalVar should only be used in a RunnableSequence, and should " - "therefore always receive a non-null config." - ) - - return self._call_with_config(self._get, input, config) - - async def ainvoke( - self, - input: Input, - config: Optional[RunnableConfig] = None, - **kwargs: Optional[Any], - ) -> Union[Output, Dict[str, Union[Input, Output]]]: - if config is None: - raise ValueError( - "GetLocalVar should only be used in a RunnableSequence, and should " - "therefore always receive a non-null config." - ) - - return await self._acall_with_config(self._aget, input, config) diff --git a/libs/langchain/langchain/schema/runnable/base.py b/libs/langchain/langchain/schema/runnable/base.py index b9a3fff16e312..a439e94550cc4 100644 --- a/libs/langchain/langchain/schema/runnable/base.py +++ b/libs/langchain/langchain/schema/runnable/base.py @@ -1656,7 +1656,6 @@ def invoke( # mark each step as a child run patch_config( config, - copy_locals=True, callbacks=run_manager.get_child(f"map:key:{key}"), ), ) @@ -2534,10 +2533,7 @@ def batch( [merge_configs(self.config, conf) for conf in config], ) else: - configs = [ - patch_config(merge_configs(self.config, config), copy_locals=True) - for _ in range(len(inputs)) - ] + configs = [merge_configs(self.config, config) for _ in range(len(inputs))] return self.bound.batch( inputs, configs, @@ -2559,10 +2555,7 @@ async def abatch( [merge_configs(self.config, conf) for conf in config], ) else: - configs = [ - patch_config(merge_configs(self.config, config), copy_locals=True) - for _ in range(len(inputs)) - ] + configs = [merge_configs(self.config, config) for _ in range(len(inputs))] return await self.bound.abatch( inputs, configs, diff --git a/libs/langchain/langchain/schema/runnable/config.py b/libs/langchain/langchain/schema/runnable/config.py index 77f68c26f721a..5faf2fcf1d5e9 100644 --- a/libs/langchain/langchain/schema/runnable/config.py +++ b/libs/langchain/langchain/schema/runnable/config.py @@ -64,13 +64,6 @@ class RunnableConfig(TypedDict, total=False): Name for the tracer run for this call. Defaults to the name of the class. """ - locals: Dict[str, Any] - """ - Variables scoped to this call and any sub-calls. Usually used with - GetLocalVar() and PutLocalVar(). Care should be taken when placing mutable - objects in locals, as they will be shared between parallel sub-calls. - """ - max_concurrency: Optional[int] """ Maximum number of parallel calls to make. If not provided, defaults to @@ -96,7 +89,6 @@ def ensure_config(config: Optional[RunnableConfig] = None) -> RunnableConfig: tags=[], metadata={}, callbacks=None, - locals={}, recursion_limit=25, ) if config is not None: @@ -124,14 +116,13 @@ def get_config_list( return ( list(map(ensure_config, config)) if isinstance(config, list) - else [patch_config(config, copy_locals=True) for _ in range(length)] + else [ensure_config(config) for _ in range(length)] ) def patch_config( config: Optional[RunnableConfig], *, - copy_locals: bool = False, callbacks: Optional[BaseCallbackManager] = None, recursion_limit: Optional[int] = None, max_concurrency: Optional[int] = None, @@ -139,8 +130,6 @@ def patch_config( configurable: Optional[Dict[str, Any]] = None, ) -> RunnableConfig: config = ensure_config(config) - if copy_locals: - config["locals"] = config["locals"].copy() if callbacks is not None: # If we're replacing callbacks we need to unset run_name # As that should apply only to the same run as the original callbacks diff --git a/libs/langchain/tests/unit_tests/schema/runnable/test_locals.py b/libs/langchain/tests/unit_tests/schema/runnable/test_locals.py deleted file mode 100644 index 82c055c0698ac..0000000000000 --- a/libs/langchain/tests/unit_tests/schema/runnable/test_locals.py +++ /dev/null @@ -1,94 +0,0 @@ -from typing import Any, Callable, Type - -import pytest - -from langchain.llms import FakeListLLM -from langchain.prompts import PromptTemplate -from langchain.schema.runnable import ( - GetLocalVar, - PutLocalVar, - Runnable, - RunnablePassthrough, - RunnableSequence, -) - - -@pytest.mark.parametrize( - ("method", "input", "output"), - [ - (lambda r, x: r.invoke(x), "foo", "foo"), - (lambda r, x: r.batch(x), ["foo", "bar"], ["foo", "bar"]), - (lambda r, x: list(r.stream(x))[0], "foo", "foo"), - ], -) -def test_put_get(method: Callable, input: Any, output: Any) -> None: - runnable = PutLocalVar("input") | GetLocalVar("input") - assert method(runnable, input) == output - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - ("method", "input", "output"), - [ - (lambda r, x: r.ainvoke(x), "foo", "foo"), - (lambda r, x: r.abatch(x), ["foo", "bar"], ["foo", "bar"]), - ], -) -async def test_put_get_async(method: Callable, input: Any, output: Any) -> None: - runnable = PutLocalVar("input") | GetLocalVar("input") - assert await method(runnable, input) == output - - -@pytest.mark.parametrize( - ("runnable", "error"), - [ - (PutLocalVar("input"), ValueError), - (GetLocalVar("input"), ValueError), - (PutLocalVar("input") | GetLocalVar("missing"), KeyError), - ], -) -def test_incorrect_usage(runnable: RunnableSequence, error: Type[Exception]) -> None: - with pytest.raises(error): - runnable.invoke("foo") - - -def test_get_in_map() -> None: - runnable: Runnable = PutLocalVar("input") | {"bar": GetLocalVar("input")} - assert runnable.invoke("foo") == {"bar": "foo"} - - -def test_put_in_map() -> None: - runnable: Runnable = {"bar": PutLocalVar("input")} | GetLocalVar("input") - with pytest.raises(KeyError): - runnable.invoke("foo") - - -@pytest.mark.parametrize( - "runnable", - [ - PutLocalVar("input") | GetLocalVar("input", passthrough_key="output"), - ( - PutLocalVar("input") - | {"input": RunnablePassthrough()} - | PromptTemplate.from_template("say {input}") - | FakeListLLM(responses=["hello"]) - | GetLocalVar("input", passthrough_key="output") - ), - ], -) -@pytest.mark.parametrize( - ("method", "input", "output"), - [ - (lambda r, x: r.invoke(x), "hello", {"input": "hello", "output": "hello"}), - (lambda r, x: r.batch(x), ["hello"], [{"input": "hello", "output": "hello"}]), - ( - lambda r, x: list(r.stream(x))[0], - "hello", - {"input": "hello", "output": "hello"}, - ), - ], -) -def test_put_get_sequence( - runnable: RunnableSequence, method: Callable, input: Any, output: Any -) -> None: - assert method(runnable, input) == output diff --git a/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py b/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py index 15b166228b431..0bfc45e62775e 100644 --- a/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py +++ b/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py @@ -1209,7 +1209,6 @@ async def test_with_config(mocker: MockerFixture) -> None: metadata={"key": "value"}, tags=["c"], callbacks=None, - locals={}, recursion_limit=5, ), ), @@ -1219,7 +1218,6 @@ async def test_with_config(mocker: MockerFixture) -> None: metadata={"key": "value"}, tags=["c"], callbacks=None, - locals={}, recursion_limit=5, ), ), @@ -1290,7 +1288,6 @@ async def test_default_method_implementations(mocker: MockerFixture) -> None: metadata={"key": "value"}, tags=[], callbacks=None, - locals={}, recursion_limit=25, ), ), @@ -1300,7 +1297,6 @@ async def test_default_method_implementations(mocker: MockerFixture) -> None: metadata={"key": "value"}, tags=[], callbacks=None, - locals={}, recursion_limit=25, ), ),