Skip to content

Commit

Permalink
core: fix .bind when used with RunnableLambda async methods (#17739)
Browse files Browse the repository at this point in the history
**Description:** Here is a minimal example to illustrate behavior:
```python
from langchain_core.runnables import RunnableLambda

def my_function(*args, **kwargs):
    return 3 + kwargs.get("n", 0)

runnable = RunnableLambda(my_function).bind(n=1)


assert 4 == runnable.invoke({})
assert [4] == list(runnable.stream({}))

assert 4 == await runnable.ainvoke({})
assert [4] == [item async for item in runnable.astream({})]
```
Here, `runnable.invoke({})` and `runnable.stream({})` work fine, but
`runnable.ainvoke({})` raises
```
TypeError: RunnableLambda._ainvoke.<locals>.func() got an unexpected keyword argument 'n'
```
and similarly for `runnable.astream({})`:
```
TypeError: RunnableLambda._atransform.<locals>.func() got an unexpected keyword argument 'n'
```
Here we assume that this behavior is undesired and attempt to fix it.

**Issue:** #17241,
#16446
  • Loading branch information
ccurme authored Feb 21, 2024
1 parent f541545 commit 1b0802b
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 0 deletions.
3 changes: 3 additions & 0 deletions libs/core/langchain_core/runnables/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3414,6 +3414,7 @@ def func(
input: Input,
run_manager: AsyncCallbackManagerForChainRun,
config: RunnableConfig,
**kwargs: Any,
) -> Output:
output: Optional[Output] = None
for chunk in call_func_with_variable_args(
Expand All @@ -3438,6 +3439,7 @@ def func(
input: Input,
run_manager: AsyncCallbackManagerForChainRun,
config: RunnableConfig,
**kwargs: Any,
) -> Output:
return call_func_with_variable_args(
self.func, input, config, run_manager.get_sync(), **kwargs
Expand Down Expand Up @@ -3643,6 +3645,7 @@ def func(
input: Input,
run_manager: AsyncCallbackManagerForChainRun,
config: RunnableConfig,
**kwargs: Any,
) -> Output:
return call_func_with_variable_args(
self.func, input, config, run_manager.get_sync(), **kwargs
Expand Down
20 changes: 20 additions & 0 deletions libs/core/tests/unit_tests/runnables/test_runnable.py
Original file line number Diff line number Diff line change
Expand Up @@ -3424,6 +3424,26 @@ def test_bind_bind() -> None:
) == dumpd(llm.bind(stop=["Observation:"], one="two", hello="world"))


def test_bind_with_lambda() -> None:
def my_function(*args: Any, **kwargs: Any) -> int:
return 3 + kwargs.get("n", 0)

runnable = RunnableLambda(my_function).bind(n=1)
assert 4 == runnable.invoke({})
chunks = list(runnable.stream({}))
assert [4] == chunks


async def test_bind_with_lambda_async() -> None:
def my_function(*args: Any, **kwargs: Any) -> int:
return 3 + kwargs.get("n", 0)

runnable = RunnableLambda(my_function).bind(n=1)
assert 4 == await runnable.ainvoke({})
chunks = [item async for item in runnable.astream({})]
assert [4] == chunks


def test_deep_stream() -> None:
prompt = (
SystemMessagePromptTemplate.from_template("You are a nice assistant.")
Expand Down

0 comments on commit 1b0802b

Please sign in to comment.