Skip to content

Commit

Permalink
feat: supported legacy tool use for Claude 3 models (#98)
Browse files Browse the repository at this point in the history
  • Loading branch information
adubovik authored May 2, 2024
1 parent 91ccf82 commit 1345bb8
Show file tree
Hide file tree
Showing 17 changed files with 471 additions and 423 deletions.
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@ The following models support `POST SERVER_URL/openai/deployments/MODEL_NAME/chat
|anthropic.claude-v1|text-to-text|||||
|anthropic.claude-v2|text-to-text|||||
|anthropic.claude-v2:1|text-to-text|||||
|anthropic.claude-3-sonnet-20240229-v1:0|text-to-text, image-to-text|||||
|anthropic.claude-3-haiku-20240307-v1:0|text-to-text, image-to-text|||||
|anthropic.claude-3-opus-20240229-v1:0|text-to-text, image-to-text|||||
|anthropic.claude-3-sonnet-20240229-v1:0|text-to-text, image-to-text|||||
|anthropic.claude-3-haiku-20240307-v1:0|text-to-text, image-to-text|||||
|anthropic.claude-3-opus-20240229-v1:0|text-to-text, image-to-text|||||
|stability.stable-diffusion-xl|text-to-image|||||
|meta.llama2-13b-chat-v1|text-to-text|||||
|meta.llama2-70b-chat-v1|text-to-text|||||
Expand Down
44 changes: 16 additions & 28 deletions aidial_adapter_bedrock/llm/chat_model.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,15 @@
from abc import ABC, abstractmethod
from typing import AsyncIterator, Callable, List, Optional

from aidial_sdk.chat_completion import Message
from aidial_sdk.chat_completion import Message, Role
from pydantic import BaseModel

import aidial_adapter_bedrock.utils.stream as stream_utils
from aidial_adapter_bedrock.dial_api.request import ModelParameters
from aidial_adapter_bedrock.llm.chat_emulator import ChatEmulator
from aidial_adapter_bedrock.llm.consumer import Consumer
from aidial_adapter_bedrock.llm.exceptions import ValidationError
from aidial_adapter_bedrock.llm.message import (
BaseMessage,
SystemMessage,
parse_message,
)
from aidial_adapter_bedrock.llm.message import BaseMessage, SystemMessage
from aidial_adapter_bedrock.llm.tools.emulator import ToolsEmulator
from aidial_adapter_bedrock.llm.tools.tool_config import ToolConfig
from aidial_adapter_bedrock.llm.truncate_prompt import (
Expand All @@ -25,8 +21,12 @@
from aidial_adapter_bedrock.utils.not_implemented import not_implemented


def _is_empty_system_message(msg: BaseMessage) -> bool:
return isinstance(msg, SystemMessage) and msg.content.strip() == ""
def _is_empty_system_message(msg: Message) -> bool:
return (
msg.role == Role.SYSTEM
and msg.content is not None
and msg.content.strip() == ""
)


class ChatCompletionAdapter(ABC, BaseModel):
Expand Down Expand Up @@ -78,10 +78,7 @@ def truncate_and_linearize_messages(
) -> TextCompletionPrompt:
pass

def validate_base_messages(
self,
messages: List[BaseMessage],
) -> List[BaseMessage]:
def preprocess_messages(self, messages: List[Message]) -> List[Message]:
# Skipping empty system messages
messages = [
msg for msg in messages if not _is_empty_system_message(msg)
Expand All @@ -92,25 +89,14 @@ def validate_base_messages(

return messages

def get_base_messages(
self, params: ModelParameters, messages: List[Message]
) -> List[BaseMessage]:
parsed_messages = list(map(parse_message, messages))
base_messages = self.tools_emulator(
params.tool_config
).convert_to_base_messages(parsed_messages)
return self.validate_base_messages(base_messages)

def get_text_completion_prompt(
self, params: ModelParameters, messages: List[Message]
) -> TextCompletionPrompt:
tools_emulator = self.tools_emulator(params.tool_config)

base_messages = self.get_base_messages(params, messages)

(base_messages, tool_stop_sequences) = (
tools_emulator.add_tool_declarations(base_messages)
)
messages = self.preprocess_messages(messages)
tools_emulator = self.tools_emulator(params.tool_config)
base_messages = tools_emulator.parse_dial_messages(messages)
tool_stop_sequences = tools_emulator.get_stop_sequences()

prompt = self.truncate_and_linearize_messages(
base_messages, params.max_prompt_tokens
Expand Down Expand Up @@ -166,7 +152,9 @@ class PseudoChatModel(TextCompletionAdapter):
async def count_prompt_tokens(
self, params: ModelParameters, messages: List[Message]
) -> int:
base_messages = self.get_base_messages(params, messages)
messages = self.preprocess_messages(messages)
tools_emulator = self.tools_emulator(params.tool_config)
base_messages = tools_emulator.parse_dial_messages(messages)
return self.tokenize_messages(base_messages)

async def count_completion_tokens(self, string: str) -> int:
Expand Down
29 changes: 23 additions & 6 deletions aidial_adapter_bedrock/llm/message.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
from typing import List, Optional, Union

from aidial_sdk.chat_completion import FunctionCall, Message, Role, ToolCall
from aidial_sdk.chat_completion import (
CustomContent,
FunctionCall,
Message,
Role,
ToolCall,
)
from pydantic import BaseModel

from aidial_adapter_bedrock.llm.exceptions import ValidationError
Expand All @@ -12,6 +18,7 @@ class SystemMessage(BaseModel):

class HumanRegularMessage(BaseModel):
content: str
custom_content: Optional[CustomContent] = None


class HumanToolResultMessage(BaseModel):
Expand All @@ -26,6 +33,7 @@ class HumanFunctionResultMessage(BaseModel):

class AIRegularMessage(BaseModel):
content: str
custom_content: Optional[CustomContent] = None


class AIToolCallMessage(BaseModel):
Expand All @@ -37,6 +45,7 @@ class AIFunctionCallMessage(BaseModel):


BaseMessage = Union[SystemMessage, HumanRegularMessage, AIRegularMessage]

ToolMessage = Union[
HumanToolResultMessage,
HumanFunctionResultMessage,
Expand All @@ -49,9 +58,10 @@ def _parse_assistant_message(
content: Optional[str],
function_call: Optional[FunctionCall],
tool_calls: Optional[List[ToolCall]],
custom_content: Optional[CustomContent],
) -> BaseMessage | ToolMessage:
if content is not None and function_call is None and tool_calls is None:
return AIRegularMessage(content=content)
return AIRegularMessage(content=content, custom_content=custom_content)

if content is None and function_call is not None and tool_calls is None:
return AIFunctionCallMessage(call=function_call)
Expand All @@ -67,19 +77,26 @@ def _parse_assistant_message(
)


def parse_message(msg: Message) -> BaseMessage | ToolMessage:
def parse_dial_message(msg: Message) -> BaseMessage | ToolMessage:
match msg:
case Message(role=Role.SYSTEM, content=content) if content is not None:
return SystemMessage(content=content)
case Message(role=Role.USER, content=content) if content is not None:
return HumanRegularMessage(content=content)
case Message(
role=Role.USER, content=content, custom_content=custom_content
) if content is not None:
return HumanRegularMessage(
content=content, custom_content=custom_content
)
case Message(
role=Role.ASSISTANT,
content=content,
function_call=function_call,
tool_calls=tool_calls,
custom_content=custom_content,
):
return _parse_assistant_message(content, function_call, tool_calls)
return _parse_assistant_message(
content, function_call, tool_calls, custom_content
)
case Message(
role=Role.FUNCTION, name=name, content=content
) if content is not None and name is not None:
Expand Down
12 changes: 7 additions & 5 deletions aidial_adapter_bedrock/llm/model/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@
from aidial_adapter_bedrock.llm.chat_model import ChatCompletionAdapter
from aidial_adapter_bedrock.llm.model.ai21 import AI21Adapter
from aidial_adapter_bedrock.llm.model.amazon import AmazonAdapter
from aidial_adapter_bedrock.llm.model.anthropic import (
AnthropicAdapter,
AnthropicChat,
from aidial_adapter_bedrock.llm.model.claude.v1_v2.adapter import (
Adapter as Claude_V1_V2,
)
from aidial_adapter_bedrock.llm.model.claude.v3.adapter import (
Adapter as Claude_V3,
)
from aidial_adapter_bedrock.llm.model.cohere import CohereAdapter
from aidial_adapter_bedrock.llm.model.meta import MetaAdapter
Expand All @@ -24,13 +26,13 @@ async def get_bedrock_adapter(
| BedrockDeployment.ANTHROPIC_CLAUDE_V3_HAIKU
| BedrockDeployment.ANTHROPIC_CLAUDE_V3_OPUS
):
return AnthropicChat.create(model, region, headers)
return Claude_V3.create(model, region, headers)
case (
BedrockDeployment.ANTHROPIC_CLAUDE_INSTANT_V1
| BedrockDeployment.ANTHROPIC_CLAUDE_V2
| BedrockDeployment.ANTHROPIC_CLAUDE_V2_1
):
return await AnthropicAdapter.create(
return await Claude_V1_V2.create(
await Bedrock.acreate(region), model
)
case (
Expand Down
8 changes: 3 additions & 5 deletions aidial_adapter_bedrock/llm/model/amazon.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Any, AsyncIterator, Dict, List, Optional

from aidial_sdk.chat_completion import Message
from pydantic import BaseModel
from typing_extensions import override

Expand All @@ -12,7 +13,6 @@
default_partitioner,
)
from aidial_adapter_bedrock.llm.consumer import Consumer
from aidial_adapter_bedrock.llm.message import BaseMessage
from aidial_adapter_bedrock.llm.model.conf import DEFAULT_MAX_TOKENS_AMAZON
from aidial_adapter_bedrock.llm.tokenize import default_tokenize_string
from aidial_adapter_bedrock.llm.tools.default_emulator import (
Expand Down Expand Up @@ -117,10 +117,8 @@ def create(cls, client: Bedrock, model: str):
)

@override
def validate_base_messages(
self, messages: List[BaseMessage]
) -> List[BaseMessage]:
messages = super().validate_base_messages(messages)
def preprocess_messages(self, messages: List[Message]) -> List[Message]:
messages = super().preprocess_messages(messages)

# AWS Titan doesn't support empty messages,
# so we replace it with a single space.
Expand Down
Loading

0 comments on commit 1345bb8

Please sign in to comment.