Skip to content

Commit

Permalink
Fix converse params, add one more integration tests
Browse files Browse the repository at this point in the history
  • Loading branch information
roman-romanov-o committed Nov 14, 2024
1 parent f051a5a commit 74723a7
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 16 deletions.
19 changes: 9 additions & 10 deletions aidial_adapter_bedrock/llm/converse/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
DiscardedMessages,
truncate_prompt,
)
from aidial_adapter_bedrock.utils.json import remove_nones
from aidial_adapter_bedrock.utils.list import omit_by_indices
from aidial_adapter_bedrock.utils.list_projection import ListProjection

Expand Down Expand Up @@ -128,16 +129,14 @@ async def construct_converse_params(
system=[system_message] if system_message else None,
messages=processed_messages,
inferenceConfig=InferenceConfig(
**{
key: value
for key, value in [
("temperature", params.temperature),
("topP", params.top_p),
("maxTokens", params.max_tokens),
("stopSequences", params.stop),
]
if value is not None
}
**remove_nones(
{
"temperature": params.temperature,
"topP": params.top_p,
"maxTokens": params.max_tokens,
"stopSequences": params.stop,
}
)
),
toolConfig=self.get_tool_config(params),
)
Expand Down
17 changes: 11 additions & 6 deletions aidial_adapter_bedrock/llm/converse/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from enum import Enum
from typing import Any, Literal, Required, TypedDict, Union

from aidial_adapter_bedrock.utils.json import remove_nones
from aidial_adapter_bedrock.utils.list_projection import ListProjection


Expand Down Expand Up @@ -108,9 +109,9 @@ class InferenceConfig(TypedDict, total=False):

class ConverseRequest(TypedDict, total=False):
messages: Required[list[ConverseMessage]]
system: list[ConverseTextPart] | None
inferenceConfig: InferenceConfig | None
toolConfig: ConverseTools | None
system: list[ConverseTextPart]
inferenceConfig: InferenceConfig
toolConfig: ConverseTools


@dataclass
Expand All @@ -122,10 +123,14 @@ class ConverseRequestWrapper:

def to_request(self) -> ConverseRequest:
return ConverseRequest(
system=self.system,
messages=self.messages.raw_list,
inferenceConfig=self.inferenceConfig,
toolConfig=self.toolConfig,
**remove_nones(
{
"inferenceConfig": self.inferenceConfig,
"toolConfig": self.toolConfig,
"system": self.system,
}
),
)


Expand Down
14 changes: 14 additions & 0 deletions tests/integration_tests/test_chat_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,19 @@ def dial_recall_expected(r: ChatCompletionResult):
)

if is_llama3(deployment):

test_case(
name="out_of_turn",
messages=[ai("hello"), user("what's 7+5?")],
expected=streaming_error(
ExpectedException(
type=BadRequestError,
message="A conversation must start with a user message",
status_code=400,
)
),
)

test_case(
name="many system",
messages=[
Expand Down Expand Up @@ -439,6 +452,7 @@ def dial_recall_expected(r: ChatCompletionResult):
ai("5"),
user(query),
]
# Llama 3 works badly with system messages along tools
if not is_llama3(deployment):
init_messages.insert(0, sys("act as a helpful assistant"))

Expand Down

0 comments on commit 74723a7

Please sign in to comment.