Skip to content

Commit

Permalink
Add configurable + contextvars to tools
Browse files Browse the repository at this point in the history
  • Loading branch information
hinthornw committed Apr 24, 2024
1 parent 9111d3a commit ba840d9
Showing 1 changed file with 26 additions and 3 deletions.
29 changes: 26 additions & 3 deletions libs/core/langchain_core/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,12 @@
import inspect
import uuid
import warnings
<<<<<<< Updated upstream
from abc import ABC, abstractmethod
=======
from abc import abstractmethod
from contextvars import copy_context
>>>>>>> Stashed changes
from functools import partial
from inspect import signature
from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple, Type, Union
Expand Down Expand Up @@ -60,7 +65,11 @@
RunnableSerializable,
ensure_config,
)
from langchain_core.runnables.config import run_in_executor
from langchain_core.runnables.config import (
patch_config,
run_in_executor,
var_child_runnable_config,
)


class SchemaAnnotationError(TypeError):
Expand Down Expand Up @@ -255,6 +264,7 @@ def invoke(
metadata=config.get("metadata"),
run_name=config.get("run_name"),
run_id=config.pop("run_id", None),
configurable=config.get("configurable"),
**kwargs,
)

Expand All @@ -272,6 +282,7 @@ async def ainvoke(
metadata=config.get("metadata"),
run_name=config.get("run_name"),
run_id=config.pop("run_id", None),
configurable=config.get("configurable"),
**kwargs,
)

Expand Down Expand Up @@ -353,6 +364,7 @@ def run(
metadata: Optional[Dict[str, Any]] = None,
run_name: Optional[str] = None,
run_id: Optional[uuid.UUID] = None,
configurable: Optional[dict] = None,
**kwargs: Any,
) -> Any:
"""Run the tool."""
Expand Down Expand Up @@ -385,9 +397,14 @@ def run(
**kwargs,
)
try:
child_config = patch_config(
{"configurable": configurable or {}}, callbacks=run_manager.get_child()
)
context = copy_context()
context.run(var_child_runnable_config.set, child_config)
parsed_input = self._parse_input(tool_input)
tool_args, tool_kwargs = self._to_args_and_kwargs(parsed_input)
observation = (
observation = context.run(
self._run(*tool_args, run_manager=run_manager, **tool_kwargs)
if new_arg_supported
else self._run(*tool_args, **tool_kwargs)
Expand Down Expand Up @@ -446,6 +463,7 @@ async def arun(
metadata: Optional[Dict[str, Any]] = None,
run_name: Optional[str] = None,
run_id: Optional[uuid.UUID] = None,
configurable: Optional[dict] = None,
**kwargs: Any,
) -> Any:
"""Run the tool asynchronously."""
Expand Down Expand Up @@ -476,7 +494,12 @@ async def arun(
parsed_input = self._parse_input(tool_input)
# We then call the tool on the tool input to get an observation
tool_args, tool_kwargs = self._to_args_and_kwargs(parsed_input)
observation = (
child_config = patch_config(
{"configurable": configurable or {}}, callbacks=run_manager.get_child()
)
context = copy_context()
context.run(var_child_runnable_config.set, child_config)
observation = context.run(
await self._arun(*tool_args, run_manager=run_manager, **tool_kwargs)
if new_arg_supported
else await self._arun(*tool_args, **tool_kwargs)
Expand Down

0 comments on commit ba840d9

Please sign in to comment.