diff --git a/libs/core/langchain_core/tools/__init__.py b/libs/core/langchain_core/tools/__init__.py index 3553ff10ec79f..d55afb1255cdf 100644 --- a/libs/core/langchain_core/tools/__init__.py +++ b/libs/core/langchain_core/tools/__init__.py @@ -31,6 +31,9 @@ from langchain_core.tools.base import ( InjectedToolArg as InjectedToolArg, ) +from langchain_core.tools.base import ( + InjectedToolArgSchema as InjectedToolArgSchema, +) from langchain_core.tools.base import SchemaAnnotationError as SchemaAnnotationError from langchain_core.tools.base import ( ToolException as ToolException, diff --git a/libs/core/langchain_core/tools/base.py b/libs/core/langchain_core/tools/base.py index 9782234dfb1a0..0a5422c8b07e7 100644 --- a/libs/core/langchain_core/tools/base.py +++ b/libs/core/langchain_core/tools/base.py @@ -31,6 +31,7 @@ PydanticDeprecationWarning, SkipValidation, ValidationError, + WithJsonSchema, model_validator, validate_arguments, ) @@ -103,7 +104,7 @@ def _get_filtered_args( for i, (k, param) in enumerate(valid_keys.items()) if k not in filter_args and (i > 0 or param.name not in ("self", "cls")) - and (include_injected or not _is_injected_arg_type(param.annotation)) + and (include_injected or not _check_injected_arg_type(param.annotation)) } @@ -263,21 +264,26 @@ def create_schema_from_function( inferred_model = validated.model # type: ignore + # extract the function parameters + existing_params: list[str] = list(sig.parameters.keys()) + + # Create filtered arguments list if filter_args: - filter_args_ = filter_args + filter_args_ = list(filter_args) + else: # Handle classmethods and instance methods - existing_params: list[str] = list(sig.parameters.keys()) if existing_params and existing_params[0] in ("self", "cls") and in_class: filter_args_ = [existing_params[0]] + list(FILTERED_ARGS) else: filter_args_ = list(FILTERED_ARGS) - for existing_param in existing_params: - if not include_injected and _is_injected_arg_type( - sig.parameters[existing_param].annotation - ): - filter_args_.append(existing_param) + # add arguments with InjectedToolArg annotation to filter_args_ + for existing_param in existing_params: + if not include_injected and _check_injected_arg_type( + sig.parameters[existing_param].annotation + ): + filter_args_.append(existing_param) description, arg_descriptions = _infer_arg_descriptions( func, @@ -954,6 +960,17 @@ class InjectedToolArg: """Annotation for a Tool arg that is **not** meant to be generated by a model.""" +# Add Custom json schema for injected tool arguments +InjectedToolArgSchema = Annotated[ + InjectedToolArg, + WithJsonSchema( + { + "type": "Injected-Tool-Argument", + } + ), +] + + def _is_injected_arg_type(type_: type) -> bool: return any( isinstance(arg, InjectedToolArg) @@ -962,6 +979,19 @@ def _is_injected_arg_type(type_: type) -> bool: ) +# Identify if a type contains an InjectedToolArg annotation +# Used to filter out injected arguments from the schema +def _check_injected_arg_type(type_: type) -> bool: + if type_ is InjectedToolArg: + return True + for arg in get_args(type_): + if arg is InjectedToolArg or ( + isinstance(arg, type) and issubclass(arg, InjectedToolArg) + ): + return True + return False + + def get_all_basemodel_annotations( cls: Union[TypeBaseModel, Any], *, default_to_bound: bool = True ) -> dict[str, type]: diff --git a/libs/core/langchain_core/tools/structured.py b/libs/core/langchain_core/tools/structured.py index 174e7b2f53704..fd6ee9c416949 100644 --- a/libs/core/langchain_core/tools/structured.py +++ b/libs/core/langchain_core/tools/structured.py @@ -173,12 +173,14 @@ def add(a: int, b: int) -> int: name = name or source_function.__name__ if args_schema is None and infer_schema: # schema name is appended within function + # Use include_injected to exclude injected args in the schema args_schema = create_schema_from_function( name, source_function, parse_docstring=parse_docstring, error_on_invalid_docstring=error_on_invalid_docstring, filter_args=_filter_schema_args(source_function), + # include_injected=False, ) description_ = description if description is None and not parse_docstring: diff --git a/libs/core/tests/unit_tests/test_tools.py b/libs/core/tests/unit_tests/test_tools.py index ce7ea4894bb5a..7d2185e93a42e 100644 --- a/libs/core/tests/unit_tests/test_tools.py +++ b/libs/core/tests/unit_tests/test_tools.py @@ -46,6 +46,7 @@ ) from langchain_core.tools.base import ( InjectedToolArg, + InjectedToolArgSchema, SchemaAnnotationError, _is_message_content_block, _is_message_content_type, @@ -2110,3 +2111,18 @@ def injected_tool(x: int, foo: Annotated[Foo, InjectedToolArg]) -> str: return foo.value assert injected_tool.invoke({"x": 5, "foo": Foo()}) == "bar" # type: ignore + + +# Test user specific tools produce correct the correct json schema +# for it's arguments using .args +def test_user_specific_tool_without_schema() -> None: + @tool + def user_specific_tool(x: str, y: InjectedToolArgSchema) -> str: + """Tool that has an injected tool arg.""" + return "User {x} processed {y}" + + # Verify the tool's args schema + assert user_specific_tool.args == { + "x": {"title": "X", "type": "string"}, + "y": {"title": "Y", "type": "Injected-Tool-Argument"}, + }