Skip to content

Commit

Permalink
feat: Add Nova micro/lite/pro using Converse (#197)
Browse files Browse the repository at this point in the history
  • Loading branch information
roman-romanov-o authored Dec 6, 2024
1 parent ec1b94a commit c061d9c
Show file tree
Hide file tree
Showing 7 changed files with 86 additions and 56 deletions.
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ Note that a model supports `/truncate_prompt` endpoint if and only if it support
|Stability AI|Stable Image Ultra 1.0|stability.stable-image-ultra-v1:0|text-to-image||🟡||
|Stability AI|Stable Image Core 1.0|stability.stable-image-core-v1:0|text-to-image||🟡||
|Amazon|Titan Text G1 - Express|amazon.titan-tg1-large|text-to-text|🟡|🟡||
|Amazon|Nova Pro|amazon.nova-pro-v1:0|text-to-text, image-to-text|🟡|🟡||
|Amazon|Nova Lite|amazon.nova-lite-v1:0|text-to-text, image-to-text|🟡|🟡||
|Amazon|Nova Micro|amazon.nova-micro-v1:0|text-to-text|🟡|🟡||
|AI21 Labs|Jurassic-2 Ultra|ai21.j2-jumbo-instruct|text-to-text|🟡|🟡||
|AI21 Labs|Jurassic-2 Ultra v1|ai21.j2-ultra-v1|text-to-text|🟡|🟡||
|AI21 Labs|Jurassic-2 Mid|ai21.j2-grande-instruct|text-to-text|🟡|🟡||
Expand Down
3 changes: 3 additions & 0 deletions aidial_adapter_bedrock/deployments.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@


class ChatCompletionDeployment(str, Enum):
AMAZON_NOVA_PRO = "amazon.nova-pro-v1:0"
AMAZON_NOVA_LITE = "amazon.nova-lite-v1:0"
AMAZON_NOVA_MICRO = "amazon.nova-micro-v1:0"
AMAZON_TITAN_TG1_LARGE = "amazon.titan-tg1-large"

AI21_J2_GRANDE_INSTRUCT = "ai21.j2-grande-instruct"
Expand Down
24 changes: 24 additions & 0 deletions aidial_adapter_bedrock/llm/converse/default_tokenizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import json

from aidial_adapter_bedrock.llm.converse.adapter import ConverseMessages
from aidial_adapter_bedrock.llm.converse.types import (
ConverseDeployment,
ConverseRequestWrapper,
)
from aidial_adapter_bedrock.llm.tokenize import default_tokenize_string


def default_converse_tokenizer_factory(
deployment: ConverseDeployment, params: ConverseRequestWrapper
):
tool_tokens = default_tokenize_string(json.dumps(params.toolConfig))
system_tokens = default_tokenize_string(json.dumps(params.system))

async def tokenizer(msg_items: ConverseMessages) -> int:
tokens = sum(
default_tokenize_string(json.dumps(msg_item[0]))
for msg_item in msg_items
)
return tokens + tool_tokens + system_tokens

return tokenizer
22 changes: 17 additions & 5 deletions aidial_adapter_bedrock/llm/model/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@
)
from aidial_adapter_bedrock.llm.chat_model import ChatCompletionAdapter
from aidial_adapter_bedrock.llm.converse.adapter import ConverseAdapter
from aidial_adapter_bedrock.llm.converse.default_tokenizer import (
default_converse_tokenizer_factory,
)
from aidial_adapter_bedrock.llm.model.ai21 import AI21Adapter
from aidial_adapter_bedrock.llm.model.amazon import AmazonAdapter
from aidial_adapter_bedrock.llm.model.claude.v1_v2.adapter import (
Expand All @@ -33,9 +36,6 @@
from aidial_adapter_bedrock.llm.model.llama.v3 import (
ConverseAdapterWithStreamingEmulation,
)
from aidial_adapter_bedrock.llm.model.llama.v3 import (
input_tokenizer_factory as llama_tokenizer_factory,
)
from aidial_adapter_bedrock.llm.model.stability.v1 import StabilityV1Adapter
from aidial_adapter_bedrock.llm.model.stability.v2 import StabilityV2Adapter

Expand Down Expand Up @@ -112,6 +112,18 @@ async def get_bedrock_adapter(
return AmazonAdapter.create(
await Bedrock.acreate(aws_client_config), model
)
case (
ChatCompletionDeployment.AMAZON_NOVA_MICRO
| ChatCompletionDeployment.AMAZON_NOVA_PRO
| ChatCompletionDeployment.AMAZON_NOVA_LITE
):
return ConverseAdapter(
deployment=model,
bedrock=await Bedrock.acreate(aws_client_config),
storage=create_file_storage(api_key),
input_tokenizer_factory=default_converse_tokenizer_factory,
support_tools=True,
)
case (
ChatCompletionDeployment.META_LLAMA3_8B_INSTRUCT_V1
| ChatCompletionDeployment.META_LLAMA3_70B_INSTRUCT_V1
Expand All @@ -123,7 +135,7 @@ async def get_bedrock_adapter(
deployment=model,
bedrock=await Bedrock.acreate(aws_client_config),
storage=create_file_storage(api_key),
input_tokenizer_factory=llama_tokenizer_factory,
input_tokenizer_factory=default_converse_tokenizer_factory,
support_tools=False,
)
case (
Expand All @@ -136,7 +148,7 @@ async def get_bedrock_adapter(
deployment=model,
bedrock=await Bedrock.acreate(aws_client_config),
storage=create_file_storage(api_key),
input_tokenizer_factory=llama_tokenizer_factory,
input_tokenizer_factory=default_converse_tokenizer_factory,
support_tools=True,
)
case (
Expand Down
29 changes: 1 addition & 28 deletions aidial_adapter_bedrock/llm/model/llama/v3.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,5 @@
import json
from typing import Awaitable, Callable

from aidial_adapter_bedrock.dial_api.request import ModelParameters
from aidial_adapter_bedrock.llm.converse.adapter import (
ConverseAdapter,
ConverseMessages,
)
from aidial_adapter_bedrock.llm.converse.types import (
ConverseDeployment,
ConverseRequestWrapper,
)
from aidial_adapter_bedrock.llm.tokenize import default_tokenize_string
from aidial_adapter_bedrock.llm.converse.adapter import ConverseAdapter


class ConverseAdapterWithStreamingEmulation(ConverseAdapter):
Expand All @@ -23,19 +12,3 @@ def is_stream(self, params: ModelParameters) -> bool:
if self.get_tool_config(params):
return False
return params.stream


def input_tokenizer_factory(
deployment: ConverseDeployment, params: ConverseRequestWrapper
) -> Callable[[ConverseMessages], Awaitable[int]]:
tool_tokens = default_tokenize_string(json.dumps(params.toolConfig))
system_tokens = default_tokenize_string(json.dumps(params.system))

async def tokenizer(msg_items: ConverseMessages) -> int:
tokens = sum(
default_tokenize_string(json.dumps(msg_item[0]))
for msg_item in msg_items
)
return tokens + tool_tokens + system_tokens

return tokenizer
58 changes: 35 additions & 23 deletions tests/integration_tests/test_chat_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,9 @@ def get_id(self):
ChatCompletionDeployment.META_LLAMA3_2_90B_INSTRUCT_V1: _WEST,
ChatCompletionDeployment.COHERE_COMMAND_TEXT_V14: _WEST,
ChatCompletionDeployment.COHERE_COMMAND_LIGHT_TEXT_V14: _WEST,
ChatCompletionDeployment.AMAZON_NOVA_MICRO: _EAST,
ChatCompletionDeployment.AMAZON_NOVA_PRO: _EAST,
ChatCompletionDeployment.AMAZON_NOVA_LITE: _EAST,
}


Expand All @@ -140,16 +143,26 @@ def supports_tools(deployment: ChatCompletionDeployment) -> bool:
ChatCompletionDeployment.META_LLAMA3_1_70B_INSTRUCT_V1,
ChatCompletionDeployment.META_LLAMA3_1_405B_INSTRUCT_V1,
ChatCompletionDeployment.META_LLAMA3_2_90B_INSTRUCT_V1,
# Technically, Nova Micro supports tools, but it's unstable
# ChatCompletionDeployment.AMAZON_NOVA_MICRO,
ChatCompletionDeployment.AMAZON_NOVA_PRO,
ChatCompletionDeployment.AMAZON_NOVA_LITE,
ChatCompletionDeployment.AMAZON_NOVA_MICRO,
]


def supports_parallel_tool_calls(deployment: ChatCompletionDeployment) -> bool:
return deployment not in [
ChatCompletionDeployment.ANTHROPIC_CLAUDE_V3_5_SONNET_V2,
ChatCompletionDeployment.ANTHROPIC_CLAUDE_V3_5_SONNET_V2_US,
ChatCompletionDeployment.META_LLAMA3_1_70B_INSTRUCT_V1,
ChatCompletionDeployment.META_LLAMA3_1_405B_INSTRUCT_V1,
] and supports_tools(deployment)
return (
deployment
not in [
ChatCompletionDeployment.ANTHROPIC_CLAUDE_V3_5_SONNET_V2,
ChatCompletionDeployment.ANTHROPIC_CLAUDE_V3_5_SONNET_V2_US,
ChatCompletionDeployment.META_LLAMA3_1_70B_INSTRUCT_V1,
ChatCompletionDeployment.META_LLAMA3_1_405B_INSTRUCT_V1,
]
and not is_nova(deployment)
and supports_tools(deployment)
)


def is_llama3(deployment: ChatCompletionDeployment) -> bool:
Expand Down Expand Up @@ -193,6 +206,14 @@ def is_claude3(deployment: ChatCompletionDeployment) -> bool:
]


def is_nova(deployment: ChatCompletionDeployment) -> bool:
return deployment in [
ChatCompletionDeployment.AMAZON_NOVA_MICRO,
ChatCompletionDeployment.AMAZON_NOVA_PRO,
ChatCompletionDeployment.AMAZON_NOVA_LITE,
]


def is_ai21(deployment: ChatCompletionDeployment) -> bool:
return deployment in [
ChatCompletionDeployment.AI21_J2_GRANDE_INSTRUCT,
Expand All @@ -211,6 +232,8 @@ def is_vision_model(deployment: ChatCompletionDeployment) -> bool:
allowed_models = [
ChatCompletionDeployment.META_LLAMA3_2_11B_INSTRUCT_V1,
ChatCompletionDeployment.META_LLAMA3_2_90B_INSTRUCT_V1,
ChatCompletionDeployment.AMAZON_NOVA_PRO,
ChatCompletionDeployment.AMAZON_NOVA_LITE,
]

# Claude 3.5 Haiku was launched as a text-only model
Expand Down Expand Up @@ -276,26 +299,15 @@ def test_case(
)
)

def dial_recall_expected(r: ChatCompletionResult):
content = r.content.lower()
success = "anton" in content
# Amazon Titan and Cohere performances have degraded recently
if deployment in [
ChatCompletionDeployment.AMAZON_TITAN_TG1_LARGE,
ChatCompletionDeployment.COHERE_COMMAND_TEXT_V14,
]:
return not success
return success

test_case(
name="dialog recall",
messages=[
user("my name is Anton"),
ai("nice to meet you"),
user("what's my name?"),
user("Remember Paris city. Just say hello"),
ai("Hello"),
user("What city did I mention earlier?"),
],
max_tokens=32,
expected=dial_recall_expected,
expected=lambda s: "paris" in s.content.lower(),
)

test_case(
Expand Down Expand Up @@ -345,7 +357,7 @@ def dial_recall_expected(r: ChatCompletionResult):
expected_empty_message_error = streaming_error(
cohere_invalid_request_error
)
elif is_llama3(deployment):
elif is_llama3(deployment) or is_nova(deployment):
expected_empty_message_error = streaming_error(
ExpectedException(
type=BadRequestError,
Expand Down Expand Up @@ -374,7 +386,7 @@ def dial_recall_expected(r: ChatCompletionResult):
expected_whitespace_message = streaming_error(
cohere_invalid_request_error
)
elif is_llama3(deployment):
elif is_llama3(deployment) or is_nova(deployment):
expected_whitespace_message = streaming_error(
ExpectedException(
type=BadRequestError,
Expand Down
3 changes: 3 additions & 0 deletions tests/unit_tests/test_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@

test_cases: List[Tuple[ChatCompletionDeployment, bool, bool]] = [
(ChatCompletionDeployment.AMAZON_TITAN_TG1_LARGE, True, True),
(ChatCompletionDeployment.AMAZON_NOVA_PRO, True, True),
(ChatCompletionDeployment.AMAZON_NOVA_LITE, True, True),
(ChatCompletionDeployment.AMAZON_NOVA_MICRO, True, True),
(ChatCompletionDeployment.AI21_J2_GRANDE_INSTRUCT, True, True),
(ChatCompletionDeployment.AI21_J2_JUMBO_INSTRUCT, True, True),
(ChatCompletionDeployment.AI21_J2_MID_V1, True, True),
Expand Down

0 comments on commit c061d9c

Please sign in to comment.