Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] Feat: update config CVar in tool.invoke #20808

Merged
merged 5 commits into from
Apr 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 38 additions & 6 deletions libs/core/langchain_core/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,12 @@

from __future__ import annotations

import asyncio
import inspect
import uuid
import warnings
from abc import ABC, abstractmethod
from contextvars import copy_context
from functools import partial
from inspect import signature
from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple, Type, Union
Expand Down Expand Up @@ -60,7 +62,12 @@
RunnableSerializable,
ensure_config,
)
from langchain_core.runnables.config import run_in_executor
from langchain_core.runnables.config import (
patch_config,
run_in_executor,
var_child_runnable_config,
)
from langchain_core.runnables.utils import accepts_context


class SchemaAnnotationError(TypeError):
Expand Down Expand Up @@ -255,6 +262,7 @@ def invoke(
metadata=config.get("metadata"),
run_name=config.get("run_name"),
run_id=config.pop("run_id", None),
config=config,
**kwargs,
)

Expand All @@ -272,6 +280,7 @@ async def ainvoke(
metadata=config.get("metadata"),
run_name=config.get("run_name"),
run_id=config.pop("run_id", None),
config=config,
**kwargs,
)

Expand Down Expand Up @@ -353,6 +362,7 @@ def run(
metadata: Optional[Dict[str, Any]] = None,
run_name: Optional[str] = None,
run_id: Optional[uuid.UUID] = None,
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> Any:
"""Run the tool."""
Expand Down Expand Up @@ -385,12 +395,20 @@ def run(
**kwargs,
)
try:
child_config = patch_config(
config,
callbacks=run_manager.get_child(),
)
context = copy_context()
context.run(var_child_runnable_config.set, child_config)
parsed_input = self._parse_input(tool_input)
tool_args, tool_kwargs = self._to_args_and_kwargs(parsed_input)
observation = (
self._run(*tool_args, run_manager=run_manager, **tool_kwargs)
context.run(
self._run, *tool_args, run_manager=run_manager, **tool_kwargs
)
if new_arg_supported
else self._run(*tool_args, **tool_kwargs)
else context.run(self._run, *tool_args, **tool_kwargs)
)
except ValidationError as e:
if not self.handle_validation_error:
Expand Down Expand Up @@ -446,6 +464,7 @@ async def arun(
metadata: Optional[Dict[str, Any]] = None,
run_name: Optional[str] = None,
run_id: Optional[uuid.UUID] = None,
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> Any:
"""Run the tool asynchronously."""
Expand Down Expand Up @@ -476,11 +495,24 @@ async def arun(
parsed_input = self._parse_input(tool_input)
# We then call the tool on the tool input to get an observation
tool_args, tool_kwargs = self._to_args_and_kwargs(parsed_input)
observation = (
await self._arun(*tool_args, run_manager=run_manager, **tool_kwargs)
child_config = patch_config(
config,
callbacks=run_manager.get_child(),
)
context = copy_context()
context.run(var_child_runnable_config.set, child_config)
coro = (
context.run(
self._arun, *tool_args, run_manager=run_manager, **tool_kwargs
)
if new_arg_supported
else await self._arun(*tool_args, **tool_kwargs)
else context.run(self._arun, *tool_args, **tool_kwargs)
)
if accepts_context(asyncio.create_task):
observation = await asyncio.create_task(coro, context=context) # type: ignore
else:
observation = await coro

except ValidationError as e:
if not self.handle_validation_error:
raise e
Expand Down
34 changes: 34 additions & 0 deletions libs/core/tests/unit_tests/test_tools.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
"""Test the base tool implementation."""

import asyncio
import json
import sys
from datetime import datetime
from enum import Enum
from functools import partial
Expand All @@ -13,6 +15,7 @@
CallbackManagerForToolRun,
)
from langchain_core.pydantic_v1 import BaseModel, ValidationError
from langchain_core.runnables import ensure_config
from langchain_core.tools import (
BaseTool,
SchemaAnnotationError,
Expand Down Expand Up @@ -871,3 +874,34 @@ def foo(bar: str, baz: Optional[int] = 3, buzz: Optional[str] = "buzz") -> dict:
else:
with pytest.raises(ValidationError):
foo.invoke(inputs) # type: ignore


def test_tool_pass_context() -> None:
@tool
def foo(bar: str) -> str:
"""The foo."""
config = ensure_config()
assert config["configurable"]["foo"] == "not-bar"
assert bar == "baz"
return bar

assert foo.invoke({"bar": "baz"}, {"configurable": {"foo": "not-bar"}}) == "baz" # type: ignore


@pytest.mark.skipif(
sys.version_info < (3, 11),
reason="requires python3.11 or higher",
)
async def test_async_tool_pass_context() -> None:
@tool
async def foo(bar: str) -> str:
"""The foo."""
await asyncio.sleep(0.0001)
config = ensure_config()
assert config["configurable"]["foo"] == "not-bar"
assert bar == "baz"
return bar

assert (
await foo.ainvoke({"bar": "baz"}, {"configurable": {"foo": "not-bar"}}) == "baz" # type: ignore
)
Loading