From 4e743b54275d6e9dabeaaf92b3ee5d29cbd7c181 Mon Sep 17 00:00:00 2001 From: Filip Ratajczak <90644163+Tesla2000@users.noreply.github.com> Date: Mon, 9 Dec 2024 16:27:25 -0800 Subject: [PATCH] Core: google docstring parsing fix (#28404) Thank you for contributing to LangChain! - [ ] **PR title**: "core: google docstring parsing fix" - [x] **PR message**: - **Description:** Added a solution for invalid parsing of google docstring such as: Args: net_annual_income (float): The user's net annual income (in current year dollars). - **Issue:** Previous code would return arg = "net_annual_income (float)" which would cause exception in _validate_docstring_args_against_annotations - **Dependencies:** None If no one reviews your PR within a few days, please @-mention one of baskaryan, efriis, eyurtsev, ccurme, vbarda, hwchase17. Co-authored-by: Erick Friis --- .../core/langchain_core/utils/function_calling.py | 8 ++++++-- .../unit_tests/utils/test_function_calling.py | 15 +++++++++++++++ 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/libs/core/langchain_core/utils/function_calling.py b/libs/core/langchain_core/utils/function_calling.py index 4779d26244203..e6b70c4ade1d7 100644 --- a/libs/core/langchain_core/utils/function_calling.py +++ b/libs/core/langchain_core/utils/function_calling.py @@ -646,9 +646,13 @@ def _parse_google_docstring( for line in args_block.split("\n")[1:]: if ":" in line: arg, desc = line.split(":", maxsplit=1) - arg_descriptions[arg.strip()] = desc.strip() + arg = arg.strip() + arg_name, _, _annotations = arg.partition(" ") + if _annotations.startswith("(") and _annotations.endswith(")"): + arg = arg_name + arg_descriptions[arg] = desc.strip() elif arg: - arg_descriptions[arg.strip()] += " " + line.strip() + arg_descriptions[arg] += " " + line.strip() return description, arg_descriptions diff --git a/libs/core/tests/unit_tests/utils/test_function_calling.py b/libs/core/tests/unit_tests/utils/test_function_calling.py index ba4c50187f139..bf1a4f56337fe 100644 --- a/libs/core/tests/unit_tests/utils/test_function_calling.py +++ b/libs/core/tests/unit_tests/utils/test_function_calling.py @@ -71,6 +71,19 @@ def dummy_function(arg1: int, arg2: Literal["bar", "baz"]) -> None: return dummy_function +@pytest.fixture() +def function_docstring_annotations() -> Callable: + def dummy_function(arg1: int, arg2: Literal["bar", "baz"]) -> None: + """dummy function + + Args: + arg1 (int): foo + arg2: one of 'bar', 'baz' + """ + + return dummy_function + + @pytest.fixture() def runnable() -> Runnable: class Args(ExtensionsTypedDict): @@ -278,6 +291,7 @@ def dummy_function(cls, arg1: int, arg2: Literal["bar", "baz"]) -> None: def test_convert_to_openai_function( pydantic: type[BaseModel], function: Callable, + function_docstring_annotations: Callable, dummy_structured_tool: StructuredTool, dummy_tool: BaseTool, json_schema: dict, @@ -311,6 +325,7 @@ def test_convert_to_openai_function( for fn in ( pydantic, function, + function_docstring_annotations, dummy_structured_tool, dummy_tool, json_schema,