Skip to content

Commit

Permalink
fix: supported multiple tool call messages for Claude (#165)
Browse files Browse the repository at this point in the history
  • Loading branch information
adubovik authored Oct 25, 2024
1 parent cab277e commit d19bf51
Show file tree
Hide file tree
Showing 11 changed files with 486 additions and 170 deletions.
37 changes: 20 additions & 17 deletions aidial_adapter_bedrock/dial_api/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,23 +69,26 @@ def tools_mode(self) -> ToolsMode | None:
def collect_text_content(
content: MessageContentSpecialized, delimiter: str = "\n\n"
) -> str:

if content is None:
return ""

if isinstance(content, str):
return content

texts: List[str] = []
for part in content:
if isinstance(part, MessageContentTextPart):
texts.append(part.text)
else:
raise ValidationError(
"Can't extract text from a multi-modal content part"
)

return delimiter.join(texts)
match content:
case None:
return ""
case str():
return content
case list():
texts: List[str] = []
for part in content:
match part:
case MessageContentTextPart(text=text):
texts.append(text)
case MessageContentImagePart():
raise ValidationError(
"Can't extract text from an image content part"
)
case _:
assert_never(part)
return delimiter.join(texts)
case _:
assert_never(content)


def to_message_content(content: MessageContentSpecialized) -> MessageContent:
Expand Down
9 changes: 9 additions & 0 deletions aidial_adapter_bedrock/llm/consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,11 @@ def create_function_tool_call(self, tool_call: ToolCall):
def create_function_call(self, function_call: FunctionCall):
pass

@property
@abstractmethod
def has_function_call(self) -> bool:
pass


class ChoiceConsumer(Consumer):
usage: TokenUsage
Expand Down Expand Up @@ -134,3 +139,7 @@ def create_function_call(self, function_call: FunctionCall):
self.choice.create_function_call(
name=function_call.name, arguments=function_call.arguments
)

@property
def has_function_call(self) -> bool:
return self.choice.has_function_call
105 changes: 63 additions & 42 deletions aidial_adapter_bedrock/llm/model/claude/v3/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@
from anthropic.lib.bedrock import AsyncAnthropicBedrock
from anthropic.lib.streaming import (
AsyncMessageStream,
ContentBlockStopEvent,
InputJsonEvent,
TextEvent,
)
from anthropic.types import (
ContentBlockDeltaEvent,
ContentBlockStartEvent,
ContentBlockStopEvent,
MessageDeltaEvent,
)
from anthropic.types import MessageParam as ClaudeMessage
Expand Down Expand Up @@ -65,6 +65,7 @@
truncate_prompt,
)
from aidial_adapter_bedrock.utils.json import json_dumps_short
from aidial_adapter_bedrock.utils.list_projection import ListProjection
from aidial_adapter_bedrock.utils.log_config import bedrock_logger as log


Expand All @@ -89,7 +90,7 @@ async def on_stream_event(self, event: MessageStreamEvent):
@dataclass
class ClaudeRequest:
params: ClaudeParameters
messages: List[ClaudeMessage]
messages: ListProjection[ClaudeMessage]


class Adapter(ChatCompletionAdapter):
Expand All @@ -105,17 +106,20 @@ async def _prepare_claude_request(

tools = NOT_GIVEN
tool_choice: ToolChoice | NotGiven = NOT_GIVEN
if params.tool_config is not None:
if (tool_config := params.tool_config) is not None:
tools = [
to_claude_tool_config(tool_function)
for tool_function in params.tool_config.functions
for tool_function in tool_config.functions
]

tool_choice = (
{"type": "any"}
if params.tool_config.required
else {"type": "auto"}
{"type": "any"} if tool_config.required else {"type": "auto"}
)

# NOTE tool_choice.disable_parallel_tool_use=True option isn't supported
# by older Claude3 versions, so we limit the number of generated function calls
# to one in the adapter itself for the functions mode.

parsed_messages = [
process_with_tools(parse_dial_message(m), params.tools_mode)
for m in messages
Expand All @@ -142,27 +146,29 @@ async def _prepare_claude_request(
return ClaudeRequest(params=claude_params, messages=claude_messages)

async def _compute_discarded_messages(
self,
request: ClaudeRequest,
max_prompt_tokens: int | None,
self, request: ClaudeRequest, max_prompt_tokens: int | None
) -> Tuple[DiscardedMessages | None, ClaudeRequest]:
if max_prompt_tokens is None:
return None, request

discarded_messages, messages = await truncate_prompt(
messages=request.messages,
messages=request.messages.list,
tokenizer=create_tokenizer(self.deployment, request.params),
keep_message=keep_last,
partitioner=turn_based_partitioner,
model_limit=None,
user_limit=max_prompt_tokens,
)

if request.params["system"] is not NOT_GIVEN:
discarded_messages = [idx + 1 for idx in discarded_messages]
claude_messages = ListProjection(messages)

if max_prompt_tokens is None:
discarded_messages = None
discarded_messages = list(
request.messages.to_original_indices(discarded_messages)
)

return discarded_messages, ClaudeRequest(
params=request.params, messages=messages
params=request.params,
messages=claude_messages,
)

async def chat(
Expand Down Expand Up @@ -197,7 +203,7 @@ async def count_prompt_tokens(
) -> int:
request = await self._prepare_claude_request(params, messages)
return await create_tokenizer(self.deployment, request.params)(
request.messages
request.messages.list
)

async def count_completion_tokens(self, string: str) -> int:
Expand Down Expand Up @@ -230,39 +236,48 @@ async def invoke_streaming(
log.debug(f"Streaming request: {msg}")

async with self.client.messages.stream(
messages=request.messages,
messages=request.messages.raw_list,
model=self.deployment.model_id,
**request.params,
) as stream:
prompt_tokens = 0
completion_tokens = 0
stop_reason = None
async for event in stream:
if log.isEnabledFor(DEBUG):
log.debug(
f"claude response event: {json_dumps_short(event)}"
)

match event:
case MessageStartEvent():
prompt_tokens += event.message.usage.input_tokens
case TextEvent():
consumer.append_content(event.text)
case MessageDeltaEvent():
completion_tokens += event.usage.output_tokens
case ContentBlockStopEvent():
if isinstance(event.content_block, ToolUseBlock):
process_tools_block(
consumer, event.content_block, tools_mode
)
case MessageStopEvent():
completion_tokens += event.message.usage.output_tokens
stop_reason = event.message.stop_reason
case MessageStartEvent(message=message):
prompt_tokens += message.usage.input_tokens
case TextEvent(text=text):
consumer.append_content(text)
case MessageDeltaEvent(usage=usage):
completion_tokens += usage.output_tokens
case ContentBlockStopEvent(content_block=content_block):
match content_block:
case ToolUseBlock():
process_tools_block(
consumer, content_block, tools_mode
)
case TextBlock():
# Already handled in TextEvent
pass
case _:
assert_never(content_block)
case MessageStopEvent(message=message):
completion_tokens += message.usage.output_tokens
stop_reason = message.stop_reason
case (
InputJsonEvent()
| ContentBlockStartEvent()
| ContentBlockDeltaEvent()
):
pass
case _:
raise ValueError(
f"Unsupported event type! {type(event)}"
)
assert_never(event)

consumer.close_content(
to_dial_finish_reason(stop_reason, tools_mode)
Expand Down Expand Up @@ -295,18 +310,24 @@ async def invoke_non_streaming(
log.debug(f"Request: {msg}")

message = await self.client.messages.create(
messages=request.messages,
messages=request.messages.raw_list,
model=self.deployment.model_id,
**request.params,
stream=False,
)

if log.isEnabledFor(DEBUG):
log.debug(f"claude response message: {json_dumps_short(message)}")

for content in message.content:
if isinstance(content, TextBlock):
consumer.append_content(content.text)
elif isinstance(content, ToolUseBlock):
process_tools_block(consumer, content, tools_mode)
else:
assert_never(content)
match content:
case TextBlock(text=text):
consumer.append_content(text)
case ToolUseBlock():
process_tools_block(consumer, content, tools_mode)
case _:
assert_never(content)

consumer.close_content(
to_dial_finish_reason(message.stop_reason, tools_mode)
)
Expand Down
71 changes: 55 additions & 16 deletions aidial_adapter_bedrock/llm/model/claude/v3/converters.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import json
from typing import List, Literal, Optional, Tuple, assert_never, cast
from typing import List, Literal, Optional, Set, Tuple, assert_never, cast

from aidial_sdk.chat_completion import (
FinishReason,
Expand Down Expand Up @@ -35,6 +35,8 @@
SystemMessage,
)
from aidial_adapter_bedrock.llm.tools.tools_config import ToolsMode
from aidial_adapter_bedrock.utils.list import group_by
from aidial_adapter_bedrock.utils.list_projection import ListProjection
from aidial_adapter_bedrock.utils.resource import Resource

ClaudeFinishReason = Literal[
Expand Down Expand Up @@ -142,34 +144,69 @@ def _to_claude_tool_result(
)


def _merge_messages_with_same_role(
messages: ListProjection[MessageParam],
) -> ListProjection[MessageParam]:
def _key(message: Tuple[MessageParam, Set[int]]) -> str:
return message[0]["role"]

def _merge(
a: Tuple[MessageParam, Set[int]],
b: Tuple[MessageParam, Set[int]],
) -> Tuple[MessageParam, Set[int]]:
(msg1, set1), (msg2, set2) = a, b

content1 = msg1["content"]
content2 = msg2["content"]

if isinstance(content1, str):
content1 = [
cast(TextBlockParam, {"type": "text", "text": content1})
]

if isinstance(content2, str):
content2 = [
cast(TextBlockParam, {"type": "text", "text": content2})
]

return {
"role": msg1["role"],
"content": list(content1) + list(content2),
}, set1 | set2

return ListProjection(group_by(messages.list, _key, lambda x: x, _merge))


async def to_claude_messages(
messages: List[BaseMessage | HumanToolResultMessage | AIToolCallMessage],
file_storage: Optional[FileStorage],
) -> Tuple[Optional[str], List[MessageParam]]:
if not messages:
return None, []
) -> Tuple[Optional[str], ListProjection[MessageParam]]:

system_prompt: str | None = None
if isinstance(messages[0], SystemMessage):
if messages and isinstance(messages[0], SystemMessage):
system_prompt = messages[0].text_content
messages = messages[1:]

claude_messages: List[MessageParam] = []
for message in messages:
idx_offset = int(system_prompt is not None)

ret: ListProjection[MessageParam] = ListProjection()
for idx, message in enumerate(messages, start=idx_offset):
match message:
case HumanRegularMessage():
claude_messages.append(
ret.append(
MessageParam(
role="user",
content=await _to_claude_message(file_storage, message),
)
),
idx,
)
case AIRegularMessage():
claude_messages.append(
ret.append(
MessageParam(
role="assistant",
content=await _to_claude_message(file_storage, message),
)
),
idx,
)
case AIToolCallMessage():
content: List[TextBlockParam | ToolUseBlockParam] = [
Expand All @@ -178,18 +215,20 @@ async def to_claude_messages(
if message.content is not None:
content.insert(0, _create_text_block(message.content))

claude_messages.append(
ret.append(
MessageParam(
role="assistant",
content=content,
)
),
idx,
)
case HumanToolResultMessage():
claude_messages.append(
ret.append(
MessageParam(
role="user",
content=[_to_claude_tool_result(message)],
)
),
idx,
)
case SystemMessage():
raise ValidationError(
Expand All @@ -198,7 +237,7 @@ async def to_claude_messages(
case _:
assert_never(message)

return system_prompt, claude_messages
return system_prompt, _merge_messages_with_same_role(ret)


def to_dial_finish_reason(
Expand Down
Loading

0 comments on commit d19bf51

Please sign in to comment.