From 74bf620e978d62e70daa27b70f180884c2fb26bb Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Wed, 2 Oct 2024 12:50:58 -0400 Subject: [PATCH] core[patch]: Support injected tool args that are arbitrary types (#27045) This adds support for inject tool args that are arbitrary types when used with pydantic 2. We'll need to add similar logic on the v1 path, and potentially mirror the config from the original model when we're doing the subset. --- libs/core/langchain_core/utils/pydantic.py | 7 +++++-- libs/core/tests/unit_tests/test_tools.py | 15 +++++++++++++++ 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/libs/core/langchain_core/utils/pydantic.py b/libs/core/langchain_core/utils/pydantic.py index 1352bcfafffb2..93375e09f348b 100644 --- a/libs/core/langchain_core/utils/pydantic.py +++ b/libs/core/langchain_core/utils/pydantic.py @@ -266,7 +266,7 @@ def _create_subset_model_v2( fn_description: Optional[str] = None, ) -> type[pydantic.BaseModel]: """Create a pydantic model with a subset of the model fields.""" - from pydantic import create_model + from pydantic import ConfigDict, create_model from pydantic.fields import FieldInfo descriptions_ = descriptions or {} @@ -278,7 +278,10 @@ def _create_subset_model_v2( if field.metadata: field_info.metadata = field.metadata fields[field_name] = (field.annotation, field_info) - rtn = create_model(name, **fields) # type: ignore + + rtn = create_model( # type: ignore + name, **fields, __config__=ConfigDict(arbitrary_types_allowed=True) + ) # TODO(0.3): Determine if there is a more "pydantic" way to preserve annotations. # This is done to preserve __annotations__ when working with pydantic 2.x diff --git a/libs/core/tests/unit_tests/test_tools.py b/libs/core/tests/unit_tests/test_tools.py index a61cead53c23f..3eae40ede1e82 100644 --- a/libs/core/tests/unit_tests/test_tools.py +++ b/libs/core/tests/unit_tests/test_tools.py @@ -2090,3 +2090,18 @@ class FooSchema(BaseModel): with pytest.raises(NotImplementedError): assert tool.invoke("hello") == "hello" + + +def test_injected_arg_with_complex_type() -> None: + """Test that an injected tool arg can be a complex type.""" + + class Foo: + def __init__(self) -> None: + self.value = "bar" + + @tool + def injected_tool(x: int, foo: Annotated[Foo, InjectedToolArg]) -> str: + """Tool that has an injected tool arg.""" + return foo.value + + assert injected_tool.invoke({"x": 5, "foo": Foo()}) == "bar" # type: ignore