Skip to content

Commit

Permalink
fix: migrated to sdk==0.7.0; turned discarded_messages into a list (#58)
Browse files Browse the repository at this point in the history
  • Loading branch information
adubovik authored Mar 5, 2024
1 parent 508357c commit ff52263
Show file tree
Hide file tree
Showing 8 changed files with 38 additions and 26 deletions.
18 changes: 4 additions & 14 deletions aidial_adapter_vertexai/chat/bison/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from aidial_adapter_vertexai.chat.bison.prompt import BisonPrompt
from aidial_adapter_vertexai.chat.bison.truncate_prompt import (
get_discarded_messages_count,
get_discarded_messages,
)
from aidial_adapter_vertexai.chat.chat_completion_adapter import (
ChatCompletionAdapter,
Expand Down Expand Up @@ -42,21 +42,11 @@ async def parse_prompt(self, messages: List[Message]) -> BisonPrompt:
@override
async def truncate_prompt(
self, prompt: BisonPrompt, max_prompt_tokens: int
) -> Tuple[BisonPrompt, int]:
) -> Tuple[BisonPrompt, List[int]]:
if max_prompt_tokens is None:
return prompt, 0
return prompt, []

discarded = await get_discarded_messages_count(
self, prompt, max_prompt_tokens
)

return (
BisonPrompt(
context=prompt.context,
messages=prompt.messages[discarded:],
),
discarded,
)
return await get_discarded_messages(self, prompt, max_prompt_tokens)

@override
async def chat(
Expand Down
21 changes: 21 additions & 0 deletions aidial_adapter_vertexai/chat/bison/truncate_prompt.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import List, Tuple

from aidial_adapter_vertexai.chat.bison.prompt import BisonPrompt
from aidial_adapter_vertexai.chat.chat_completion_adapter import (
ChatCompletionAdapter,
Expand Down Expand Up @@ -77,3 +79,22 @@ async def get_discarded_messages_count(
discarded_messages_count -= 2

return discarded_messages_count


async def get_discarded_messages(
model: ChatCompletionAdapter[BisonPrompt],
prompt: BisonPrompt,
max_prompt_tokens: int,
) -> Tuple[BisonPrompt, List[int]]:
count = await get_discarded_messages_count(model, prompt, max_prompt_tokens)

truncated_prompt = BisonPrompt(
context=prompt.context,
messages=prompt.messages[count:],
)

discarded_indices = list(range(count))
if prompt.context is not None:
discarded_indices = list(map(lambda x: x + 1, discarded_indices))

return truncated_prompt, discarded_indices
2 changes: 1 addition & 1 deletion aidial_adapter_vertexai/chat/chat_completion_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ async def parse_prompt(self, messages: List[Message]) -> P | UserError:
@abstractmethod
async def truncate_prompt(
self, prompt: P, max_prompt_tokens: int
) -> Tuple[P, int]:
) -> Tuple[P, List[int]]:
pass

@abstractmethod
Expand Down
2 changes: 1 addition & 1 deletion aidial_adapter_vertexai/chat/gemini/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ async def parse_prompt(
@override
async def truncate_prompt(
self, prompt: GeminiPrompt, max_prompt_tokens: int
) -> Tuple[GeminiPrompt, int]:
) -> Tuple[GeminiPrompt, List[int]]:
raise NotImplementedError(
"Prompt truncation is not supported for Genimi model yet"
)
Expand Down
4 changes: 2 additions & 2 deletions aidial_adapter_vertexai/chat/imagen/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ async def parse_prompt(self, messages: List[Message]) -> ImagenPrompt:
@override
async def truncate_prompt(
self, prompt: ImagenPrompt, max_prompt_tokens: int
) -> Tuple[ImagenPrompt, int]:
return prompt, 0
) -> Tuple[ImagenPrompt, List[int]]:
return prompt, []

@staticmethod
def get_image_type(image: PIL_Image.Image) -> str:
Expand Down
7 changes: 4 additions & 3 deletions aidial_adapter_vertexai/chat_completion.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
from typing import List

from aidial_sdk.chat_completion import ChatCompletion, Request, Response, Status

Expand Down Expand Up @@ -56,9 +57,9 @@ async def chat_completion(self, request: Request, response: Response):
if n > 1 and params.stream:
raise ValidationError("n>1 is not supported in streaming mode")

discarded_messages_count = 0
discarded_messages: List[int] = []
if params.max_prompt_tokens is not None:
prompt, discarded_messages_count = await model.truncate_prompt(
prompt, discarded_messages = await model.truncate_prompt(
prompt, params.max_prompt_tokens
)

Expand All @@ -84,4 +85,4 @@ async def generate_response(usage: TokenUsage, choice_idx: int) -> None:
response.set_usage(usage.prompt_tokens, usage.completion_tokens)

if params.max_prompt_tokens is not None:
response.set_discarded_messages(discarded_messages_count)
response.set_discarded_messages(discarded_messages)
8 changes: 4 additions & 4 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ repository = "https://github.com/epam/ai-dial-adapter-vertexai/"

[tool.poetry.dependencies]
python = "~3.11"
aidial-sdk = {version = "0.6.2", extras = ["telemetry"]}
aidial-sdk = {version = "0.7.0", extras = ["telemetry"]}
fastapi = "0.109.2"
google-cloud-aiplatform = "1.38.1"
google-auth = "2.21.0"
Expand Down

0 comments on commit ff52263

Please sign in to comment.