Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: claude 3 tool calling #70

Merged
merged 19 commits into from
Jun 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
336 changes: 295 additions & 41 deletions libs/aws/langchain_aws/chat_models/bedrock.py

Large diffs are not rendered by default.

83 changes: 83 additions & 0 deletions libs/aws/langchain_aws/function_calling.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,16 @@
Dict,
List,
Literal,
Optional,
Type,
Union,
cast,
)

from langchain_core.messages import ToolCall
from langchain_core.output_parsers import BaseGenerationOutputParser
from langchain_core.outputs import ChatGeneration, Generation
from langchain_core.prompts.chat import AIMessage
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.tools import BaseTool
from langchain_core.utils.function_calling import convert_to_openai_tool
Expand Down Expand Up @@ -63,6 +69,35 @@ class AnthropicTool(TypedDict):
input_schema: Dict[str, Any]


def _tools_in_params(params: dict) -> bool:
return "tools" in params or (
"extra_body" in params and params["extra_body"].get("tools")
)


class _AnthropicToolUse(TypedDict):
type: Literal["tool_use"]
name: str
input: dict
id: str


def _lc_tool_calls_to_anthropic_tool_use_blocks(
tool_calls: List[ToolCall],
) -> List[_AnthropicToolUse]:
blocks = []
for tool_call in tool_calls:
blocks.append(
_AnthropicToolUse(
type="tool_use",
name=tool_call["name"],
input=tool_call["args"],
id=cast(str, tool_call["id"]),
)
)
return blocks


def _get_type(parameter: Dict[str, Any]) -> str:
if "type" in parameter:
return parameter["type"]
Expand Down Expand Up @@ -122,6 +157,54 @@ class ToolDescription(TypedDict):
function: FunctionDescription


class ToolsOutputParser(BaseGenerationOutputParser):
first_tool_only: bool = False
args_only: bool = False
pydantic_schemas: Optional[List[Type[BaseModel]]] = None

class Config:
extra = "forbid"

def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
"""Parse a list of candidate model Generations into a specific format.

Args:
result: A list of Generations to be parsed. The Generations are assumed
to be different candidate outputs for a single model input.

Returns:
Structured output.
"""
if not result or not isinstance(result[0], ChatGeneration):
return None if self.first_tool_only else []
message = result[0].message
if len(message.content) > 0:
tool_calls: List = []
else:
content = cast(AIMessage, message)
_tool_calls = [dict(tc) for tc in content.tool_calls]
# Map tool call id to index
id_to_index = {block["id"]: i for i, block in enumerate(_tool_calls)}
tool_calls = [{**tc, "index": id_to_index[tc["id"]]} for tc in _tool_calls]
if self.pydantic_schemas:
tool_calls = [self._pydantic_parse(tc) for tc in tool_calls]
elif self.args_only:
tool_calls = [tc["args"] for tc in tool_calls]
else:
pass

if self.first_tool_only:
return tool_calls[0] if tool_calls else None
else:
return [tool_call for tool_call in tool_calls]

def _pydantic_parse(self, tool_call: dict) -> BaseModel:
cls_ = {schema.__name__: schema for schema in self.pydantic_schemas or []}[
tool_call["name"]
]
return cls_(**tool_call["args"])


def convert_to_anthropic_tool(
tool: Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool],
) -> AnthropicTool:
Expand Down
117 changes: 96 additions & 21 deletions libs/aws/langchain_aws/llms/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
Mapping,
Optional,
Tuple,
TypedDict,
Union,
)

Expand All @@ -21,10 +22,12 @@
CallbackManagerForLLMRun,
)
from langchain_core.language_models import LLM, BaseLanguageModel
from langchain_core.messages import ToolCall
from langchain_core.outputs import Generation, GenerationChunk, LLMResult
from langchain_core.pydantic_v1 import Extra, Field, root_validator
from langchain_core.utils import get_from_dict_or_env

from langchain_aws.function_calling import _tools_in_params
from langchain_aws.utils import (
enforce_stop_tokens,
get_num_tokens_anthropic,
Expand Down Expand Up @@ -81,7 +84,10 @@ def _human_assistant_format(input_text: str) -> str:


def _stream_response_to_generation_chunk(
stream_response: Dict[str, Any], provider: str, output_key: str, messages_api: bool
stream_response: Dict[str, Any],
provider: str,
output_key: str,
messages_api: bool,
) -> Union[GenerationChunk, None]:
"""Convert a stream response to a generation chunk."""
if messages_api:
Expand Down Expand Up @@ -174,6 +180,23 @@ def _combine_generation_info_for_llm_result(
return {"usage": total_usage_info, "stop_reason": stop_reason}


def extract_tool_calls(content: List[dict]) -> List[ToolCall]:
tool_calls = []
for block in content:
if block["type"] != "tool_use":
continue
tool_calls.append(
ToolCall(name=block["name"], args=block["input"], id=block["id"])
)
return tool_calls


class AnthropicTool(TypedDict):
name: str
description: str
input_schema: Dict[str, Any]


class LLMInputOutputAdapter:
"""Adapter class to prepare the inputs from Langchain to a format
that LLM model expects.
Expand All @@ -197,10 +220,13 @@ def prepare_input(
prompt: Optional[str] = None,
system: Optional[str] = None,
messages: Optional[List[Dict]] = None,
tools: Optional[List[AnthropicTool]] = None,
) -> Dict[str, Any]:
input_body = {**model_kwargs}
if provider == "anthropic":
if messages:
if tools:
input_body["tools"] = tools
input_body["anthropic_version"] = "bedrock-2023-05-31"
input_body["messages"] = messages
if system:
Expand All @@ -225,16 +251,20 @@ def prepare_input(
@classmethod
def prepare_output(cls, provider: str, response: Any) -> dict:
text = ""
tool_calls = []
response_body = json.loads(response.get("body").read().decode())

if provider == "anthropic":
response_body = json.loads(response.get("body").read().decode())
if "completion" in response_body:
text = response_body.get("completion")
elif "content" in response_body:
content = response_body.get("content")
text = content[0].get("text")
else:
response_body = json.loads(response.get("body").read())
if len(content) == 1 and content[0]["type"] == "text":
text = content[0]["text"]
elif any(block["type"] == "tool_use" for block in content):
tool_calls = extract_tool_calls(content)

else:
if provider == "ai21":
text = response_body.get("completions")[0].get("data").get("text")
elif provider == "cohere":
Expand All @@ -251,6 +281,7 @@ def prepare_output(cls, provider: str, response: Any) -> dict:
completion_tokens = int(headers.get("x-amzn-bedrock-output-token-count", 0))
return {
"text": text,
"tool_calls": tool_calls,
"body": response_body,
"usage": {
"prompt_tokens": prompt_tokens,
Expand Down Expand Up @@ -584,19 +615,32 @@ def _prepare_input_and_invoke(
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Tuple[str, Dict[str, Any]]:
) -> Tuple[
str,
List[dict],
Dict[str, Any],
]:
_model_kwargs = self.model_kwargs or {}

provider = self._get_provider()
params = {**_model_kwargs, **kwargs}

input_body = LLMInputOutputAdapter.prepare_input(
provider=provider,
model_kwargs=params,
prompt=prompt,
system=system,
messages=messages,
)
if "claude-3" in self._get_model():
if _tools_in_params(params):
input_body = LLMInputOutputAdapter.prepare_input(
provider=provider,
model_kwargs=params,
prompt=prompt,
system=system,
messages=messages,
tools=params["tools"],
)
body = json.dumps(input_body)
accept = "application/json"
contentType = "application/json"
Expand All @@ -621,9 +665,13 @@ def _prepare_input_and_invoke(
try:
response = self.client.invoke_model(**request_options)

text, body, usage_info, stop_reason = LLMInputOutputAdapter.prepare_output(
provider, response
).values()
(
text,
tool_calls,
body,
usage_info,
stop_reason,
) = LLMInputOutputAdapter.prepare_output(provider, response).values()

except Exception as e:
raise ValueError(f"Error raised by bedrock service: {e}")
Expand All @@ -646,7 +694,7 @@ def _prepare_input_and_invoke(
**services_trace,
)

return text, llm_output
return text, tool_calls, llm_output

def _get_bedrock_services_signal(self, body: dict) -> dict:
"""
Expand Down Expand Up @@ -711,6 +759,16 @@ def _prepare_input_and_invoke_stream(
messages=messages,
model_kwargs=params,
)
if "claude-3" in self._get_model():
if _tools_in_params(params):
input_body = LLMInputOutputAdapter.prepare_input(
provider=provider,
model_kwargs=params,
prompt=prompt,
system=system,
messages=messages,
tools=params["tools"],
)
body = json.dumps(input_body)

request_options = {
Expand All @@ -737,7 +795,10 @@ def _prepare_input_and_invoke_stream(
raise ValueError(f"Error raised by bedrock service: {e}")

for chunk in LLMInputOutputAdapter.prepare_output_stream(
provider, response, stop, True if messages else False
provider,
response,
stop,
True if messages else False,
):
yield chunk
# verify and raise callback error if any middleware intervened
Expand Down Expand Up @@ -770,13 +831,24 @@ async def _aprepare_input_and_invoke_stream(
_model_kwargs["stream"] = True

params = {**_model_kwargs, **kwargs}
input_body = LLMInputOutputAdapter.prepare_input(
provider=provider,
prompt=prompt,
system=system,
messages=messages,
model_kwargs=params,
)
if "claude-3" in self._get_model():
if _tools_in_params(params):
input_body = LLMInputOutputAdapter.prepare_input(
provider=provider,
model_kwargs=params,
prompt=prompt,
system=system,
messages=messages,
tools=params["tools"],
)
else:
input_body = LLMInputOutputAdapter.prepare_input(
provider=provider,
prompt=prompt,
system=system,
messages=messages,
model_kwargs=params,
)
body = json.dumps(input_body)

response = await asyncio.get_running_loop().run_in_executor(
Expand All @@ -790,7 +862,10 @@ async def _aprepare_input_and_invoke_stream(
)

async for chunk in LLMInputOutputAdapter.aprepare_output_stream(
provider, response, stop, True if messages else False
provider,
response,
stop,
True if messages else False,
):
yield chunk
if run_manager is not None and asyncio.iscoroutinefunction(
Expand Down Expand Up @@ -951,7 +1026,7 @@ def _call(

return completion

text, llm_output = self._prepare_input_and_invoke(
text, tool_calls, llm_output = self._prepare_input_and_invoke(
prompt=prompt, stop=stop, run_manager=run_manager, **kwargs
)
if run_manager is not None:
Expand Down
Loading
Loading