From f64ecc9b778d5ef8d83466c234d621af935a56ff Mon Sep 17 00:00:00 2001 From: Zheyuan Wei Date: Sat, 23 Nov 2024 17:22:55 -0500 Subject: [PATCH 1/8] fix: filter out `InjectedToolArg` filter out `InjectedToolArg` when creating schemas from a function --- libs/core/langchain_core/tools/base.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/libs/core/langchain_core/tools/base.py b/libs/core/langchain_core/tools/base.py index 9782234dfb1a0..baf335ad9d321 100644 --- a/libs/core/langchain_core/tools/base.py +++ b/libs/core/langchain_core/tools/base.py @@ -279,6 +279,11 @@ def create_schema_from_function( ): filter_args_.append(existing_param) + # Filter out injected arguments + for param_name, param in sig.parameters.items(): + if param.annotation is InjectedToolArg: + filter_args_.append(param_name) + description, arg_descriptions = _infer_arg_descriptions( func, parse_docstring=parse_docstring, From 5ecc630e2435ffd8088e4f0d32f8c27f33543502 Mon Sep 17 00:00:00 2001 From: XiaoConan Date: Sun, 24 Nov 2024 22:24:48 -0500 Subject: [PATCH 2/8] Update base.py to identify InjectedToolArg annotation --- libs/core/langchain_core/tools/base.py | 33 ++++++++++++++++++-------- 1 file changed, 23 insertions(+), 10 deletions(-) diff --git a/libs/core/langchain_core/tools/base.py b/libs/core/langchain_core/tools/base.py index baf335ad9d321..ef91c4e997cf9 100644 --- a/libs/core/langchain_core/tools/base.py +++ b/libs/core/langchain_core/tools/base.py @@ -214,7 +214,7 @@ def create_schema_from_function( filter_args: Optional[Sequence[str]] = None, parse_docstring: bool = False, error_on_invalid_docstring: bool = False, - include_injected: bool = True, + include_injected: bool = False, ) -> type[BaseModel]: """Create a pydantic schema from a function's signature. @@ -265,6 +265,13 @@ def create_schema_from_function( if filter_args: filter_args_ = filter_args + + existing_params: list[str] = list(sig.parameters.keys()) + for existing_param in existing_params: + if not include_injected and _is_injected_arg_type( + sig.parameters[existing_param].annotation + ): + filter_args_.append(existing_param) else: # Handle classmethods and instance methods existing_params: list[str] = list(sig.parameters.keys()) @@ -279,10 +286,10 @@ def create_schema_from_function( ): filter_args_.append(existing_param) - # Filter out injected arguments - for param_name, param in sig.parameters.items(): - if param.annotation is InjectedToolArg: - filter_args_.append(param_name) + # # Filter out injected arguments + # for param_name, param in sig.parameters.items(): + # if param.annotation is InjectedToolArg: + # filter_args_.append(param_name) description, arg_descriptions = _infer_arg_descriptions( func, @@ -960,11 +967,17 @@ class InjectedToolArg: def _is_injected_arg_type(type_: type) -> bool: - return any( - isinstance(arg, InjectedToolArg) - or (isinstance(arg, type) and issubclass(arg, InjectedToolArg)) - for arg in get_args(type_)[1:] - ) + if type_ is InjectedToolArg: + return True + for arg in get_args(type_): + if arg is InjectedToolArg or (isinstance(arg, type) and issubclass(arg, InjectedToolArg)): + return True + return False + # return any( + # isinstance(arg, InjectedToolArg) + # or (isinstance(arg, type) and issubclass(arg, InjectedToolArg)) + # for arg in get_args(type_)[1:] + # ) def get_all_basemodel_annotations( From 3455d2a252c247b91ae173aba7ecda4131afcfd5 Mon Sep 17 00:00:00 2001 From: XiaoConan Date: Thu, 28 Nov 2024 05:53:52 -0500 Subject: [PATCH 3/8] Create Custom Json schema for InjectedToolArg annotation --- libs/core/langchain_core/tools/__init__.py | 1 + libs/core/langchain_core/tools/base.py | 11 ++++++++++- 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/libs/core/langchain_core/tools/__init__.py b/libs/core/langchain_core/tools/__init__.py index 3553ff10ec79f..822d5376a22bf 100644 --- a/libs/core/langchain_core/tools/__init__.py +++ b/libs/core/langchain_core/tools/__init__.py @@ -30,6 +30,7 @@ ) from langchain_core.tools.base import ( InjectedToolArg as InjectedToolArg, + InjectedToolArgSchema as InjectedToolArgSchema, ) from langchain_core.tools.base import SchemaAnnotationError as SchemaAnnotationError from langchain_core.tools.base import ( diff --git a/libs/core/langchain_core/tools/base.py b/libs/core/langchain_core/tools/base.py index ef91c4e997cf9..ecb0a093e81e7 100644 --- a/libs/core/langchain_core/tools/base.py +++ b/libs/core/langchain_core/tools/base.py @@ -33,6 +33,7 @@ ValidationError, model_validator, validate_arguments, + WithJsonSchema, ) from pydantic.v1 import BaseModel as BaseModelV1 from pydantic.v1 import ValidationError as ValidationErrorV1 @@ -214,7 +215,7 @@ def create_schema_from_function( filter_args: Optional[Sequence[str]] = None, parse_docstring: bool = False, error_on_invalid_docstring: bool = False, - include_injected: bool = False, + include_injected: bool = True, ) -> type[BaseModel]: """Create a pydantic schema from a function's signature. @@ -965,6 +966,14 @@ def _get_runnable_config_param(func: Callable) -> Optional[str]: class InjectedToolArg: """Annotation for a Tool arg that is **not** meant to be generated by a model.""" +InjectedToolArgSchema = Annotated[ + InjectedToolArg, + WithJsonSchema( + { + "type": "Injected-Tool-Argument", + } + ), +] def _is_injected_arg_type(type_: type) -> bool: if type_ is InjectedToolArg: From a002cc9c76e7958563ba60ed658d696597218b90 Mon Sep 17 00:00:00 2001 From: XiaoConan Date: Fri, 29 Nov 2024 03:34:51 -0500 Subject: [PATCH 4/8] Added documentations to the modifed code for fixing InjectedToolArg schema inspection. --- libs/core/langchain_core/tools/base.py | 12 ++---------- libs/core/langchain_core/tools/structured.py | 2 ++ 2 files changed, 4 insertions(+), 10 deletions(-) diff --git a/libs/core/langchain_core/tools/base.py b/libs/core/langchain_core/tools/base.py index ecb0a093e81e7..07dd814e1d59c 100644 --- a/libs/core/langchain_core/tools/base.py +++ b/libs/core/langchain_core/tools/base.py @@ -269,6 +269,7 @@ def create_schema_from_function( existing_params: list[str] = list(sig.parameters.keys()) for existing_param in existing_params: + # add arguments with InjectedToolArg annotation to filter_args_ if not include_injected and _is_injected_arg_type( sig.parameters[existing_param].annotation ): @@ -287,11 +288,6 @@ def create_schema_from_function( ): filter_args_.append(existing_param) - # # Filter out injected arguments - # for param_name, param in sig.parameters.items(): - # if param.annotation is InjectedToolArg: - # filter_args_.append(param_name) - description, arg_descriptions = _infer_arg_descriptions( func, parse_docstring=parse_docstring, @@ -966,6 +962,7 @@ def _get_runnable_config_param(func: Callable) -> Optional[str]: class InjectedToolArg: """Annotation for a Tool arg that is **not** meant to be generated by a model.""" +# Add Custom json schema for injected tool arguments InjectedToolArgSchema = Annotated[ InjectedToolArg, WithJsonSchema( @@ -982,11 +979,6 @@ def _is_injected_arg_type(type_: type) -> bool: if arg is InjectedToolArg or (isinstance(arg, type) and issubclass(arg, InjectedToolArg)): return True return False - # return any( - # isinstance(arg, InjectedToolArg) - # or (isinstance(arg, type) and issubclass(arg, InjectedToolArg)) - # for arg in get_args(type_)[1:] - # ) def get_all_basemodel_annotations( diff --git a/libs/core/langchain_core/tools/structured.py b/libs/core/langchain_core/tools/structured.py index 174e7b2f53704..d3e6e33aaecd4 100644 --- a/libs/core/langchain_core/tools/structured.py +++ b/libs/core/langchain_core/tools/structured.py @@ -173,12 +173,14 @@ def add(a: int, b: int) -> int: name = name or source_function.__name__ if args_schema is None and infer_schema: # schema name is appended within function + # Use include_injected to decide whether to include injected args in the schema args_schema = create_schema_from_function( name, source_function, parse_docstring=parse_docstring, error_on_invalid_docstring=error_on_invalid_docstring, filter_args=_filter_schema_args(source_function), + # include_injected=False, ) description_ = description if description is None and not parse_docstring: From b8484da5fcb6fd175dca295a8963781bd63614b1 Mon Sep 17 00:00:00 2001 From: XiaoConan Date: Fri, 29 Nov 2024 22:13:48 -0500 Subject: [PATCH 5/8] Update create_schema_from_function --- libs/core/langchain_core/tools/base.py | 25 +++++++++++-------------- 1 file changed, 11 insertions(+), 14 deletions(-) diff --git a/libs/core/langchain_core/tools/base.py b/libs/core/langchain_core/tools/base.py index 07dd814e1d59c..c4fce8b89c765 100644 --- a/libs/core/langchain_core/tools/base.py +++ b/libs/core/langchain_core/tools/base.py @@ -264,29 +264,26 @@ def create_schema_from_function( inferred_model = validated.model # type: ignore + # extract the function parameters + existing_params: list[str] = list(sig.parameters.keys()) + + # Create filtered arguments list if filter_args: filter_args_ = filter_args - existing_params: list[str] = list(sig.parameters.keys()) - for existing_param in existing_params: - # add arguments with InjectedToolArg annotation to filter_args_ - if not include_injected and _is_injected_arg_type( - sig.parameters[existing_param].annotation - ): - filter_args_.append(existing_param) else: # Handle classmethods and instance methods - existing_params: list[str] = list(sig.parameters.keys()) if existing_params and existing_params[0] in ("self", "cls") and in_class: filter_args_ = [existing_params[0]] + list(FILTERED_ARGS) else: filter_args_ = list(FILTERED_ARGS) - - for existing_param in existing_params: - if not include_injected and _is_injected_arg_type( - sig.parameters[existing_param].annotation - ): - filter_args_.append(existing_param) + + # add arguments with InjectedToolArg annotation to filter_args_ + for existing_param in existing_params: + if not include_injected and _is_injected_arg_type( + sig.parameters[existing_param].annotation + ): + filter_args_.append(existing_param) description, arg_descriptions = _infer_arg_descriptions( func, From 84cce3f0fcb49adc97b50a1f9d92fa7b15fb8ab1 Mon Sep 17 00:00:00 2001 From: XiaoConan Date: Fri, 29 Nov 2024 23:43:57 -0500 Subject: [PATCH 6/8] Fixed formatting issue --- libs/core/langchain_core/tools/__init__.py | 2 ++ libs/core/langchain_core/tools/base.py | 14 +++++++++----- libs/core/langchain_core/tools/structured.py | 2 +- 3 files changed, 12 insertions(+), 6 deletions(-) diff --git a/libs/core/langchain_core/tools/__init__.py b/libs/core/langchain_core/tools/__init__.py index 822d5376a22bf..d55afb1255cdf 100644 --- a/libs/core/langchain_core/tools/__init__.py +++ b/libs/core/langchain_core/tools/__init__.py @@ -30,6 +30,8 @@ ) from langchain_core.tools.base import ( InjectedToolArg as InjectedToolArg, +) +from langchain_core.tools.base import ( InjectedToolArgSchema as InjectedToolArgSchema, ) from langchain_core.tools.base import SchemaAnnotationError as SchemaAnnotationError diff --git a/libs/core/langchain_core/tools/base.py b/libs/core/langchain_core/tools/base.py index c4fce8b89c765..65619c3bc2309 100644 --- a/libs/core/langchain_core/tools/base.py +++ b/libs/core/langchain_core/tools/base.py @@ -31,9 +31,9 @@ PydanticDeprecationWarning, SkipValidation, ValidationError, + WithJsonSchema, model_validator, validate_arguments, - WithJsonSchema, ) from pydantic.v1 import BaseModel as BaseModelV1 from pydantic.v1 import ValidationError as ValidationErrorV1 @@ -269,15 +269,15 @@ def create_schema_from_function( # Create filtered arguments list if filter_args: - filter_args_ = filter_args - + filter_args_ = list(filter_args) + else: # Handle classmethods and instance methods if existing_params and existing_params[0] in ("self", "cls") and in_class: filter_args_ = [existing_params[0]] + list(FILTERED_ARGS) else: filter_args_ = list(FILTERED_ARGS) - + # add arguments with InjectedToolArg annotation to filter_args_ for existing_param in existing_params: if not include_injected and _is_injected_arg_type( @@ -959,6 +959,7 @@ def _get_runnable_config_param(func: Callable) -> Optional[str]: class InjectedToolArg: """Annotation for a Tool arg that is **not** meant to be generated by a model.""" + # Add Custom json schema for injected tool arguments InjectedToolArgSchema = Annotated[ InjectedToolArg, @@ -969,11 +970,14 @@ class InjectedToolArg: ), ] + def _is_injected_arg_type(type_: type) -> bool: if type_ is InjectedToolArg: return True for arg in get_args(type_): - if arg is InjectedToolArg or (isinstance(arg, type) and issubclass(arg, InjectedToolArg)): + if arg is InjectedToolArg or ( + isinstance(arg, type) and issubclass(arg, InjectedToolArg) + ): return True return False diff --git a/libs/core/langchain_core/tools/structured.py b/libs/core/langchain_core/tools/structured.py index d3e6e33aaecd4..fd6ee9c416949 100644 --- a/libs/core/langchain_core/tools/structured.py +++ b/libs/core/langchain_core/tools/structured.py @@ -173,7 +173,7 @@ def add(a: int, b: int) -> int: name = name or source_function.__name__ if args_schema is None and infer_schema: # schema name is appended within function - # Use include_injected to decide whether to include injected args in the schema + # Use include_injected to exclude injected args in the schema args_schema = create_schema_from_function( name, source_function, From 698e5827b538e66ecdeaa296d2efcfb2429504bc Mon Sep 17 00:00:00 2001 From: XiaoConan Date: Sat, 30 Nov 2024 03:06:50 -0500 Subject: [PATCH 7/8] pass InjectedToolArg unit tests --- libs/core/langchain_core/tools/base.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/libs/core/langchain_core/tools/base.py b/libs/core/langchain_core/tools/base.py index 65619c3bc2309..0a5422c8b07e7 100644 --- a/libs/core/langchain_core/tools/base.py +++ b/libs/core/langchain_core/tools/base.py @@ -104,7 +104,7 @@ def _get_filtered_args( for i, (k, param) in enumerate(valid_keys.items()) if k not in filter_args and (i > 0 or param.name not in ("self", "cls")) - and (include_injected or not _is_injected_arg_type(param.annotation)) + and (include_injected or not _check_injected_arg_type(param.annotation)) } @@ -280,7 +280,7 @@ def create_schema_from_function( # add arguments with InjectedToolArg annotation to filter_args_ for existing_param in existing_params: - if not include_injected and _is_injected_arg_type( + if not include_injected and _check_injected_arg_type( sig.parameters[existing_param].annotation ): filter_args_.append(existing_param) @@ -972,6 +972,16 @@ class InjectedToolArg: def _is_injected_arg_type(type_: type) -> bool: + return any( + isinstance(arg, InjectedToolArg) + or (isinstance(arg, type) and issubclass(arg, InjectedToolArg)) + for arg in get_args(type_)[1:] + ) + + +# Identify if a type contains an InjectedToolArg annotation +# Used to filter out injected arguments from the schema +def _check_injected_arg_type(type_: type) -> bool: if type_ is InjectedToolArg: return True for arg in get_args(type_): From 679d994ef2eff3c5f0036c7b26d6b57591b1a376 Mon Sep 17 00:00:00 2001 From: XiaoConan Date: Mon, 16 Dec 2024 19:59:11 -0500 Subject: [PATCH 8/8] Add unit test for the issue use case --- libs/core/tests/unit_tests/test_tools.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/libs/core/tests/unit_tests/test_tools.py b/libs/core/tests/unit_tests/test_tools.py index ce7ea4894bb5a..7d2185e93a42e 100644 --- a/libs/core/tests/unit_tests/test_tools.py +++ b/libs/core/tests/unit_tests/test_tools.py @@ -46,6 +46,7 @@ ) from langchain_core.tools.base import ( InjectedToolArg, + InjectedToolArgSchema, SchemaAnnotationError, _is_message_content_block, _is_message_content_type, @@ -2110,3 +2111,18 @@ def injected_tool(x: int, foo: Annotated[Foo, InjectedToolArg]) -> str: return foo.value assert injected_tool.invoke({"x": 5, "foo": Foo()}) == "bar" # type: ignore + + +# Test user specific tools produce correct the correct json schema +# for it's arguments using .args +def test_user_specific_tool_without_schema() -> None: + @tool + def user_specific_tool(x: str, y: InjectedToolArgSchema) -> str: + """Tool that has an injected tool arg.""" + return "User {x} processed {y}" + + # Verify the tool's args schema + assert user_specific_tool.args == { + "x": {"title": "X", "type": "string"}, + "y": {"title": "Y", "type": "Injected-Tool-Argument"}, + }