From 74723a7d44391c211eefeb5be01c19616892acec Mon Sep 17 00:00:00 2001 From: Roman Romanov Date: Thu, 14 Nov 2024 13:39:44 +0200 Subject: [PATCH] Fix converse params, add one more integration tests --- .../llm/converse/adapter.py | 19 +++++++++---------- aidial_adapter_bedrock/llm/converse/types.py | 17 +++++++++++------ .../integration_tests/test_chat_completion.py | 14 ++++++++++++++ 3 files changed, 34 insertions(+), 16 deletions(-) diff --git a/aidial_adapter_bedrock/llm/converse/adapter.py b/aidial_adapter_bedrock/llm/converse/adapter.py index f650391..41a9da1 100644 --- a/aidial_adapter_bedrock/llm/converse/adapter.py +++ b/aidial_adapter_bedrock/llm/converse/adapter.py @@ -33,6 +33,7 @@ DiscardedMessages, truncate_prompt, ) +from aidial_adapter_bedrock.utils.json import remove_nones from aidial_adapter_bedrock.utils.list import omit_by_indices from aidial_adapter_bedrock.utils.list_projection import ListProjection @@ -128,16 +129,14 @@ async def construct_converse_params( system=[system_message] if system_message else None, messages=processed_messages, inferenceConfig=InferenceConfig( - **{ - key: value - for key, value in [ - ("temperature", params.temperature), - ("topP", params.top_p), - ("maxTokens", params.max_tokens), - ("stopSequences", params.stop), - ] - if value is not None - } + **remove_nones( + { + "temperature": params.temperature, + "topP": params.top_p, + "maxTokens": params.max_tokens, + "stopSequences": params.stop, + } + ) ), toolConfig=self.get_tool_config(params), ) diff --git a/aidial_adapter_bedrock/llm/converse/types.py b/aidial_adapter_bedrock/llm/converse/types.py index 1b3a326..54b8e15 100644 --- a/aidial_adapter_bedrock/llm/converse/types.py +++ b/aidial_adapter_bedrock/llm/converse/types.py @@ -7,6 +7,7 @@ from enum import Enum from typing import Any, Literal, Required, TypedDict, Union +from aidial_adapter_bedrock.utils.json import remove_nones from aidial_adapter_bedrock.utils.list_projection import ListProjection @@ -108,9 +109,9 @@ class InferenceConfig(TypedDict, total=False): class ConverseRequest(TypedDict, total=False): messages: Required[list[ConverseMessage]] - system: list[ConverseTextPart] | None - inferenceConfig: InferenceConfig | None - toolConfig: ConverseTools | None + system: list[ConverseTextPart] + inferenceConfig: InferenceConfig + toolConfig: ConverseTools @dataclass @@ -122,10 +123,14 @@ class ConverseRequestWrapper: def to_request(self) -> ConverseRequest: return ConverseRequest( - system=self.system, messages=self.messages.raw_list, - inferenceConfig=self.inferenceConfig, - toolConfig=self.toolConfig, + **remove_nones( + { + "inferenceConfig": self.inferenceConfig, + "toolConfig": self.toolConfig, + "system": self.system, + } + ), ) diff --git a/tests/integration_tests/test_chat_completion.py b/tests/integration_tests/test_chat_completion.py index 74922a7..0bff30a 100644 --- a/tests/integration_tests/test_chat_completion.py +++ b/tests/integration_tests/test_chat_completion.py @@ -406,6 +406,19 @@ def dial_recall_expected(r: ChatCompletionResult): ) if is_llama3(deployment): + + test_case( + name="out_of_turn", + messages=[ai("hello"), user("what's 7+5?")], + expected=streaming_error( + ExpectedException( + type=BadRequestError, + message="A conversation must start with a user message", + status_code=400, + ) + ), + ) + test_case( name="many system", messages=[ @@ -439,6 +452,7 @@ def dial_recall_expected(r: ChatCompletionResult): ai("5"), user(query), ] + # Llama 3 works badly with system messages along tools if not is_llama3(deployment): init_messages.insert(0, sys("act as a helpful assistant"))