Skip to content

Commit

Permalink
feat: add Converse API, add LLama 3.2 (#177)
Browse files Browse the repository at this point in the history
  • Loading branch information
roman-romanov-o authored Nov 19, 2024
1 parent 56a2065 commit 67d3d6d
Show file tree
Hide file tree
Showing 24 changed files with 1,611 additions and 583 deletions.
10 changes: 6 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,15 @@ Note that a model supports `/truncate_prompt` endpoint if and only if it support
|Anthropic|Claude 2.1|anthropic.claude-v2:1|text-to-text||||
|Anthropic|Claude 2|anthropic.claude-v2|text-to-text||||
|Anthropic|Claude Instant 1.2|anthropic.claude-instant-v1|text-to-text|🟡|🟡||
|Meta|Llama 3.1 405B Instruct|meta.llama3-1-405b-instruct-v1:0|text-to-text|🟡|🟡||
|Meta|Llama 3.1 70B Instruct|meta.llama3-1-70b-instruct-v1:0|text-to-text|🟡|🟡||
|Meta|Llama 3.2 90B Instruct|us.meta.llama3-2-90b-instruct-v1:0|text-to-text, image-to-text|🟡|🟡||
|Meta|Llama 3.2 11B Instruct|us.meta.llama3-2-11b-instruct-v1:0|text-to-text, image-to-text|🟡|🟡||
|Meta|Llama 3.2 3B Instruct|us.meta.llama3-2-3b-instruct-v1:0|text-to-text|🟡|🟡||
|Meta|Llama 3.2 1B Instruct|us.meta.llama3-2-1b-instruct-v1:0|text-to-text|🟡|🟡||
|Meta|Llama 3.1 405B Instruct|meta.llama3-1-405b-instruct-v1:0|text-to-text|🟡|🟡||
|Meta|Llama 3.1 70B Instruct|meta.llama3-1-70b-instruct-v1:0|text-to-text|🟡|🟡||
|Meta|Llama 3.1 8B Instruct|meta.llama3-1-8b-instruct-v1:0|text-to-text|🟡|🟡||
|Meta|Llama 3 Chat 70B Instruct|meta.llama3-70b-instruct-v1:0|text-to-text|🟡|🟡||
|Meta|Llama 3 Chat 8B Instruct|meta.llama3-8b-instruct-v1:0|text-to-text|🟡|🟡||
|Meta|Llama 2 Chat 70B|meta.llama2-70b-chat-v1|text-to-text|🟡|🟡||
|Meta|Llama 2 Chat 13B|meta.llama2-13b-chat-v1|text-to-text|🟡|🟡||
|Stability AI|SDXL 1.0|stability.stable-diffusion-xl-v1|text-to-image||🟡||
|Stability AI|SD3 Large 1.0|stability.sd3-large-v1:0|text-to-image / image-to-image||🟡||
|Stability AI|Stable Image Ultra 1.0|stability.stable-image-ultra-v1:0|text-to-image||🟡||
Expand Down
20 changes: 19 additions & 1 deletion aidial_adapter_bedrock/bedrock.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json
from abc import ABC
from logging import DEBUG
from typing import Any, AsyncIterator, Mapping, Optional, Tuple
from typing import Any, AsyncIterator, Mapping, Optional, Tuple, Unpack

import boto3
from botocore.eventstream import EventStream
Expand All @@ -10,6 +10,7 @@

from aidial_adapter_bedrock.aws_client_config import AWSClientConfig
from aidial_adapter_bedrock.dial_api.token_usage import TokenUsage
from aidial_adapter_bedrock.llm.converse.types import ConverseRequest
from aidial_adapter_bedrock.utils.concurrency import (
make_async,
to_async_iterator,
Expand All @@ -36,6 +37,23 @@ async def acreate(cls, aws_client_config: AWSClientConfig) -> "Bedrock":
)
return cls(client)

async def aconverse_non_streaming(
self, model: str, **params: Unpack[ConverseRequest]
):
response = await make_async(
lambda: self.client.converse(modelId=model, **params)
)
return response

async def aconverse_streaming(
self, model: str, **params: Unpack[ConverseRequest]
):
response = await make_async(
lambda: self.client.converse_stream(modelId=model, **params)
)

return to_async_iterator(iter(response["stream"]))

def _create_invoke_params(self, model: str, body: dict) -> dict:
return {
"modelId": model,
Expand Down
10 changes: 6 additions & 4 deletions aidial_adapter_bedrock/deployments.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,15 @@ class ChatCompletionDeployment(str, Enum):
STABILITY_STABLE_DIFFUSION_3_LARGE_V1 = "stability.sd3-large-v1:0"
STABILITY_STABLE_IMAGE_ULTRA_V1 = "stability.stable-image-ultra-v1:0"

META_LLAMA2_13B_CHAT_V1 = "meta.llama2-13b-chat-v1"
META_LLAMA2_70B_CHAT_V1 = "meta.llama2-70b-chat-v1"
META_LLAMA3_8B_INSTRUCT_V1 = "meta.llama3-8b-instruct-v1:0"
META_LLAMA3_70B_INSTRUCT_V1 = "meta.llama3-70b-instruct-v1:0"
META_LLAMA3_1_405B_INSTRUCT_V1 = "meta.llama3-1-405b-instruct-v1:0"
META_LLAMA3_1_70B_INSTRUCT_V1 = "meta.llama3-1-70b-instruct-v1:0"
META_LLAMA3_1_8B_INSTRUCT_V1 = "meta.llama3-1-8b-instruct-v1:0"
META_LLAMA3_1_70B_INSTRUCT_V1 = "meta.llama3-1-70b-instruct-v1:0"
META_LLAMA3_1_405B_INSTRUCT_V1 = "meta.llama3-1-405b-instruct-v1:0"
META_LLAMA3_2_1B_INSTRUCT_V1 = "us.meta.llama3-2-1b-instruct-v1:0"
META_LLAMA3_2_3B_INSTRUCT_V1 = "us.meta.llama3-2-3b-instruct-v1:0"
META_LLAMA3_2_11B_INSTRUCT_V1 = "us.meta.llama3-2-11b-instruct-v1:0"
META_LLAMA3_2_90B_INSTRUCT_V1 = "us.meta.llama3-2-90b-instruct-v1:0"

COHERE_COMMAND_TEXT_V14 = "cohere.command-text-v14"
COHERE_COMMAND_LIGHT_TEXT_V14 = "cohere.command-light-text-v14"
Expand Down
Empty file.
180 changes: 180 additions & 0 deletions aidial_adapter_bedrock/llm/converse/adapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
from typing import Any, Awaitable, Callable, List, Tuple

from aidial_sdk.chat_completion import Message as DialMessage

from aidial_adapter_bedrock.bedrock import Bedrock
from aidial_adapter_bedrock.dial_api.request import ModelParameters
from aidial_adapter_bedrock.dial_api.storage import FileStorage
from aidial_adapter_bedrock.llm.chat_model import (
ChatCompletionAdapter,
keep_last,
turn_based_partitioner,
)
from aidial_adapter_bedrock.llm.consumer import Consumer
from aidial_adapter_bedrock.llm.converse.input import (
extract_converse_system_prompt,
to_converse_messages,
to_converse_tools,
)
from aidial_adapter_bedrock.llm.converse.output import (
process_non_streaming,
process_streaming,
)
from aidial_adapter_bedrock.llm.converse.types import (
ConverseDeployment,
ConverseMessage,
ConverseRequestWrapper,
ConverseTools,
InferenceConfig,
)
from aidial_adapter_bedrock.llm.errors import ValidationError
from aidial_adapter_bedrock.llm.tokenize import default_tokenize_string
from aidial_adapter_bedrock.llm.truncate_prompt import (
DiscardedMessages,
truncate_prompt,
)
from aidial_adapter_bedrock.utils.json import remove_nones
from aidial_adapter_bedrock.utils.list import omit_by_indices
from aidial_adapter_bedrock.utils.list_projection import ListProjection

ConverseMessages = List[Tuple[ConverseMessage, Any]]


class ConverseAdapter(ChatCompletionAdapter):
deployment: str
bedrock: Bedrock
storage: FileStorage | None

tokenize_text: Callable[[str], int] = default_tokenize_string
input_tokenizer_factory: Callable[
[ConverseDeployment, ConverseRequestWrapper],
Callable[[ConverseMessages], Awaitable[int]],
]
support_tools: bool
partitioner: Callable[[ConverseMessages], List[int]] = (
turn_based_partitioner
)

async def _discard_messages(
self, params: ConverseRequestWrapper, max_prompt_tokens: int | None
) -> Tuple[DiscardedMessages | None, ConverseRequestWrapper]:
if max_prompt_tokens is None:
return None, params

discarded_messages, messages = await truncate_prompt(
messages=params.messages.list,
tokenizer=self.input_tokenizer_factory(self.deployment, params),
keep_message=keep_last,
partitioner=self.partitioner,
model_limit=None,
user_limit=max_prompt_tokens,
)

return list(
params.messages.to_original_indices(discarded_messages)
), ConverseRequestWrapper(
messages=ListProjection(
omit_by_indices(messages, discarded_messages)
),
system=params.system,
inferenceConfig=params.inferenceConfig,
toolConfig=params.toolConfig,
)

async def count_prompt_tokens(
self, params: ModelParameters, messages: List[DialMessage]
) -> int:
converse_params = await self.construct_converse_params(messages, params)
return await self.input_tokenizer_factory(
self.deployment, converse_params
)(converse_params.messages.list)

async def count_completion_tokens(self, string: str) -> int:
return self.tokenize_text(string)

async def compute_discarded_messages(
self, params: ModelParameters, messages: List[DialMessage]
) -> DiscardedMessages | None:
converse_params = await self.construct_converse_params(messages, params)
discarded_messages, _ = await self._discard_messages(
converse_params, params.max_prompt_tokens
)
return discarded_messages

def get_tool_config(self, params: ModelParameters) -> ConverseTools | None:
if params.tool_config and not self.support_tools:
raise ValidationError("Tools are not supported")
return (
to_converse_tools(params.tool_config)
if params.tool_config
else None
)

async def construct_converse_params(
self,
messages: List[DialMessage],
params: ModelParameters,
) -> ConverseRequestWrapper:
system_prompt_extraction = extract_converse_system_prompt(messages)
converse_messages = await to_converse_messages(
system_prompt_extraction.non_system_messages,
self.storage,
start_offset=system_prompt_extraction.system_message_count,
)
system_message = system_prompt_extraction.system_prompt
if not converse_messages.list:
raise ValidationError("List of messages must not be empty")

return ConverseRequestWrapper(
system=[system_message] if system_message else None,
messages=converse_messages,
inferenceConfig=InferenceConfig(
**remove_nones(
{
"temperature": params.temperature,
"topP": params.top_p,
"maxTokens": params.max_tokens,
"stopSequences": params.stop,
}
)
),
toolConfig=self.get_tool_config(params),
)

def is_stream(self, params: ModelParameters) -> bool:
return params.stream

async def chat(
self,
consumer: Consumer,
params: ModelParameters,
messages: List[DialMessage],
) -> None:

converse_params = await self.construct_converse_params(messages, params)
discarded_messages, converse_params = await self._discard_messages(
converse_params, params.max_prompt_tokens
)
if not converse_params.messages.raw_list:
raise ValidationError("No messages left after truncation")

consumer.set_discarded_messages(discarded_messages)

if self.is_stream(params):
await process_streaming(
params=params,
stream=(
await self.bedrock.aconverse_streaming(
self.deployment, **converse_params.to_request()
)
),
consumer=consumer,
)
else:
process_non_streaming(
params=params,
response=await self.bedrock.aconverse_non_streaming(
self.deployment, **converse_params.to_request()
),
consumer=consumer,
)
12 changes: 12 additions & 0 deletions aidial_adapter_bedrock/llm/converse/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from aidial_sdk.chat_completion import FinishReason as DialFinishReason

from aidial_adapter_bedrock.llm.converse.types import ConverseStopReason

CONVERSE_TO_DIAL_FINISH_REASON = {
ConverseStopReason.END_TURN: DialFinishReason.STOP,
ConverseStopReason.TOOL_USE: DialFinishReason.TOOL_CALLS,
ConverseStopReason.MAX_TOKENS: DialFinishReason.LENGTH,
ConverseStopReason.STOP_SEQUENCE: DialFinishReason.STOP,
ConverseStopReason.GUARDRAIL_INTERVENED: DialFinishReason.CONTENT_FILTER,
ConverseStopReason.CONTENT_FILTERED: DialFinishReason.CONTENT_FILTER,
}
Loading

0 comments on commit 67d3d6d

Please sign in to comment.