Skip to content

Commit

Permalink
anthropic[patch]: handle lists in function calling (#18609)
Browse files Browse the repository at this point in the history
  • Loading branch information
efriis authored Mar 5, 2024
1 parent 1831733 commit e169ee8
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 16 deletions.
59 changes: 46 additions & 13 deletions libs/partners/anthropic/langchain_anthropic/experimental.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,16 @@
</parameter>"""


def _get_type(parameter: Dict[str, Any]) -> str:
if "type" in parameter:
return parameter["type"]
if "anyOf" in parameter:
return json.dumps({"anyOf": parameter["anyOf"]})
if "allOf" in parameter:
return json.dumps({"allOf": parameter["allOf"]})
return json.dumps(parameter)


def get_system_message(tools: List[Dict]) -> str:
tools_data: List[Dict] = [
{
Expand All @@ -78,7 +88,7 @@ def get_system_message(tools: List[Dict]) -> str:
[
TOOL_PARAMETER_FORMAT.format(
parameter_name=name,
parameter_type=parameter["type"],
parameter_type=_get_type(parameter),
parameter_description=parameter.get("description"),
)
for name, parameter in tool["parameters"]["properties"].items()
Expand Down Expand Up @@ -118,21 +128,44 @@ def _xml_to_dict(t: Any) -> Union[str, Dict[str, Any]]:
return d


def _xml_to_tool_calls(elem: Any) -> List[Dict[str, Any]]:
def _xml_to_function_call(invoke: Any, tools: List[Dict]) -> Dict[str, Any]:
name = invoke.find("tool_name").text
arguments = _xml_to_dict(invoke.find("parameters"))

# make list elements in arguments actually lists
filtered_tools = [tool for tool in tools if tool["name"] == name]
if len(filtered_tools) > 0 and not isinstance(arguments, str):
tool = filtered_tools[0]
for key, value in arguments.items():
if key in tool["parameters"]["properties"]:
if "type" in tool["parameters"]["properties"][key]:
if tool["parameters"]["properties"][key][
"type"
] == "array" and not isinstance(value, list):
arguments[key] = [value]
if (
tool["parameters"]["properties"][key]["type"] != "object"
and isinstance(value, dict)
and len(value.keys()) == 1
):
arguments[key] = list(value.values())[0]

return {
"function": {
"name": name,
"arguments": json.dumps(arguments),
},
"type": "function",
}


def _xml_to_tool_calls(elem: Any, tools: List[Dict]) -> List[Dict[str, Any]]:
"""
Convert an XML element and its children into a dictionary of dictionaries.
"""
invokes = elem.findall("invoke")
return [
{
"function": {
"name": invoke.find("tool_name").text,
"arguments": json.dumps(_xml_to_dict(invoke.find("parameters"))),
},
"type": "function",
}
for invoke in invokes
]

return [_xml_to_function_call(invoke, tools) for invoke in invokes]


@beta()
Expand Down Expand Up @@ -262,7 +295,7 @@ def _format_output(self, data: Any, **kwargs: Any) -> ChatResult:
xml_text = text[start:end]

xml = self._xmllib.fromstring(xml_text)
additional_kwargs["tool_calls"] = _xml_to_tool_calls(xml)
additional_kwargs["tool_calls"] = _xml_to_tool_calls(xml, tools)
text = ""
except Exception:
pass
Expand Down
4 changes: 2 additions & 2 deletions libs/partners/anthropic/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "langchain-anthropic"
version = "0.1.2"
version = "0.1.3"
description = "An integration package connecting AnthropicMessages and LangChain"
authors = []
readme = "README.md"
Expand All @@ -14,7 +14,7 @@ license = "MIT"
python = ">=3.8.1,<4.0"
langchain-core = "^0.1"
anthropic = ">=0.17.0,<1"
defusedxml = {version = "^0.7.1", optional = true}
defusedxml = { version = "^0.7.1", optional = true }

[tool.poetry.group.test]
optional = true
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
"""Test ChatAnthropic chat model."""

import json
from enum import Enum
from typing import List, Optional

from langchain_core.prompts import ChatPromptTemplate
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.pydantic_v1 import BaseModel, Field

from langchain_anthropic.experimental import ChatAnthropicTools

Expand Down Expand Up @@ -129,3 +131,49 @@ class Person(BaseModel):
assert isinstance(result, Person)
assert result.name == "Erick"
assert result.age == 27


def test_anthropic_complex_structured_output() -> None:
class ToneEnum(str, Enum):
positive = "positive"
negative = "negative"

class Email(BaseModel):
"""Relevant information about an email."""

sender: Optional[str] = Field(
None, description="The sender's name, if available"
)
sender_phone_number: Optional[str] = Field(
None, description="The sender's phone number, if available"
)
sender_address: Optional[str] = Field(
None, description="The sender's address, if available"
)
action_items: List[str] = Field(
..., description="A list of action items requested by the email"
)
topic: str = Field(
..., description="High level description of what the email is about"
)
tone: ToneEnum = Field(..., description="The tone of the email.")

prompt = ChatPromptTemplate.from_messages(
[
(
"human",
"What can you tell me about the following email? Make sure to answer in the correct format: {email}", # noqa: E501
),
]
)

llm = ChatAnthropicTools(temperature=0, model_name="claude-3-sonnet-20240229")

extraction_chain = prompt | llm.with_structured_output(Email)

response = extraction_chain.invoke(
{
"email": "From: Erick. The email is about the new project. The tone is positive. The action items are to send the report and to schedule a meeting." # noqa: E501
}
) # noqa: E501
assert isinstance(response, Email)

0 comments on commit e169ee8

Please sign in to comment.