From ff52263ee7559b6fd7185adfb32db7ce6ef34cdf Mon Sep 17 00:00:00 2001 From: Anton Dubovik Date: Tue, 5 Mar 2024 14:24:12 +0000 Subject: [PATCH] fix: migrated to sdk==0.7.0; turned discarded_messages into a list (#58) --- aidial_adapter_vertexai/chat/bison/base.py | 18 ++++------------ .../chat/bison/truncate_prompt.py | 21 +++++++++++++++++++ .../chat/chat_completion_adapter.py | 2 +- .../chat/gemini/adapter.py | 2 +- .../chat/imagen/adapter.py | 4 ++-- aidial_adapter_vertexai/chat_completion.py | 7 ++++--- poetry.lock | 8 +++---- pyproject.toml | 2 +- 8 files changed, 38 insertions(+), 26 deletions(-) diff --git a/aidial_adapter_vertexai/chat/bison/base.py b/aidial_adapter_vertexai/chat/bison/base.py index bba7994..4cad67c 100644 --- a/aidial_adapter_vertexai/chat/bison/base.py +++ b/aidial_adapter_vertexai/chat/bison/base.py @@ -11,7 +11,7 @@ from aidial_adapter_vertexai.chat.bison.prompt import BisonPrompt from aidial_adapter_vertexai.chat.bison.truncate_prompt import ( - get_discarded_messages_count, + get_discarded_messages, ) from aidial_adapter_vertexai.chat.chat_completion_adapter import ( ChatCompletionAdapter, @@ -42,21 +42,11 @@ async def parse_prompt(self, messages: List[Message]) -> BisonPrompt: @override async def truncate_prompt( self, prompt: BisonPrompt, max_prompt_tokens: int - ) -> Tuple[BisonPrompt, int]: + ) -> Tuple[BisonPrompt, List[int]]: if max_prompt_tokens is None: - return prompt, 0 + return prompt, [] - discarded = await get_discarded_messages_count( - self, prompt, max_prompt_tokens - ) - - return ( - BisonPrompt( - context=prompt.context, - messages=prompt.messages[discarded:], - ), - discarded, - ) + return await get_discarded_messages(self, prompt, max_prompt_tokens) @override async def chat( diff --git a/aidial_adapter_vertexai/chat/bison/truncate_prompt.py b/aidial_adapter_vertexai/chat/bison/truncate_prompt.py index e79c937..42eb3ce 100644 --- a/aidial_adapter_vertexai/chat/bison/truncate_prompt.py +++ b/aidial_adapter_vertexai/chat/bison/truncate_prompt.py @@ -1,3 +1,5 @@ +from typing import List, Tuple + from aidial_adapter_vertexai.chat.bison.prompt import BisonPrompt from aidial_adapter_vertexai.chat.chat_completion_adapter import ( ChatCompletionAdapter, @@ -77,3 +79,22 @@ async def get_discarded_messages_count( discarded_messages_count -= 2 return discarded_messages_count + + +async def get_discarded_messages( + model: ChatCompletionAdapter[BisonPrompt], + prompt: BisonPrompt, + max_prompt_tokens: int, +) -> Tuple[BisonPrompt, List[int]]: + count = await get_discarded_messages_count(model, prompt, max_prompt_tokens) + + truncated_prompt = BisonPrompt( + context=prompt.context, + messages=prompt.messages[count:], + ) + + discarded_indices = list(range(count)) + if prompt.context is not None: + discarded_indices = list(map(lambda x: x + 1, discarded_indices)) + + return truncated_prompt, discarded_indices diff --git a/aidial_adapter_vertexai/chat/chat_completion_adapter.py b/aidial_adapter_vertexai/chat/chat_completion_adapter.py index d3b7a03..a9a2727 100644 --- a/aidial_adapter_vertexai/chat/chat_completion_adapter.py +++ b/aidial_adapter_vertexai/chat/chat_completion_adapter.py @@ -18,7 +18,7 @@ async def parse_prompt(self, messages: List[Message]) -> P | UserError: @abstractmethod async def truncate_prompt( self, prompt: P, max_prompt_tokens: int - ) -> Tuple[P, int]: + ) -> Tuple[P, List[int]]: pass @abstractmethod diff --git a/aidial_adapter_vertexai/chat/gemini/adapter.py b/aidial_adapter_vertexai/chat/gemini/adapter.py index 8dc65d1..46670b0 100644 --- a/aidial_adapter_vertexai/chat/gemini/adapter.py +++ b/aidial_adapter_vertexai/chat/gemini/adapter.py @@ -91,7 +91,7 @@ async def parse_prompt( @override async def truncate_prompt( self, prompt: GeminiPrompt, max_prompt_tokens: int - ) -> Tuple[GeminiPrompt, int]: + ) -> Tuple[GeminiPrompt, List[int]]: raise NotImplementedError( "Prompt truncation is not supported for Genimi model yet" ) diff --git a/aidial_adapter_vertexai/chat/imagen/adapter.py b/aidial_adapter_vertexai/chat/imagen/adapter.py index ef7da62..c290b19 100644 --- a/aidial_adapter_vertexai/chat/imagen/adapter.py +++ b/aidial_adapter_vertexai/chat/imagen/adapter.py @@ -50,8 +50,8 @@ async def parse_prompt(self, messages: List[Message]) -> ImagenPrompt: @override async def truncate_prompt( self, prompt: ImagenPrompt, max_prompt_tokens: int - ) -> Tuple[ImagenPrompt, int]: - return prompt, 0 + ) -> Tuple[ImagenPrompt, List[int]]: + return prompt, [] @staticmethod def get_image_type(image: PIL_Image.Image) -> str: diff --git a/aidial_adapter_vertexai/chat_completion.py b/aidial_adapter_vertexai/chat_completion.py index a935b54..f7eed58 100644 --- a/aidial_adapter_vertexai/chat_completion.py +++ b/aidial_adapter_vertexai/chat_completion.py @@ -1,4 +1,5 @@ import asyncio +from typing import List from aidial_sdk.chat_completion import ChatCompletion, Request, Response, Status @@ -56,9 +57,9 @@ async def chat_completion(self, request: Request, response: Response): if n > 1 and params.stream: raise ValidationError("n>1 is not supported in streaming mode") - discarded_messages_count = 0 + discarded_messages: List[int] = [] if params.max_prompt_tokens is not None: - prompt, discarded_messages_count = await model.truncate_prompt( + prompt, discarded_messages = await model.truncate_prompt( prompt, params.max_prompt_tokens ) @@ -84,4 +85,4 @@ async def generate_response(usage: TokenUsage, choice_idx: int) -> None: response.set_usage(usage.prompt_tokens, usage.completion_tokens) if params.max_prompt_tokens is not None: - response.set_discarded_messages(discarded_messages_count) + response.set_discarded_messages(discarded_messages) diff --git a/poetry.lock b/poetry.lock index 1d4438a..1cbf7de 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2,13 +2,13 @@ [[package]] name = "aidial-sdk" -version = "0.6.2" +version = "0.7.0" description = "Framework to create applications and model adapters for AI DIAL" optional = false python-versions = ">=3.8.1,<4.0" files = [ - {file = "aidial_sdk-0.6.2-py3-none-any.whl", hash = "sha256:fa1cc43f1f8f70047e81adc5fae9914ddf6c94e4d7f55b83ba7ecca3cea5d122"}, - {file = "aidial_sdk-0.6.2.tar.gz", hash = "sha256:46dafb6360cad6cea08531d3cea7600d87cda06cd8c86a560330b61d0a492cab"}, + {file = "aidial_sdk-0.7.0-py3-none-any.whl", hash = "sha256:e22a948011f6ed55d7f7eef4c0f589f26d6e2412a6b55072be5fd37e8adc5752"}, + {file = "aidial_sdk-0.7.0.tar.gz", hash = "sha256:a239af55a29742c18446df8a8a29ced8fedd2deebc8cf351d565fcbf8299c295"}, ] [package.dependencies] @@ -2973,4 +2973,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = "~3.11" -content-hash = "4edc35ee316c98f6e54b28a9819f7cc4e556dd99f49132ae5698fc18d2f2bd9d" +content-hash = "3a7a1eba7057c3385c31e422504490b6d2ecd6acd898e113cbbedaf19ec662b1" diff --git a/pyproject.toml b/pyproject.toml index caf8112..7256f3c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,7 +15,7 @@ repository = "https://github.com/epam/ai-dial-adapter-vertexai/" [tool.poetry.dependencies] python = "~3.11" -aidial-sdk = {version = "0.6.2", extras = ["telemetry"]} +aidial-sdk = {version = "0.7.0", extras = ["telemetry"]} fastapi = "0.109.2" google-cloud-aiplatform = "1.38.1" google-auth = "2.21.0"