Skip to content

Commit

Permalink
feat: supported Gemini Pro Vision (#34)
Browse files Browse the repository at this point in the history
  • Loading branch information
adubovik authored Jan 12, 2024
1 parent 36f9428 commit f9302ed
Show file tree
Hide file tree
Showing 22 changed files with 660 additions and 97 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ Copy `.env.example` to `.env` and customize it for your environment:
|AIDIAL_LOG_LEVEL|WARNING|AI DIAL SDK log level|
|WEB_CONCURRENCY|1|Number of workers for the server|
|TEST_SERVER_URL|http://0.0.0.0:5001|Server URL used in the integration tests|
|DIAL_URL||URL of the core DIAL server. Optional. Used to access images stored in the DIAL File storage|

### Docker

Expand Down
20 changes: 18 additions & 2 deletions aidial_adapter_vertexai/chat_completion.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import asyncio

from aidial_sdk.chat_completion import ChatCompletion, Request, Response
from aidial_sdk.chat_completion import ChatCompletion, Request, Response, Status

from aidial_adapter_vertexai.llm.chat_completion_adapter import (
ChatCompletionAdapter,
)
from aidial_adapter_vertexai.llm.consumer import ChoiceConsumer
from aidial_adapter_vertexai.llm.exceptions import UserError
from aidial_adapter_vertexai.llm.vertex_ai_adapter import (
get_chat_completion_model,
)
Expand All @@ -28,15 +29,30 @@ def __init__(self, region: str, project_id: str):

@dial_exception_decorator
async def chat_completion(self, request: Request, response: Response):
headers = request.headers
model: ChatCompletionAdapter = await get_chat_completion_model(
deployment=ChatCompletionDeployment(request.deployment_id),
project_id=self.project_id,
location=self.region,
headers=headers,
)

params = ModelParameters.create(request)
prompt = await model.parse_prompt(request.messages)

if isinstance(prompt, UserError):
# Show the error message in a stage for a web UI user
with response.create_choice() as choice:
stage = choice.create_stage("Error")
stage.open()
stage.append_content(prompt.to_message_for_chat_user())
stage.close(Status.FAILED)
await response.aflush()

# Raise exception for a DIAL API client
raise Exception(prompt.message)

params = ModelParameters.create(request)

discarded_messages_count = 0
if params.max_prompt_tokens is not None:
prompt, discarded_messages_count = await model.truncate_prompt(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,6 @@ async def count_completion_tokens(self, string: str) -> int:
self._create_instance(BisonPrompt(context=None, messages=messages))
)

@override
@classmethod
async def create(cls, model_id: str, project_id: str, location: str):
model = get_vertex_ai_chat(model_id, project_id, location)
Expand Down
10 changes: 2 additions & 8 deletions aidial_adapter_vertexai/llm/chat_completion_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,15 @@
from aidial_sdk.chat_completion import Message

from aidial_adapter_vertexai.llm.consumer import Consumer
from aidial_adapter_vertexai.llm.exceptions import UserError
from aidial_adapter_vertexai.universal_api.request import ModelParameters

P = TypeVar("P")


class ChatCompletionAdapter(ABC, Generic[P]):
@classmethod
@abstractmethod
async def create(
cls, model_id: str, project_id: str, location: str
) -> "ChatCompletionAdapter":
pass

@abstractmethod
async def parse_prompt(self, messages: List[Message]) -> P:
async def parse_prompt(self, messages: List[Message]) -> P | UserError:
pass

@abstractmethod
Expand Down
10 changes: 10 additions & 0 deletions aidial_adapter_vertexai/llm/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,13 @@ class ValidationError(Exception):
def __init__(self, message: str):
self.message = message
super().__init__(self.message)


class UserError(Exception):
def __init__(self, message: str, usage: str):
self.message = message
self.usage = usage
super().__init__(self.message)

def to_message_for_chat_user(self) -> str:
return f"{self.message}\n\n{self.usage}"
44 changes: 32 additions & 12 deletions aidial_adapter_vertexai/llm/gemini_chat_completion_adapter.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import AsyncIterator, Dict, List, Tuple
from logging import DEBUG
from typing import AsyncIterator, Dict, List, Optional, Tuple

from aidial_sdk.chat_completion import Message
from google.cloud.aiplatform_v1beta1.types import content as gapic_content_types
Expand All @@ -13,13 +14,16 @@
ChatCompletionAdapter,
)
from aidial_adapter_vertexai.llm.consumer import Consumer
from aidial_adapter_vertexai.llm.exceptions import UserError
from aidial_adapter_vertexai.llm.gemini_prompt import GeminiPrompt
from aidial_adapter_vertexai.llm.vertex_ai import (
get_gemini_model,
init_vertex_ai,
)
from aidial_adapter_vertexai.universal_api.request import ModelParameters
from aidial_adapter_vertexai.universal_api.storage import FileStorage
from aidial_adapter_vertexai.universal_api.token_usage import TokenUsage
from aidial_adapter_vertexai.utils.json import json_dumps_short
from aidial_adapter_vertexai.utils.log_config import vertex_ai_logger as log
from aidial_adapter_vertexai.utils.timer import Timer

Expand Down Expand Up @@ -48,12 +52,24 @@ def create_generation_config(params: ModelParameters) -> GenerationConfig:


class GeminiChatCompletionAdapter(ChatCompletionAdapter[GeminiPrompt]):
def __init__(self, model: GenerativeModel):
def __init__(
self,
file_storage: Optional[FileStorage],
model: GenerativeModel,
is_vision_model: bool,
):
self.file_storage = file_storage
self.model = model
self.is_vision_model = is_vision_model

@override
async def parse_prompt(self, messages: List[Message]) -> GeminiPrompt:
return GeminiPrompt.parse(messages)
async def parse_prompt(
self, messages: List[Message]
) -> GeminiPrompt | UserError:
if self.is_vision_model:
return await GeminiPrompt.parse_vision(self.file_storage, messages)
else:
return GeminiPrompt.parse_non_vision(messages)

@override
async def truncate_prompt(
Expand Down Expand Up @@ -98,11 +114,11 @@ async def chat(
prompt_tokens = await self.count_prompt_tokens(prompt)

with Timer("predict timing: {time}", log.debug):
log.debug(
"predict request: "
f"parameters=({params}), "
f"prompt=({prompt}), "
)
if log.isEnabledFor(DEBUG):
log.debug(
"predict request: "
+ json_dumps_short({"parameters": params, "prompt": prompt})
)

completion = ""

Expand Down Expand Up @@ -133,11 +149,15 @@ async def count_completion_tokens(self, string: str) -> int:
resp = await self.model.count_tokens_async(string)
return resp.total_tokens

@override
@classmethod
async def create(
cls, model_id: str, project_id: str, location: str
cls,
file_storage: Optional[FileStorage],
model_id: str,
is_vision_model: bool,
project_id: str,
location: str,
) -> "GeminiChatCompletionAdapter":
await init_vertex_ai(project_id, location)
model = await get_gemini_model(model_id)
return cls(model)
return cls(file_storage, model, is_vision_model)
168 changes: 128 additions & 40 deletions aidial_adapter_vertexai/llm/gemini_prompt.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,28 @@
from typing import List, assert_never
import base64
from typing import List, Optional, Union, assert_never

from aidial_sdk.chat_completion import Message, Role
from pydantic import BaseModel
from vertexai.preview.generative_models import ChatSession, Content, Part

from aidial_adapter_vertexai.llm.exceptions import ValidationError
from aidial_adapter_vertexai.utils.list import cluster_by
from aidial_adapter_vertexai.llm.exceptions import UserError, ValidationError
from aidial_adapter_vertexai.llm.process_inputs import (
MessageWithInputs,
download_inputs,
)
from aidial_adapter_vertexai.universal_api.storage import FileStorage

# Pricing info: https://cloud.google.com/vertex-ai/pricing
# Supported image types:
# https://cloud.google.com/vertex-ai/docs/generative-ai/multimodal/send-multimodal-prompts?authuser=1#image-requirements
SUPPORTED_IMAGE_TYPES = ["image/jpeg", "image/png"]
SUPPORTED_FILE_EXTS = ["jpg", "jpeg", "png"]
# NOTE: Tokens per image: 258. count_tokens API call takes this into account.
# Up to 16 images. Total max size 4MB.

# NOTE: See also supported video formats:
# https://cloud.google.com/vertex-ai/docs/generative-ai/multimodal/send-multimodal-prompts?authuser=1#video-requirements
# Tokens per video: 1032


class GeminiPrompt(BaseModel):
Expand All @@ -16,18 +33,51 @@ class Config:
arbitrary_types_allowed = True

@classmethod
def parse(cls, messages: List[Message]) -> "GeminiPrompt":
def parse_non_vision(cls, messages: List[Message]) -> "GeminiPrompt":
if len(messages) == 0:
raise ValidationError(
"The chat history must have at least one message"
)

simple_messages = list(map(SimpleMessage.from_message, messages))
history = [
SimpleMessage.from_messages(cluster).to_content()
for cluster in cluster_by(lambda c: c.role, simple_messages)
messages = accommodate_first_system_message(messages)

msgs = [
MessageWithInputs(message=message, image_inputs=[])
for message in messages
]

history = list(map(to_content, msgs))
return cls(history=history[:-1], prompt=history[-1].parts)

@classmethod
async def parse_vision(
cls,
file_storage: Optional[FileStorage],
messages: List[Message],
) -> Union["GeminiPrompt", UserError]:
if len(messages) == 0:
raise ValidationError(
"The chat history must have at least one message"
)

# NOTE: Vision model can't handle multiple messages with images.
# It throws "Invalid request 500" error.
messages = messages[-1:]

download_result = await download_inputs(
file_storage, SUPPORTED_IMAGE_TYPES, messages
)

usage_message = get_usage_message(SUPPORTED_FILE_EXTS)

if isinstance(download_result, str):
return UserError(download_result, usage_message)

image_count = sum(len(msg.image_inputs) for msg in download_result)
if image_count == 0:
return UserError("No images inputs were found", usage_message)

history = list(map(to_content, download_result))
return cls(history=history[:-1], prompt=history[-1].parts)

@property
Expand All @@ -37,39 +87,77 @@ def contents(self) -> List[Content]:
]


class SimpleMessage(BaseModel):
role: str
content: str
def accommodate_first_system_message(messages: List[Message]) -> List[Message]:
if len(messages) == 0:
return messages

@classmethod
def from_message(cls, message: Message) -> "SimpleMessage":
content = message.content
if content is None:
raise ValueError("Message content must be present")

match message.role:
case Role.SYSTEM:
role = ChatSession._USER_ROLE
case Role.USER:
role = ChatSession._USER_ROLE
case Role.ASSISTANT:
role = ChatSession._MODEL_ROLE
case Role.FUNCTION | Role.TOOL:
raise ValidationError("Function messages are not supported")
case _:
assert_never(message.role)

return SimpleMessage(role=role, content=content)
first_message: Message = messages[0]
if first_message.role != Role.SYSTEM:
return messages

@classmethod
def from_messages(cls, messages: List["SimpleMessage"]) -> "SimpleMessage":
if len(messages) == 0:
raise ValueError("Messages must not be empty")
if len(messages) == 1:
first_message = first_message.copy()
first_message.role = Role.USER
return [first_message]

return SimpleMessage(
role=messages[0].role,
content="\n".join(message.content for message in messages),
)
second_message = messages[1]
if second_message.role != Role.USER:
return messages

if first_message.content is None or second_message.content is None:
return messages

content = first_message.content + "\n" + second_message.content
return [Message(role=Role.USER, content=content)] + messages[2:]


def to_content(msg: MessageWithInputs) -> Content:
message = msg.message
content = message.content
if content is None:
raise ValidationError("Message content must be present")

parts: List[Part] = []

for image in msg.image_inputs:
data = base64.b64decode(image.data, validate=True)
parts.append(Part.from_data(data=data, mime_type=image.type))

parts.append(Part.from_text(content))

return Content(role=get_part_role(message.role), parts=parts)


def get_part_role(role: Role) -> str:
match role:
case Role.SYSTEM:
raise ValidationError(
"System messages other than the first system message are not allowed"
)
case Role.USER:
return ChatSession._USER_ROLE
case Role.ASSISTANT:
return ChatSession._MODEL_ROLE
case Role.FUNCTION:
raise ValidationError("Function messages are not supported")
case Role.TOOL:
raise ValidationError("Tool messages are not supported")
case _:
assert_never(role)


def get_usage_message(supported_exts: List[str]) -> str:
return f"""
### Usage
The application answers queries about attached images.
Attach images and ask questions about them in the same message.
Only the last message will be taken into account.
Supported image types: {', '.join(supported_exts)}.
def to_content(self) -> Content:
return Content(role=self.role, parts=[Part.from_text(self.content)])
Examples of queries:
- "Describe this picture" for one image,
- "What are in these images? Is there any difference between them?" for multiple images.
""".strip()
Loading

0 comments on commit f9302ed

Please sign in to comment.