Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: claude 3 tool calling #70

Merged
merged 19 commits into from
Jun 12, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
183 changes: 153 additions & 30 deletions libs/aws/langchain_aws/chat_models/bedrock.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import re
from collections import defaultdict
from typing import (
Expand Down Expand Up @@ -29,12 +30,18 @@
HumanMessage,
SystemMessage,
)
from langchain_core.messages.tool import ToolCall, ToolMessage
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import BaseModel, Extra
from langchain_core.runnables import Runnable
from langchain_core.tools import BaseTool

from langchain_aws.function_calling import convert_to_anthropic_tool, get_system_message
from langchain_aws.function_calling import (
_lc_tool_calls_to_anthropic_tool_use_blocks,
_tools_in_params,
convert_to_anthropic_tool,
get_system_message,
)
from langchain_aws.llms.bedrock import (
BedrockBase,
_combine_generation_info_for_llm_result,
Expand Down Expand Up @@ -197,23 +204,61 @@ def _format_image(image_url: str) -> Dict:
}


def _merge_messages(
messages: Sequence[BaseMessage],
) -> List[Union[SystemMessage, AIMessage, HumanMessage]]:
"""Merge runs of human/tool messages into single human messages with content blocks.""" # noqa: E501
merged: list = []
for curr in messages:
curr = curr.copy(deep=True)
if isinstance(curr, ToolMessage):
if isinstance(curr.content, str):
curr = HumanMessage( # type: ignore[misc]
[
{
"type": "tool_result",
"content": curr.content,
"tool_use_id": curr.tool_call_id,
}
]
)
else:
curr = HumanMessage(curr.content) # type: ignore[misc]
laithalsaadoon marked this conversation as resolved.
Show resolved Hide resolved
last = merged[-1] if merged else None
if isinstance(last, HumanMessage) and isinstance(curr, HumanMessage):
if isinstance(last.content, str):
new_content: List = [{"type": "text", "text": last.content}]
else:
new_content = last.content
if isinstance(curr.content, str):
new_content.append({"type": "text", "text": curr.content})
else:
new_content.extend(curr.content)
last.content = new_content
else:
merged.append(curr)
return merged


def _format_anthropic_messages(
messages: List[BaseMessage],
) -> Tuple[Optional[str], List[Dict]]:
"""Format messages for anthropic."""

"""
[
{
"role": _message_type_lookups[m.type],
"content": [_AnthropicMessageContent(text=m.content).dict()],
}
for m in messages
]
{
"role": _message_type_lookups[m.type],
"content": [_AnthropicMessageContent(text=m.content).dict()],
}
for m in messages
]
"""
system: Optional[str] = None
formatted_messages: List[Dict] = []
for i, message in enumerate(messages):

merged_messages = _merge_messages(messages)
for i, message in enumerate(merged_messages):
if message.type == "system":
if i != 0:
raise ValueError("System message must be at beginning of message list.")
Expand All @@ -226,7 +271,7 @@ def _format_anthropic_messages(
continue

role = _message_type_lookups[message.type]
content: Union[str, List[Dict]]
content: Union[str, List]

if not isinstance(message.content, str):
# parse as dict
Expand All @@ -238,39 +283,58 @@ def _format_anthropic_messages(
content = []
for item in message.content:
if isinstance(item, str):
content.append(
{
"type": "text",
"text": item,
}
)
content.append({"type": "text", "text": item})
elif isinstance(item, dict):
if "type" not in item:
raise ValueError("Dict content item must have a type key")
if item["type"] == "image_url":
elif item["type"] == "image_url":
# convert format
source = _format_image(item["image_url"]["url"])
content.append(
{
"type": "image",
"source": source,
}
)
content.append({"type": "image", "source": source})
elif item["type"] == "tool_use":
# If a tool_call with the same id as a tool_use content block
# exists, the tool_call is preferred.
if isinstance(message, AIMessage) and item["id"] in [
tc["id"] for tc in message.tool_calls
]:
overlapping = [
tc
for tc in message.tool_calls
if tc["id"] == item["id"]
]
content.extend(
_lc_tool_calls_to_anthropic_tool_use_blocks(overlapping)
)
else:
item.pop("text", None)
content.append(item)
elif item["type"] == "text":
text = item.get("text", "")
# Only add non-empty strings for now as empty ones are not
# accepted.
# https://github.com/anthropics/anthropic-sdk-python/issues/461
if text.strip():
content.append({"type": "text", "text": text})
else:
content.append(item)
else:
raise ValueError(
f"Content items must be str or dict, instead was: {type(item)}"
)
elif isinstance(message, AIMessage) and message.tool_calls:
content = (
[]
if not message.content
else [{"type": "text", "text": message.content}]
)
# Note: Anthropic can't have invalid tool calls as presently defined,
# since the model already returns dicts args not JSON strings, and invalid
# tool calls are those with invalid JSON for args.
content += _lc_tool_calls_to_anthropic_tool_use_blocks(message.tool_calls)
else:
content = message.content

formatted_messages.append(
{
"role": role,
"content": content,
}
)
formatted_messages.append({"role": role, "content": content})
return system, formatted_messages


Expand Down Expand Up @@ -363,10 +427,38 @@ def _stream(
provider = self._get_provider()
prompt, system, formatted_messages = None, None, None

if "claude-3" in self._get_model():
if _tools_in_params({**kwargs}):
result = self._generate(
messages, stop=stop, run_manager=run_manager, **kwargs
)
message = result.generations[0].message
if isinstance(message, AIMessage) and message.tool_calls is not None:
tool_call_chunks = [
{
"name": tool_call["name"],
"args": json.dumps(tool_call["args"]),
"id": tool_call["id"],
"index": idx,
}
for idx, tool_call in enumerate(message.tool_calls)
]
message_chunk = AIMessageChunk(
content=message.content,
tool_call_chunks=tool_call_chunks, # type: ignore[arg-type]
usage_metadata=message.usage_metadata,
)
yield ChatGenerationChunk(message=message_chunk)
else:
yield cast(ChatGenerationChunk, result.generations[0])
return
if provider == "anthropic":
system, formatted_messages = ChatPromptAdapter.format_messages(
provider, messages
)
# use tools the new way with claude 3
# if "claude-3" in self._get_model():
# if _tools_in_params()
baskaryan marked this conversation as resolved.
Show resolved Hide resolved
if self.system_prompt_with_tools:
if system:
system = self.system_prompt_with_tools + f"\n{system}"
Expand Down Expand Up @@ -403,6 +495,7 @@ def _generate(
) -> ChatResult:
completion = ""
llm_output: Dict[str, Any] = {}
tool_calls: List[Dict[str, Any]] = []
provider_stop_reason_code = self.provider_stop_reason_key_map.get(
self._get_provider(), "stop_reason"
)
Expand All @@ -411,6 +504,8 @@ def _generate(
for chunk in self._stream(messages, stop, run_manager, **kwargs):
completion += chunk.text
response_metadata.append(chunk.message.response_metadata)
if "tool_calls" in chunk.message.additional_kwargs.keys():
tool_calls = chunk.message.additional_kwargs["tool_calls"]
llm_output = _combine_generation_info_for_llm_result(
response_metadata, provider_stop_reason_code
)
Expand All @@ -423,6 +518,7 @@ def _generate(
system, formatted_messages = ChatPromptAdapter.format_messages(
provider, messages
)
# use tools the new way with claude 3
if self.system_prompt_with_tools:
if system:
system = self.system_prompt_with_tools + f"\n{system}"
Expand All @@ -436,7 +532,7 @@ def _generate(
if stop:
params["stop_sequences"] = stop

completion, llm_output = self._prepare_input_and_invoke(
completion, tool_calls, llm_output = self._prepare_input_and_invoke(
prompt=prompt,
stop=stop,
run_manager=run_manager,
Expand All @@ -446,10 +542,18 @@ def _generate(
)

llm_output["model_id"] = self.model_id
if len(tool_calls) > 0:
msg = AIMessage(
content=completion,
additional_kwargs=llm_output,
tool_calls=cast(List[ToolCall], tool_calls),
)
else:
msg = AIMessage(content=completion, additional_kwargs=llm_output)
return ChatResult(
generations=[
ChatGeneration(
message=AIMessage(content=completion, additional_kwargs=llm_output)
message=msg,
)
],
llm_output=llm_output,
Expand Down Expand Up @@ -511,6 +615,25 @@ def bind_tools(

if provider == "anthropic":
formatted_tools = [convert_to_anthropic_tool(tool) for tool in tools]

# true if the model is a claude 3 model
if "claude-3" in self._get_model():
if not tool_choice:
pass
elif isinstance(tool_choice, dict):
kwargs["tool_choice"] = tool_choice
elif isinstance(tool_choice, str) and tool_choice in ("any", "auto"):
kwargs["tool_choice"] = {"type": tool_choice}
elif isinstance(tool_choice, str):
kwargs["tool_choice"] = {"type": "tool", "name": tool_choice}
else:
raise ValueError(
f"Unrecognized 'tool_choice' type {tool_choice=}."
f"Expected dict, str, or None."
)
return self.bind(tools=formatted_tools, **kwargs)

# add tools to the system prompt, the old way
system_formatted_tools = get_system_message(formatted_tools)
self.set_system_prompt_with_tools(system_formatted_tools)
return self
Expand Down
31 changes: 31 additions & 0 deletions libs/aws/langchain_aws/function_calling.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@
Literal,
Type,
Union,
cast,
)

from langchain_core.messages.tool import ToolCall
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.tools import BaseTool
from langchain_core.utils.function_calling import convert_to_openai_tool
Expand Down Expand Up @@ -63,6 +65,35 @@ class AnthropicTool(TypedDict):
input_schema: Dict[str, Any]


def _tools_in_params(params: dict) -> bool:
return "tools" in params or (
"extra_body" in params and params["extra_body"].get("tools")
)


class _AnthropicToolUse(TypedDict):
type: Literal["tool_use"]
name: str
input: dict
id: str


def _lc_tool_calls_to_anthropic_tool_use_blocks(
tool_calls: List[ToolCall],
) -> List[_AnthropicToolUse]:
blocks = []
for tool_call in tool_calls:
blocks.append(
_AnthropicToolUse(
type="tool_use",
name=tool_call["name"],
input=tool_call["args"],
id=cast(str, tool_call["id"]),
)
)
return blocks


def _get_type(parameter: Dict[str, Any]) -> str:
if "type" in parameter:
return parameter["type"]
Expand Down
Loading
Loading