Skip to content

Commit

Permalink
Cover by tests, update README, update poetry
Browse files Browse the repository at this point in the history
  • Loading branch information
roman-romanov-o committed Nov 8, 2024
1 parent d7a2861 commit 791c33a
Show file tree
Hide file tree
Showing 8 changed files with 459 additions and 34 deletions.
14 changes: 8 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,15 @@ Note that a model supports `/truncate_prompt` endpoint if and only if it support
|Anthropic|Claude 2.1|anthropic.claude-v2:1|text-to-text||||
|Anthropic|Claude 2|anthropic.claude-v2|text-to-text||||
|Anthropic|Claude Instant 1.2|anthropic.claude-instant-v1|text-to-text|🟡|🟡||
|Meta|Llama 3.1 405B Instruct|meta.llama3-1-405b-instruct-v1:0|text-to-text|🟡|🟡||
|Meta|Llama 3.1 70B Instruct|meta.llama3-1-70b-instruct-v1:0|text-to-text|🟡|🟡||
|Meta|Llama 3.1 8B Instruct|meta.llama3-1-8b-instruct-v1:0|text-to-text|🟡|🟡||
|Meta|Llama 3 Chat 70B Instruct|meta.llama3-70b-instruct-v1:0|text-to-text|🟡|🟡||
|Meta|Llama 3 Chat 8B Instruct|meta.llama3-8b-instruct-v1:0|text-to-text|🟡|🟡||
|Meta|Llama 2 Chat 70B|meta.llama2-70b-chat-v1|text-to-text|🟡|🟡||
|Meta|Llama 2 Chat 13B|meta.llama2-13b-chat-v1|text-to-text|🟡|🟡||
|Meta|Llama 3 Chat 70B Instruct|meta.llama3-70b-instruct-v1:0|text-to-text|🟡|🟡||
|Meta|Llama 3.1 8B Instruct|meta.llama3-1-8b-instruct-v1:0|text-to-text|🟡|🟡||
|Meta|Llama 3.1 70B Instruct|meta.llama3-1-70b-instruct-v1:0|text-to-text|🟡|🟡||
|Meta|Llama 3.1 405B Instruct|meta.llama3-1-405b-instruct-v1:0|text-to-text|🟡|🟡||
|Meta|Llama 3.2 1B Instruct|us.meta.llama3-2-1b-instruct-v1:0|text-to-text|🟡|🟡||
|Meta|Llama 3.2 3B Instruct|us.meta.llama3-2-3b-instruct-v1:0|text-to-text|🟡|🟡||
|Meta|Llama 3.2 11B Instruct|us.meta.llama3-2-11b-instruct-v1:0|text-to-text|🟡|🟡||
|Meta|Llama 3.2 90B Instruct|us.meta.llama3-2-90b-instruct-v1:0|text-to-text|🟡|🟡||
|Stability AI|SDXL 1.0|stability.stable-diffusion-xl-v1|text-to-image||🟡||
|Stability AI|SD3 Large 1.0|stability.sd3-large-v1:0|text-to-image / image-to-image||🟡||
|Stability AI|Stable Image Ultra 1.0|stability.stable-image-ultra-v1:0|text-to-image||🟡||
Expand Down
4 changes: 2 additions & 2 deletions aidial_adapter_bedrock/deployments.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,9 @@ class ChatCompletionDeployment(str, Enum):

META_LLAMA3_8B_INSTRUCT_V1 = "meta.llama3-8b-instruct-v1:0"
META_LLAMA3_70B_INSTRUCT_V1 = "meta.llama3-70b-instruct-v1:0"
META_LLAMA3_1_405B_INSTRUCT_V1 = "meta.llama3-1-405b-instruct-v1:0"
META_LLAMA3_1_70B_INSTRUCT_V1 = "meta.llama3-1-70b-instruct-v1:0"
META_LLAMA3_1_8B_INSTRUCT_V1 = "meta.llama3-1-8b-instruct-v1:0"
META_LLAMA3_1_70B_INSTRUCT_V1 = "meta.llama3-1-70b-instruct-v1:0"
META_LLAMA3_1_405B_INSTRUCT_V1 = "meta.llama3-1-405b-instruct-v1:0"
META_LLAMA3_2_1B_INSTRUCT_V1 = "us.meta.llama3-2-1b-instruct-v1:0"
META_LLAMA3_2_3B_INSTRUCT_V1 = "us.meta.llama3-2-3b-instruct-v1:0"
META_LLAMA3_2_11B_INSTRUCT_V1 = "us.meta.llama3-2-11b-instruct-v1:0"
Expand Down
6 changes: 5 additions & 1 deletion aidial_adapter_bedrock/llm/converse/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,12 @@ class ConverseToolConfig(TypedDict):
inputSchema: dict


class ConverseToolSpec(TypedDict):
toolSpec: ConverseToolConfig


class ConverseTools(TypedDict):
tools: list[ConverseToolConfig]
tools: list[ConverseToolSpec]
toolChoice: dict


Expand Down
39 changes: 20 additions & 19 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ repository = "https://github.com/epam/ai-dial-adapter-bedrock/"

[tool.poetry.dependencies]
python = "^3.11,<4.0"
boto3 = "1.28.57"
botocore = "1.31.57"
boto3 = "1.35.41"
botocore = "1.35.41"
aidial-sdk = {version = "0.14.0", extras = ["telemetry"]}
anthropic = {version = "0.28.1", extras = ["bedrock"]}
fastapi = "0.115.2"
Expand Down
232 changes: 232 additions & 0 deletions tests/unit_tests/converse/test_converse_adapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,232 @@
from dataclasses import dataclass
from typing import List

import pytest
from aidial_sdk.chat_completion.request import (
Function,
FunctionCall,
Message,
Role,
ToolCall,
)

from aidial_adapter_bedrock.aws_client_config import AWSClientConfig
from aidial_adapter_bedrock.bedrock import Bedrock
from aidial_adapter_bedrock.dial_api.request import ModelParameters
from aidial_adapter_bedrock.llm.converse.adapter import ConverseAdapter
from aidial_adapter_bedrock.llm.converse.types import (
ConverseMessage,
ConverseParams,
ConverseRole,
ConverseTextPart,
ConverseToolResultPart,
ConverseToolUseConfig,
ConverseToolUsePart,
InferenceConfig,
)
from aidial_adapter_bedrock.llm.errors import ValidationError
from aidial_adapter_bedrock.llm.tools.tools_config import ToolsConfig
from aidial_adapter_bedrock.utils.list_projection import ListProjection


async def _input_tokenizer_factory(_deployment, _params):
async def _test_tokenizer(_messages) -> int:
return 100

return _test_tokenizer


@dataclass
class TestCase:
__test__ = False
name: str
messages: List[Message]
params: ModelParameters
expected_output: ConverseParams | None = None
expected_error: type[Exception] | None = None


default_inference_config = InferenceConfig(stopSequences=[])
TEST_CASES = [
TestCase(
name="plain_message",
messages=[Message(role=Role.USER, content="Hello, world!")],
params=ModelParameters(tool_config=None),
expected_output=ConverseParams(
inferenceConfig=default_inference_config,
messages=ListProjection(
list=[
(
ConverseMessage(
role=ConverseRole.USER,
content=[ConverseTextPart(text="Hello, world!")],
),
{0},
)
]
),
),
),
TestCase(
name="system_message",
messages=[
Message(role=Role.SYSTEM, content="You are a helpful assistant."),
Message(role=Role.USER, content="Hello!"),
],
params=ModelParameters(tool_config=None),
expected_output=ConverseParams(
inferenceConfig=default_inference_config,
system=[ConverseTextPart(text="You are a helpful assistant.")],
messages=ListProjection(
list=[
(
ConverseMessage(
role=ConverseRole.USER,
content=[ConverseTextPart(text="Hello!")],
),
{1},
)
]
),
),
),
TestCase(
name="system_message_after_user",
messages=[
Message(role=Role.SYSTEM, content="You are a helpful assistant."),
Message(role=Role.USER, content="Hello!"),
Message(role=Role.SYSTEM, content="You are a helpful assistant."),
],
params=ModelParameters(tool_config=None),
expected_error=ValidationError,
),
TestCase(
name="tools_convert",
messages=[
Message(role=Role.USER, content="What's the weather?"),
Message(
role=Role.ASSISTANT,
content=None,
tool_calls=[
ToolCall(
index=0,
id="call_123",
type="function",
function=FunctionCall(
name="get_weather",
arguments='{"location": "London"}',
),
)
],
),
Message(
role=Role.TOOL,
content='{"temperature": "20C"}',
tool_call_id="call_123",
),
],
params=ModelParameters(
tool_config=ToolsConfig(
functions=[
Function(
name="get_weather",
description="Get the weather",
parameters={"type": "object", "properties": {}},
)
],
required=True,
tool_ids=None,
)
),
expected_output=ConverseParams(
inferenceConfig=default_inference_config,
toolConfig={
"tools": [
{
"toolSpec": {
"name": "get_weather",
"description": "Get the weather",
"inputSchema": {
"json": {"properties": {}, "type": "object"}
},
}
}
],
"toolChoice": {"any": {}},
},
messages=ListProjection(
list=[
(
ConverseMessage(
role=ConverseRole.USER,
content=[
ConverseTextPart(text="What's the weather?")
],
),
{0},
),
(
ConverseMessage(
role=ConverseRole.ASSISTANT,
content=[
ConverseToolUsePart(
toolUse=ConverseToolUseConfig(
toolUseId="call_123",
name="get_weather",
input={"location": "London"},
)
)
],
),
{1},
),
(
ConverseMessage(
role=ConverseRole.USER,
content=[
ConverseToolResultPart(
toolResult={
"toolUseId": "call_123",
"content": [
{"json": {"temperature": "20C"}}
],
"status": "success",
}
)
],
),
{2},
),
]
),
),
),
]


@pytest.mark.parametrize(
"test_case", TEST_CASES, ids=lambda test_case: test_case.name
)
@pytest.mark.asyncio
async def test_converse_adapter(
test_case: TestCase,
):
adapter = ConverseAdapter(
deployment="test",
bedrock=Bedrock.create(AWSClientConfig(region="us-east-1")),
tokenize_text=lambda x: len(x),
input_tokenizer_factory=_input_tokenizer_factory, # type: ignore
support_tools=True,
storage=None,
)
construct_coro = adapter.construct_converse_params(
messages=test_case.messages,
params=test_case.params,
)

if test_case.expected_error is not None:
with pytest.raises(test_case.expected_error):
converse_request = await construct_coro
else:
converse_request = await construct_coro
assert converse_request == test_case.expected_output
Loading

0 comments on commit 791c33a

Please sign in to comment.