Skip to content

Commit

Permalink
Fixes due to PR comments
Browse files Browse the repository at this point in the history
  • Loading branch information
roman-romanov-o committed Nov 14, 2024
1 parent cf80b4a commit f051a5a
Show file tree
Hide file tree
Showing 6 changed files with 118 additions and 30 deletions.
30 changes: 15 additions & 15 deletions aidial_adapter_bedrock/llm/converse/adapter.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Awaitable, Callable, List, Tuple, cast
from typing import Any, Awaitable, Callable, List, Tuple

from aidial_sdk.chat_completion import Message as DialMessage

Expand All @@ -13,7 +13,7 @@
from aidial_adapter_bedrock.llm.consumer import Consumer
from aidial_adapter_bedrock.llm.converse.input import (
extract_converse_system_prompt,
process_messages,
to_converse_messages,
to_converse_tools,
)
from aidial_adapter_bedrock.llm.converse.output import (
Expand All @@ -33,7 +33,6 @@
DiscardedMessages,
truncate_prompt,
)
from aidial_adapter_bedrock.utils.json import remove_nones
from aidial_adapter_bedrock.utils.list import omit_by_indices
from aidial_adapter_bedrock.utils.list_projection import ListProjection

Expand Down Expand Up @@ -116,8 +115,8 @@ async def construct_converse_params(
params: ModelParameters,
) -> ConverseRequestWrapper:
system_prompt_extraction = extract_converse_system_prompt(messages)
processed_messages = await process_messages(
system_prompt_extraction.modified_messages,
processed_messages = await to_converse_messages(
system_prompt_extraction.non_system_messages,
self.storage,
start_offset=system_prompt_extraction.system_message_count,
)
Expand All @@ -128,16 +127,17 @@ async def construct_converse_params(
return ConverseRequestWrapper(
system=[system_message] if system_message else None,
messages=processed_messages,
inferenceConfig=cast(
InferenceConfig,
remove_nones(
{
"temperature": params.temperature,
"topP": params.top_p,
"maxTokens": params.max_tokens,
"stopSequences": params.stop,
}
),
inferenceConfig=InferenceConfig(
**{
key: value
for key, value in [
("temperature", params.temperature),
("topP", params.top_p),
("maxTokens", params.max_tokens),
("stopSequences", params.stop),
]
if value is not None
}
),
toolConfig=self.get_tool_config(params),
)
Expand Down
10 changes: 5 additions & 5 deletions aidial_adapter_bedrock/llm/converse/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ async def to_converse_message(
class ExtractSystemPromptResult:
system_prompt: ConverseTextPart | None
system_message_count: int
modified_messages: List[DialMessage]
non_system_messages: List[DialMessage]


def extract_converse_system_prompt(
Expand All @@ -253,7 +253,7 @@ def extract_converse_system_prompt(
system_msgs = []
found_non_system = False
system_messages_count = 0
modified_messages = []
non_system_messages = []

for msg in messages:
if msg.role == DialRole.SYSTEM:
Expand All @@ -280,16 +280,16 @@ def extract_converse_system_prompt(
assert_never(msg.content)
else:
found_non_system = True
modified_messages.append(msg)
non_system_messages.append(msg)
combined = "\n\n".join(msg for msg in system_msgs if msg)
return ExtractSystemPromptResult(
system_prompt=ConverseTextPart(text=combined) if combined else None,
system_message_count=system_messages_count,
modified_messages=modified_messages,
non_system_messages=non_system_messages,
)


async def process_messages(
async def to_converse_messages(
messages: List[DialMessage],
storage: FileStorage | None,
# Offset for system messages at the beginning
Expand Down
12 changes: 6 additions & 6 deletions aidial_adapter_bedrock/llm/converse/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from dataclasses import dataclass
from enum import Enum
from typing import Any, Literal, TypedDict, Union
from typing import Any, Literal, Required, TypedDict, Union

from aidial_adapter_bedrock.utils.list_projection import ListProjection

Expand Down Expand Up @@ -100,14 +100,14 @@ class ConverseMessage(TypedDict):


class InferenceConfig(TypedDict, total=False):
temperature: float | None
topP: float | None
maxTokens: int | None
stopSequences: list[str] | None
temperature: float
topP: float
maxTokens: int
stopSequences: list[str]


class ConverseRequest(TypedDict, total=False):
messages: list[ConverseMessage]
messages: Required[list[ConverseMessage]]
system: list[ConverseTextPart] | None
inferenceConfig: InferenceConfig | None
toolConfig: ConverseTools | None
Expand Down
4 changes: 2 additions & 2 deletions aidial_adapter_bedrock/llm/model/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
)
from aidial_adapter_bedrock.llm.model.cohere import CohereAdapter
from aidial_adapter_bedrock.llm.model.llama.v3 import (
ConverseStreamingEmulateAdapter,
ConverseAdapterWithStreamingEmulation,
)
from aidial_adapter_bedrock.llm.model.llama.v3 import (
input_tokenizer_factory as llama_tokenizer_factory,
Expand Down Expand Up @@ -130,7 +130,7 @@ async def get_bedrock_adapter(
| ChatCompletionDeployment.META_LLAMA3_2_11B_INSTRUCT_V1
| ChatCompletionDeployment.META_LLAMA3_2_90B_INSTRUCT_V1
):
return ConverseStreamingEmulateAdapter(
return ConverseAdapterWithStreamingEmulation(
deployment=model,
bedrock=await Bedrock.acreate(aws_client_config),
storage=create_file_storage(api_key),
Expand Down
4 changes: 2 additions & 2 deletions aidial_adapter_bedrock/llm/model/llama/v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
from aidial_adapter_bedrock.llm.tokenize import default_tokenize_string


class ConverseStreamingEmulateAdapter(ConverseAdapter):
class ConverseAdapterWithStreamingEmulation(ConverseAdapter):
"""
Llama 3.1 supports tools, but only in non-streaming mode.
Llama 3 models support tools only in the non-streaming mode.
So we need to run request in non-streaming mode, and then emulate streaming.
"""

Expand Down
88 changes: 88 additions & 0 deletions tests/unit_tests/converse/test_converse_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,94 @@ class TestCase:
message="A system message can only follow another system message",
),
),
TestCase(
name="multiple_system_messages",
messages=[
Message(role=Role.SYSTEM, content="You are a helpful assistant."),
Message(role=Role.SYSTEM, content="You are also very friendly."),
Message(role=Role.USER, content="Hello!"),
],
params=ModelParameters(tool_config=None),
expected_output=ConverseRequestWrapper(
inferenceConfig=default_inference_config,
system=[
ConverseTextPart(
text="You are a helpful assistant.\n\nYou are also very friendly."
),
],
messages=ListProjection(
list=[
(
ConverseMessage(
role=ConverseRole.USER,
content=[ConverseTextPart(text="Hello!")],
),
{2},
)
]
),
),
),
TestCase(
name="system_message_multiple_parts",
messages=[
Message(
role=Role.SYSTEM,
content=[
MessageContentTextPart(
type="text", text="You are a helpful assistant."
),
MessageContentTextPart(
type="text", text="You are also very friendly."
),
],
),
Message(role=Role.USER, content="Hello!"),
],
params=ModelParameters(tool_config=None),
expected_output=ConverseRequestWrapper(
inferenceConfig=default_inference_config,
system=[
ConverseTextPart(
text="You are a helpful assistant.\n\nYou are also very friendly."
),
],
messages=ListProjection(
list=[
(
ConverseMessage(
role=ConverseRole.USER,
content=[ConverseTextPart(text="Hello!")],
),
{1},
)
]
),
),
),
TestCase(
name="system_message_with_forbidden_image",
messages=[
Message(
role=Role.SYSTEM,
content=[
MessageContentTextPart(
type="text", text="You are a helpful assistant."
),
MessageContentImagePart(
type="image_url",
image_url=ImageURL(url=BLUE_PNG_PICTURE.to_data_url()),
),
],
),
Message(role=Role.USER, content="Hello!"),
],
params=ModelParameters(tool_config=None),
expected_error=ExpectedException(
type=ValidationError,
message="System messages cannot contain images",
),
),
TestCase(
name="tools_convert",
messages=[
Expand Down

0 comments on commit f051a5a

Please sign in to comment.