From c4012c274468677aec31dbfa226e5a51ef5c3fb7 Mon Sep 17 00:00:00 2001 From: hmnfalahi Date: Sat, 20 Apr 2024 12:55:27 +0330 Subject: [PATCH] Modified _get_python_function_required_args to ignore self/cls from required args of class functions --- .../langchain_core/utils/function_calling.py | 8 +++-- .../unit_tests/utils/test_function_calling.py | 34 ++++++++++++++++++- 2 files changed, 39 insertions(+), 3 deletions(-) diff --git a/libs/core/langchain_core/utils/function_calling.py b/libs/core/langchain_core/utils/function_calling.py index 860259a93e35f..5861becb5f321 100644 --- a/libs/core/langchain_core/utils/function_calling.py +++ b/libs/core/langchain_core/utils/function_calling.py @@ -17,6 +17,7 @@ Union, cast, ) +from types import FunctionType, MethodType from typing_extensions import TypedDict @@ -200,8 +201,11 @@ def _get_python_function_required_args(function: Callable) -> List[str]: required = spec.args[: -len(spec.defaults)] if spec.defaults else spec.args required += [k for k in spec.kwonlyargs if k not in (spec.kwonlydefaults or {})] - is_class = type(function) is type - if is_class and required[0] == "self": + is_function_type = isinstance(function, FunctionType) + is_method_type = isinstance(function, MethodType) + if is_function_type and required[0] == "self": + required = required[1:] + elif is_method_type and required[0] == "cls": required = required[1:] return required diff --git a/libs/core/tests/unit_tests/utils/test_function_calling.py b/libs/core/tests/unit_tests/utils/test_function_calling.py index 00328bcf29b44..b58ade6746df8 100644 --- a/libs/core/tests/unit_tests/utils/test_function_calling.py +++ b/libs/core/tests/unit_tests/utils/test_function_calling.py @@ -71,11 +71,42 @@ def json_schema() -> Dict: } +@pytest.fixture() +def dummy_instance_method() -> object: + class Dummy: + def dummy_function(self, arg1: int, arg2: Literal["bar", "baz"]) -> None: + """dummy function + + Args: + arg1: foo + arg2: one of 'bar', 'baz' + """ + pass + return Dummy + + +@pytest.fixture() +def dummy_class_method() -> object: + class Dummy: + @classmethod + def dummy_function(cls, arg1: int, arg2: Literal["bar", "baz"]) -> None: + """dummy function + + Args: + arg1: foo + arg2: one of 'bar', 'baz' + """ + pass + return Dummy + + def test_convert_to_openai_function( pydantic: Type[BaseModel], function: Callable, dummy_tool: BaseTool, json_schema: Dict, + dummy_instance_method: object, + dummy_class_method: object, ) -> None: expected = { "name": "dummy_function", @@ -94,7 +125,8 @@ def test_convert_to_openai_function( }, } - for fn in (pydantic, function, dummy_tool, json_schema, expected): + for fn in (pydantic, function, dummy_tool, json_schema, expected, + dummy_instance_method.dummy_function, dummy_class_method.dummy_function): actual = convert_to_openai_function(fn) # type: ignore assert actual == expected