Skip to content

Commit

Permalink
Modified _get_python_function_required_args to ignore self/cls from r…
Browse files Browse the repository at this point in the history
…equired args of class functions
  • Loading branch information
hmnfalahi committed Apr 24, 2024
1 parent ed98060 commit c4012c2
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 3 deletions.
8 changes: 6 additions & 2 deletions libs/core/langchain_core/utils/function_calling.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
Union,
cast,
)
from types import FunctionType, MethodType

from typing_extensions import TypedDict

Expand Down Expand Up @@ -200,8 +201,11 @@ def _get_python_function_required_args(function: Callable) -> List[str]:
required = spec.args[: -len(spec.defaults)] if spec.defaults else spec.args
required += [k for k in spec.kwonlyargs if k not in (spec.kwonlydefaults or {})]

is_class = type(function) is type
if is_class and required[0] == "self":
is_function_type = isinstance(function, FunctionType)
is_method_type = isinstance(function, MethodType)
if is_function_type and required[0] == "self":
required = required[1:]
elif is_method_type and required[0] == "cls":
required = required[1:]
return required

Expand Down
34 changes: 33 additions & 1 deletion libs/core/tests/unit_tests/utils/test_function_calling.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,42 @@ def json_schema() -> Dict:
}


@pytest.fixture()
def dummy_instance_method() -> object:
class Dummy:
def dummy_function(self, arg1: int, arg2: Literal["bar", "baz"]) -> None:
"""dummy function
Args:
arg1: foo
arg2: one of 'bar', 'baz'
"""
pass
return Dummy


@pytest.fixture()
def dummy_class_method() -> object:
class Dummy:
@classmethod
def dummy_function(cls, arg1: int, arg2: Literal["bar", "baz"]) -> None:
"""dummy function
Args:
arg1: foo
arg2: one of 'bar', 'baz'
"""
pass
return Dummy


def test_convert_to_openai_function(
pydantic: Type[BaseModel],
function: Callable,
dummy_tool: BaseTool,
json_schema: Dict,
dummy_instance_method: object,
dummy_class_method: object,
) -> None:
expected = {
"name": "dummy_function",
Expand All @@ -94,7 +125,8 @@ def test_convert_to_openai_function(
},
}

for fn in (pydantic, function, dummy_tool, json_schema, expected):
for fn in (pydantic, function, dummy_tool, json_schema, expected,
dummy_instance_method.dummy_function, dummy_class_method.dummy_function):
actual = convert_to_openai_function(fn) # type: ignore
assert actual == expected

Expand Down

0 comments on commit c4012c2

Please sign in to comment.