Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add Converse API, add LLama 3.2 #177

Merged
merged 16 commits into from
Nov 19, 2024
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,15 @@ Note that a model supports `/truncate_prompt` endpoint if and only if it support
|Anthropic|Claude 2.1|anthropic.claude-v2:1|text-to-text|✅|✅|✅|
|Anthropic|Claude 2|anthropic.claude-v2|text-to-text|✅|✅|❌|
|Anthropic|Claude Instant 1.2|anthropic.claude-instant-v1|text-to-text|🟡|🟡|❌|
|Meta|Llama 3.1 405B Instruct|meta.llama3-1-405b-instruct-v1:0|text-to-text|🟡|🟡|❌|
|Meta|Llama 3.1 70B Instruct|meta.llama3-1-70b-instruct-v1:0|text-to-text|🟡|🟡|❌|
|Meta|Llama 3.2 90B Instruct|us.meta.llama3-2-90b-instruct-v1:0|text-to-text|🟡|🟡|✅|
|Meta|Llama 3.2 11B Instruct|us.meta.llama3-2-11b-instruct-v1:0|text-to-text|🟡|🟡|❌|
|Meta|Llama 3.2 3B Instruct|us.meta.llama3-2-3b-instruct-v1:0|text-to-text|🟡|🟡|❌|
|Meta|Llama 3.2 1B Instruct|us.meta.llama3-2-1b-instruct-v1:0|text-to-text|🟡|🟡|❌|
|Meta|Llama 3.1 405B Instruct|meta.llama3-1-405b-instruct-v1:0|text-to-text|🟡|🟡|✅|
|Meta|Llama 3.1 70B Instruct|meta.llama3-1-70b-instruct-v1:0|text-to-text|🟡|🟡|✅|
|Meta|Llama 3.1 8B Instruct|meta.llama3-1-8b-instruct-v1:0|text-to-text|🟡|🟡|❌|
|Meta|Llama 3 Chat 70B Instruct|meta.llama3-70b-instruct-v1:0|text-to-text|🟡|🟡|❌|
|Meta|Llama 3 Chat 8B Instruct|meta.llama3-8b-instruct-v1:0|text-to-text|🟡|🟡|❌|
|Meta|Llama 2 Chat 70B|meta.llama2-70b-chat-v1|text-to-text|🟡|🟡|❌|
|Meta|Llama 2 Chat 13B|meta.llama2-13b-chat-v1|text-to-text|🟡|🟡|❌|
|Stability AI|SDXL 1.0|stability.stable-diffusion-xl-v1|text-to-image|❌|🟡|❌|
|Stability AI|SD3 Large 1.0|stability.sd3-large-v1:0|text-to-image / image-to-image|❌|🟡|❌|
|Stability AI|Stable Image Ultra 1.0|stability.stable-image-ultra-v1:0|text-to-image|❌|🟡|❌|
Expand Down
20 changes: 19 additions & 1 deletion aidial_adapter_bedrock/bedrock.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json
from abc import ABC
from logging import DEBUG
from typing import Any, AsyncIterator, Mapping, Optional, Tuple
from typing import Any, AsyncIterator, Mapping, Optional, Tuple, Unpack

import boto3
from botocore.eventstream import EventStream
Expand All @@ -10,6 +10,7 @@

from aidial_adapter_bedrock.aws_client_config import AWSClientConfig
from aidial_adapter_bedrock.dial_api.token_usage import TokenUsage
from aidial_adapter_bedrock.llm.converse.types import ConverseRequest
from aidial_adapter_bedrock.utils.concurrency import (
make_async,
to_async_iterator,
Expand All @@ -36,6 +37,23 @@ async def acreate(cls, aws_client_config: AWSClientConfig) -> "Bedrock":
)
return cls(client)

async def aconverse_non_streaming(
self, model: str, **params: Unpack[ConverseRequest]
):
response = await make_async(
lambda: self.client.converse(modelId=model, **params)
)
return response

async def aconverse_streaming(
self, model: str, **params: Unpack[ConverseRequest]
):
response = await make_async(
lambda: self.client.converse_stream(modelId=model, **params)
)

return to_async_iterator(iter(response["stream"]))

def _create_invoke_params(self, model: str, body: dict) -> dict:
return {
"modelId": model,
Expand Down
10 changes: 6 additions & 4 deletions aidial_adapter_bedrock/deployments.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,15 @@ class ChatCompletionDeployment(str, Enum):
STABILITY_STABLE_DIFFUSION_3_LARGE_V1 = "stability.sd3-large-v1:0"
STABILITY_STABLE_IMAGE_ULTRA_V1 = "stability.stable-image-ultra-v1:0"

META_LLAMA2_13B_CHAT_V1 = "meta.llama2-13b-chat-v1"
META_LLAMA2_70B_CHAT_V1 = "meta.llama2-70b-chat-v1"
META_LLAMA3_8B_INSTRUCT_V1 = "meta.llama3-8b-instruct-v1:0"
META_LLAMA3_70B_INSTRUCT_V1 = "meta.llama3-70b-instruct-v1:0"
META_LLAMA3_1_405B_INSTRUCT_V1 = "meta.llama3-1-405b-instruct-v1:0"
META_LLAMA3_1_70B_INSTRUCT_V1 = "meta.llama3-1-70b-instruct-v1:0"
META_LLAMA3_1_8B_INSTRUCT_V1 = "meta.llama3-1-8b-instruct-v1:0"
META_LLAMA3_1_70B_INSTRUCT_V1 = "meta.llama3-1-70b-instruct-v1:0"
META_LLAMA3_1_405B_INSTRUCT_V1 = "meta.llama3-1-405b-instruct-v1:0"
META_LLAMA3_2_1B_INSTRUCT_V1 = "us.meta.llama3-2-1b-instruct-v1:0"
META_LLAMA3_2_3B_INSTRUCT_V1 = "us.meta.llama3-2-3b-instruct-v1:0"
META_LLAMA3_2_11B_INSTRUCT_V1 = "us.meta.llama3-2-11b-instruct-v1:0"
META_LLAMA3_2_90B_INSTRUCT_V1 = "us.meta.llama3-2-90b-instruct-v1:0"

COHERE_COMMAND_TEXT_V14 = "cohere.command-text-v14"
COHERE_COMMAND_LIGHT_TEXT_V14 = "cohere.command-light-text-v14"
Expand Down
Empty file.
180 changes: 180 additions & 0 deletions aidial_adapter_bedrock/llm/converse/adapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
from typing import Any, Awaitable, Callable, List, Tuple

from aidial_sdk.chat_completion import Message as DialMessage

from aidial_adapter_bedrock.bedrock import Bedrock
from aidial_adapter_bedrock.dial_api.request import ModelParameters
from aidial_adapter_bedrock.dial_api.storage import FileStorage
from aidial_adapter_bedrock.llm.chat_model import (
ChatCompletionAdapter,
keep_last,
turn_based_partitioner,
)
from aidial_adapter_bedrock.llm.consumer import Consumer
from aidial_adapter_bedrock.llm.converse.input import (
extract_converse_system_prompt,
to_converse_messages,
to_converse_tools,
)
from aidial_adapter_bedrock.llm.converse.output import (
process_non_streaming,
process_streaming,
)
from aidial_adapter_bedrock.llm.converse.types import (
ConverseDeployment,
ConverseMessage,
ConverseRequestWrapper,
ConverseTools,
InferenceConfig,
)
from aidial_adapter_bedrock.llm.errors import ValidationError
from aidial_adapter_bedrock.llm.tokenize import default_tokenize_string
from aidial_adapter_bedrock.llm.truncate_prompt import (
DiscardedMessages,
truncate_prompt,
)
from aidial_adapter_bedrock.utils.json import remove_nones
from aidial_adapter_bedrock.utils.list import omit_by_indices
from aidial_adapter_bedrock.utils.list_projection import ListProjection

ConverseMessages = List[Tuple[ConverseMessage, Any]]


class ConverseAdapter(ChatCompletionAdapter):
deployment: str
bedrock: Bedrock
storage: FileStorage | None

tokenize_text: Callable[[str], int] = default_tokenize_string
input_tokenizer_factory: Callable[
[ConverseDeployment, ConverseRequestWrapper],
Callable[[ConverseMessages], Awaitable[int]],
]
support_tools: bool
partitioner: Callable[[ConverseMessages], List[int]] = (
turn_based_partitioner
)

async def _discard_messages(
self, params: ConverseRequestWrapper, max_prompt_tokens: int | None
) -> Tuple[DiscardedMessages | None, ConverseRequestWrapper]:
if max_prompt_tokens is None:
return None, params

discarded_messages, messages = await truncate_prompt(
messages=params.messages.list,
tokenizer=self.input_tokenizer_factory(self.deployment, params),
keep_message=keep_last,
partitioner=self.partitioner,
model_limit=None,
user_limit=max_prompt_tokens,
)

return list(
params.messages.to_original_indices(discarded_messages)
), ConverseRequestWrapper(
messages=ListProjection(
omit_by_indices(messages, discarded_messages)
),
system=params.system,
inferenceConfig=params.inferenceConfig,
toolConfig=params.toolConfig,
)

async def count_prompt_tokens(
self, params: ModelParameters, messages: List[DialMessage]
) -> int:
converse_params = await self.construct_converse_params(messages, params)
return await self.input_tokenizer_factory(
self.deployment, converse_params
)(converse_params.messages.list)

async def count_completion_tokens(self, string: str) -> int:
return self.tokenize_text(string)

async def compute_discarded_messages(
self, params: ModelParameters, messages: List[DialMessage]
) -> DiscardedMessages | None:
converse_params = await self.construct_converse_params(messages, params)
discarded_messages, _ = await self._discard_messages(
converse_params, params.max_prompt_tokens
)
return discarded_messages

def get_tool_config(self, params: ModelParameters) -> ConverseTools | None:
if params.tool_config and not self.support_tools:
raise ValidationError("Tools are not supported")
return (
to_converse_tools(params.tool_config)
if params.tool_config
else None
)

async def construct_converse_params(
self,
messages: List[DialMessage],
params: ModelParameters,
) -> ConverseRequestWrapper:
system_prompt_extraction = extract_converse_system_prompt(messages)
adubovik marked this conversation as resolved.
Show resolved Hide resolved
processed_messages = await to_converse_messages(
adubovik marked this conversation as resolved.
Show resolved Hide resolved
system_prompt_extraction.non_system_messages,
self.storage,
start_offset=system_prompt_extraction.system_message_count,
)
system_message = system_prompt_extraction.system_prompt
if not processed_messages.list:
raise ValidationError("List of messages must not be empty")

return ConverseRequestWrapper(
system=[system_message] if system_message else None,
messages=processed_messages,
inferenceConfig=InferenceConfig(
**remove_nones(
{
"temperature": params.temperature,
"topP": params.top_p,
"maxTokens": params.max_tokens,
"stopSequences": params.stop,
}
)
),
toolConfig=self.get_tool_config(params),
)

def is_stream(self, params: ModelParameters) -> bool:
return params.stream

async def chat(
self,
consumer: Consumer,
params: ModelParameters,
messages: List[DialMessage],
) -> None:

converse_params = await self.construct_converse_params(messages, params)
discarded_messages, converse_params = await self._discard_messages(
converse_params, params.max_prompt_tokens
)
if not converse_params.messages.raw_list:
raise ValidationError("No messages left after truncation")

consumer.set_discarded_messages(discarded_messages)

if self.is_stream(params):
await process_streaming(
params=params,
stream=(
await self.bedrock.aconverse_streaming(
self.deployment, **converse_params.to_request()
)
),
consumer=consumer,
)
else:
process_non_streaming(
params=params,
response=await self.bedrock.aconverse_non_streaming(
self.deployment, **converse_params.to_request()
),
consumer=consumer,
)
12 changes: 12 additions & 0 deletions aidial_adapter_bedrock/llm/converse/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from aidial_sdk.chat_completion import FinishReason as DialFinishReason

from aidial_adapter_bedrock.llm.converse.types import ConverseStopReason

CONVERSE_TO_DIAL_FINISH_REASON = {
ConverseStopReason.END_TURN: DialFinishReason.STOP,
ConverseStopReason.TOOL_USE: DialFinishReason.TOOL_CALLS,
ConverseStopReason.MAX_TOKENS: DialFinishReason.LENGTH,
ConverseStopReason.STOP_SEQUENCE: DialFinishReason.STOP,
ConverseStopReason.GUARDRAIL_INTERVENED: DialFinishReason.CONTENT_FILTER,
ConverseStopReason.CONTENT_FILTERED: DialFinishReason.CONTENT_FILTER,
}
Loading
Loading