-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: add Converse API, add LLama 3.2 (#177)
- Loading branch information
1 parent
56a2065
commit 67d3d6d
Showing
24 changed files
with
1,611 additions
and
583 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,180 @@ | ||
from typing import Any, Awaitable, Callable, List, Tuple | ||
|
||
from aidial_sdk.chat_completion import Message as DialMessage | ||
|
||
from aidial_adapter_bedrock.bedrock import Bedrock | ||
from aidial_adapter_bedrock.dial_api.request import ModelParameters | ||
from aidial_adapter_bedrock.dial_api.storage import FileStorage | ||
from aidial_adapter_bedrock.llm.chat_model import ( | ||
ChatCompletionAdapter, | ||
keep_last, | ||
turn_based_partitioner, | ||
) | ||
from aidial_adapter_bedrock.llm.consumer import Consumer | ||
from aidial_adapter_bedrock.llm.converse.input import ( | ||
extract_converse_system_prompt, | ||
to_converse_messages, | ||
to_converse_tools, | ||
) | ||
from aidial_adapter_bedrock.llm.converse.output import ( | ||
process_non_streaming, | ||
process_streaming, | ||
) | ||
from aidial_adapter_bedrock.llm.converse.types import ( | ||
ConverseDeployment, | ||
ConverseMessage, | ||
ConverseRequestWrapper, | ||
ConverseTools, | ||
InferenceConfig, | ||
) | ||
from aidial_adapter_bedrock.llm.errors import ValidationError | ||
from aidial_adapter_bedrock.llm.tokenize import default_tokenize_string | ||
from aidial_adapter_bedrock.llm.truncate_prompt import ( | ||
DiscardedMessages, | ||
truncate_prompt, | ||
) | ||
from aidial_adapter_bedrock.utils.json import remove_nones | ||
from aidial_adapter_bedrock.utils.list import omit_by_indices | ||
from aidial_adapter_bedrock.utils.list_projection import ListProjection | ||
|
||
ConverseMessages = List[Tuple[ConverseMessage, Any]] | ||
|
||
|
||
class ConverseAdapter(ChatCompletionAdapter): | ||
deployment: str | ||
bedrock: Bedrock | ||
storage: FileStorage | None | ||
|
||
tokenize_text: Callable[[str], int] = default_tokenize_string | ||
input_tokenizer_factory: Callable[ | ||
[ConverseDeployment, ConverseRequestWrapper], | ||
Callable[[ConverseMessages], Awaitable[int]], | ||
] | ||
support_tools: bool | ||
partitioner: Callable[[ConverseMessages], List[int]] = ( | ||
turn_based_partitioner | ||
) | ||
|
||
async def _discard_messages( | ||
self, params: ConverseRequestWrapper, max_prompt_tokens: int | None | ||
) -> Tuple[DiscardedMessages | None, ConverseRequestWrapper]: | ||
if max_prompt_tokens is None: | ||
return None, params | ||
|
||
discarded_messages, messages = await truncate_prompt( | ||
messages=params.messages.list, | ||
tokenizer=self.input_tokenizer_factory(self.deployment, params), | ||
keep_message=keep_last, | ||
partitioner=self.partitioner, | ||
model_limit=None, | ||
user_limit=max_prompt_tokens, | ||
) | ||
|
||
return list( | ||
params.messages.to_original_indices(discarded_messages) | ||
), ConverseRequestWrapper( | ||
messages=ListProjection( | ||
omit_by_indices(messages, discarded_messages) | ||
), | ||
system=params.system, | ||
inferenceConfig=params.inferenceConfig, | ||
toolConfig=params.toolConfig, | ||
) | ||
|
||
async def count_prompt_tokens( | ||
self, params: ModelParameters, messages: List[DialMessage] | ||
) -> int: | ||
converse_params = await self.construct_converse_params(messages, params) | ||
return await self.input_tokenizer_factory( | ||
self.deployment, converse_params | ||
)(converse_params.messages.list) | ||
|
||
async def count_completion_tokens(self, string: str) -> int: | ||
return self.tokenize_text(string) | ||
|
||
async def compute_discarded_messages( | ||
self, params: ModelParameters, messages: List[DialMessage] | ||
) -> DiscardedMessages | None: | ||
converse_params = await self.construct_converse_params(messages, params) | ||
discarded_messages, _ = await self._discard_messages( | ||
converse_params, params.max_prompt_tokens | ||
) | ||
return discarded_messages | ||
|
||
def get_tool_config(self, params: ModelParameters) -> ConverseTools | None: | ||
if params.tool_config and not self.support_tools: | ||
raise ValidationError("Tools are not supported") | ||
return ( | ||
to_converse_tools(params.tool_config) | ||
if params.tool_config | ||
else None | ||
) | ||
|
||
async def construct_converse_params( | ||
self, | ||
messages: List[DialMessage], | ||
params: ModelParameters, | ||
) -> ConverseRequestWrapper: | ||
system_prompt_extraction = extract_converse_system_prompt(messages) | ||
converse_messages = await to_converse_messages( | ||
system_prompt_extraction.non_system_messages, | ||
self.storage, | ||
start_offset=system_prompt_extraction.system_message_count, | ||
) | ||
system_message = system_prompt_extraction.system_prompt | ||
if not converse_messages.list: | ||
raise ValidationError("List of messages must not be empty") | ||
|
||
return ConverseRequestWrapper( | ||
system=[system_message] if system_message else None, | ||
messages=converse_messages, | ||
inferenceConfig=InferenceConfig( | ||
**remove_nones( | ||
{ | ||
"temperature": params.temperature, | ||
"topP": params.top_p, | ||
"maxTokens": params.max_tokens, | ||
"stopSequences": params.stop, | ||
} | ||
) | ||
), | ||
toolConfig=self.get_tool_config(params), | ||
) | ||
|
||
def is_stream(self, params: ModelParameters) -> bool: | ||
return params.stream | ||
|
||
async def chat( | ||
self, | ||
consumer: Consumer, | ||
params: ModelParameters, | ||
messages: List[DialMessage], | ||
) -> None: | ||
|
||
converse_params = await self.construct_converse_params(messages, params) | ||
discarded_messages, converse_params = await self._discard_messages( | ||
converse_params, params.max_prompt_tokens | ||
) | ||
if not converse_params.messages.raw_list: | ||
raise ValidationError("No messages left after truncation") | ||
|
||
consumer.set_discarded_messages(discarded_messages) | ||
|
||
if self.is_stream(params): | ||
await process_streaming( | ||
params=params, | ||
stream=( | ||
await self.bedrock.aconverse_streaming( | ||
self.deployment, **converse_params.to_request() | ||
) | ||
), | ||
consumer=consumer, | ||
) | ||
else: | ||
process_non_streaming( | ||
params=params, | ||
response=await self.bedrock.aconverse_non_streaming( | ||
self.deployment, **converse_params.to_request() | ||
), | ||
consumer=consumer, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
from aidial_sdk.chat_completion import FinishReason as DialFinishReason | ||
|
||
from aidial_adapter_bedrock.llm.converse.types import ConverseStopReason | ||
|
||
CONVERSE_TO_DIAL_FINISH_REASON = { | ||
ConverseStopReason.END_TURN: DialFinishReason.STOP, | ||
ConverseStopReason.TOOL_USE: DialFinishReason.TOOL_CALLS, | ||
ConverseStopReason.MAX_TOKENS: DialFinishReason.LENGTH, | ||
ConverseStopReason.STOP_SEQUENCE: DialFinishReason.STOP, | ||
ConverseStopReason.GUARDRAIL_INTERVENED: DialFinishReason.CONTENT_FILTER, | ||
ConverseStopReason.CONTENT_FILTERED: DialFinishReason.CONTENT_FILTER, | ||
} |
Oops, something went wrong.