Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: added basic support of genini pro #33

Merged
merged 11 commits into from
Jan 10, 2024
2 changes: 1 addition & 1 deletion .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
"editor.defaultFormatter": "ms-python.black-formatter",
"editor.formatOnSave": true,
"editor.codeActionsOnSave": {
"source.organizeImports": true
"source.organizeImports": "explicit"
},
"editor.tabSize": 4
},
Expand Down
108 changes: 12 additions & 96 deletions aidial_adapter_vertexai/chat_completion.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,14 @@
import asyncio
from typing import List, Optional, Tuple

from aidial_sdk.chat_completion import (
ChatCompletion,
Message,
Request,
Response,
Role,
)
from aidial_sdk.chat_completion import ChatCompletion, Request, Response

from aidial_adapter_vertexai.llm.consumer import ChoiceConsumer
from aidial_adapter_vertexai.llm.exceptions import ValidationError
from aidial_adapter_vertexai.llm.history_trimming import (
get_discarded_messages_count,
from aidial_adapter_vertexai.llm.chat_completion_adapter import (
ChatCompletionAdapter,
)
from aidial_adapter_vertexai.llm.consumer import ChoiceConsumer
from aidial_adapter_vertexai.llm.vertex_ai_adapter import (
get_chat_completion_model,
)
from aidial_adapter_vertexai.llm.vertex_ai_chat import (
VertexAIAuthor,
VertexAIMessage,
)
from aidial_adapter_vertexai.llm.vertex_ai_deployments import (
ChatCompletionDeployment,
)
Expand All @@ -29,78 +17,6 @@
from aidial_adapter_vertexai.universal_api.token_usage import TokenUsage
from aidial_adapter_vertexai.utils.log_config import app_logger as log

_SUPPORTED_ROLES = {Role.SYSTEM, Role.USER, Role.ASSISTANT}


def _parse_message(message: Message) -> VertexAIMessage:
author = (
VertexAIAuthor.BOT
if message.role == Role.ASSISTANT
else VertexAIAuthor.USER
)
return VertexAIMessage(author=author, content=message.content) # type: ignore


def _validate_messages_and_split(
messages: List[Message],
) -> Tuple[Optional[str], List[Message]]:
if len(messages) == 0:
raise ValidationError("The chat history must have at least one message")

for message in messages:
if message.content is None:
raise ValidationError("Message content must be present")

if message.role not in _SUPPORTED_ROLES:
raise ValidationError(
f"Message role must be one of {_SUPPORTED_ROLES}"
)

context: Optional[str] = None
if len(messages) > 0 and messages[0].role == Role.SYSTEM:
context = messages[0].content or ""
context = context if context.strip() else None
messages = messages[1:]

if len(messages) == 0 and context is not None:
raise ValidationError(
"The chat history must have at least one non-system message"
)

role: Optional[Role] = None
for message in messages:
if message.role == Role.SYSTEM:
raise ValidationError(
"System messages other than the initial system message are not allowed"
)

# Bison doesn't support empty messages,
# so we replace it with a single space.
message.content = message.content or " "

if role == message.role:
raise ValidationError("Messages must alternate between authors")

role = message.role

if len(messages) % 2 == 0:
raise ValidationError(
"There should be odd number of messages for correct alternating turn"
)

if messages[-1].role != Role.USER:
raise ValidationError("The last message must be a user message")

return context, messages


def _parse_history(
history: List[Message],
) -> Tuple[Optional[str], List[VertexAIMessage]]:
context, history = _validate_messages_and_split(history)

return context, list(map(_parse_message, history))


class VertexAIChatCompletion(ChatCompletion):
region: str
Expand All @@ -112,25 +28,25 @@ def __init__(self, region: str, project_id: str):

@dial_exception_decorator
async def chat_completion(self, request: Request, response: Response):
model = await get_chat_completion_model(
model: ChatCompletionAdapter = await get_chat_completion_model(
deployment=ChatCompletionDeployment(request.deployment_id),
project_id=self.project_id,
location=self.region,
)

params = ModelParameters.create(request)
context, messages = _parse_history(request.messages)
discarded_messages_count: Optional[int] = None
prompt = await model.parse_prompt(request.messages)

discarded_messages_count = 0
if params.max_prompt_tokens is not None:
discarded_messages_count = await get_discarded_messages_count(
model, context, messages, params.max_prompt_tokens
prompt, discarded_messages_count = await model.truncate_prompt(
prompt, params.max_prompt_tokens
)
messages = messages[discarded_messages_count:]

async def generate_response(usage: TokenUsage, choice_idx: int) -> None:
with response.create_choice() as choice:
consumer = ChoiceConsumer(choice)
await model.chat(consumer, context, messages, params)
await model.chat(params, consumer, prompt)
usage.accumulate(consumer.usage)

usage = TokenUsage()
Expand All @@ -142,5 +58,5 @@ async def generate_response(usage: TokenUsage, choice_idx: int) -> None:
log.debug(f"usage: {usage}")
response.set_usage(usage.prompt_tokens, usage.completion_tokens)

if discarded_messages_count is not None:
if params.max_prompt_tokens is not None:
response.set_discarded_messages(discarded_messages_count)
42 changes: 14 additions & 28 deletions aidial_adapter_vertexai/llm/bison_adapter.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,25 @@
from typing import Any, Dict, List, Optional
from typing import Any, Dict

from typing_extensions import override

from aidial_adapter_vertexai.llm.chat_completion_adapter import (
ChatCompletionAdapter,
from aidial_adapter_vertexai.llm.bison_chat_completion_adapter import (
BisonChatCompletionAdapter,
)
from aidial_adapter_vertexai.llm.bison_prompt import BisonPrompt
from aidial_adapter_vertexai.llm.exceptions import ValidationError
from aidial_adapter_vertexai.llm.vertex_ai_chat import VertexAIMessage
from aidial_adapter_vertexai.universal_api.request import ModelParameters


class BisonChatAdapter(ChatCompletionAdapter):
class BisonChatAdapter(BisonChatCompletionAdapter):
@override
def _create_instance(
self,
context: Optional[str],
messages: List[VertexAIMessage],
) -> Dict[str, Any]:
def _create_instance(self, prompt: BisonPrompt) -> Dict[str, Any]:
return {
"context": context or "",
"messages": messages,
"context": prompt.context or "",
"messages": prompt.messages,
}

@override
def _create_parameters(
self,
params: ModelParameters,
) -> Dict[str, Any]:
def _create_parameters(self, params: ModelParameters) -> Dict[str, Any]:
# See chat playground: https://console.cloud.google.com/vertex-ai/generative/language/create/chat
ret: Dict[str, Any] = {}

Expand All @@ -47,25 +40,18 @@ def _create_parameters(
return ret


class BisonCodeChatAdapter(ChatCompletionAdapter):
class BisonCodeChatAdapter(BisonChatCompletionAdapter):
@override
def _create_instance(
self,
context: Optional[str],
messages: List[VertexAIMessage],
) -> Dict[str, Any]:
if context is not None:
def _create_instance(self, prompt: BisonPrompt) -> Dict[str, Any]:
if prompt.context is not None:
raise ValidationError("System message is not supported")

return {
"messages": messages,
"messages": prompt.messages,
}

@override
def _create_parameters(
self,
params: ModelParameters,
) -> Dict[str, Any]:
def _create_parameters(self, params: ModelParameters) -> Dict[str, Any]:
ret: Dict[str, Any] = {}

if params.max_tokens is not None:
Expand Down
104 changes: 104 additions & 0 deletions aidial_adapter_vertexai/llm/bison_chat_completion_adapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import asyncio
from abc import abstractmethod
from typing import Any, Dict, List, Tuple

from aidial_sdk.chat_completion import Message
from typing_extensions import override

from aidial_adapter_vertexai.llm.bison_history_trimming import (
get_discarded_messages_count,
)
from aidial_adapter_vertexai.llm.bison_prompt import BisonPrompt
from aidial_adapter_vertexai.llm.chat_completion_adapter import (
ChatCompletionAdapter,
)
from aidial_adapter_vertexai.llm.consumer import Consumer
from aidial_adapter_vertexai.llm.vertex_ai import get_vertex_ai_chat
from aidial_adapter_vertexai.llm.vertex_ai_chat import (
VertexAIAuthor,
VertexAIChat,
VertexAIMessage,
)
from aidial_adapter_vertexai.universal_api.request import ModelParameters
from aidial_adapter_vertexai.universal_api.token_usage import TokenUsage


class BisonChatCompletionAdapter(ChatCompletionAdapter[BisonPrompt]):
def __init__(self, model: VertexAIChat):
self.model = model

@abstractmethod
def _create_instance(self, prompt: BisonPrompt) -> Dict[str, Any]:
pass

@abstractmethod
def _create_parameters(self, params: ModelParameters) -> Dict[str, Any]:
pass

@override
async def parse_prompt(self, messages: List[Message]) -> BisonPrompt:
return BisonPrompt.parse(messages)

@override
async def truncate_prompt(
self, prompt: BisonPrompt, max_prompt_tokens: int
) -> Tuple[BisonPrompt, int]:
if max_prompt_tokens is None:
return prompt, 0

discarded = await get_discarded_messages_count(
self, prompt, max_prompt_tokens
)

return (
BisonPrompt(
context=prompt.context,
messages=prompt.messages[discarded:],
),
discarded,
)

@override
async def chat(
self, params: ModelParameters, consumer: Consumer, prompt: BisonPrompt
) -> None:
content_task = self.model.predict(
params.stream,
consumer,
self._create_instance(prompt),
self._create_parameters(params),
)

if params.stream:
# Token usage isn't reported for streaming requests.
# Computing it manually
prompt_tokens, content = await asyncio.gather(
self.count_prompt_tokens(prompt), content_task
)
completion_tokens = await self.count_completion_tokens(content)

await consumer.set_usage(
TokenUsage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
)
)
else:
await content_task

@override
async def count_prompt_tokens(self, prompt: BisonPrompt) -> int:
return await self.model.count_tokens(self._create_instance(prompt))

@override
async def count_completion_tokens(self, string: str) -> int:
messages = [VertexAIMessage(author=VertexAIAuthor.USER, content=string)]
return await self.model.count_tokens(
self._create_instance(BisonPrompt(context=None, messages=messages))
)

@override
@classmethod
async def create(cls, model_id: str, project_id: str, location: str):
model = get_vertex_ai_chat(model_id, project_id, location)
return cls(model)
Loading