Skip to content

Commit

Permalink
Merge pull request #179 from mobiusml/vllm_deployment_image_support
Browse files Browse the repository at this point in the history
VLM Support for vLLM Deployment
  • Loading branch information
movchan74 authored Sep 25, 2024
2 parents 8cb6ac0 + ba6eb55 commit 1ca768e
Show file tree
Hide file tree
Showing 16 changed files with 2,981 additions and 2,287 deletions.
9 changes: 6 additions & 3 deletions aana/core/chat/chat_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def load_chat_template(chat_template_name: str) -> str:

def apply_chat_template(
tokenizer: PreTrainedTokenizerBase,
dialog: ChatDialog,
dialog: ChatDialog | list[dict],
chat_template_name: str | None = None,
) -> str:
"""Applies a chat template to a list of messages to generate a prompt for the model.
Expand All @@ -38,7 +38,7 @@ def apply_chat_template(
Args:
tokenizer (PreTrainedTokenizerBase): The tokenizer to use.
dialog (ChatDialog): The dialog to generate a prompt for.
dialog (ChatDialog | list[dict]): The dialog to generate a prompt for.
chat_template_name (str, optional): The name of the chat template to use. Defaults to None, which uses the tokenizer's default chat template.
Returns:
Expand All @@ -48,7 +48,10 @@ def apply_chat_template(
ValueError: If the tokenizer does not have a chat template.
ValueError: If the chat template does not exist.
"""
messages = dialog.model_dump()["messages"]
if isinstance(dialog, ChatDialog):
messages = dialog.model_dump()["messages"]
else:
messages = dialog

if chat_template_name is not None:
chat_template = load_chat_template(chat_template_name)
Expand Down
208 changes: 197 additions & 11 deletions aana/deployments/vllm_deployment.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import base64
import contextlib
from collections.abc import AsyncGenerator
from typing import Any
Expand All @@ -9,19 +10,28 @@
from vllm.inputs import TokensPrompt

from aana.core.models.base import merged_options
from aana.core.models.chat import ChatDialog, ChatMessage
from aana.core.models.custom_config import CustomConfig
from aana.core.models.image import Image
from aana.core.models.image_chat import ImageChatDialog
from aana.core.models.types import Dtype
from aana.deployments.base_deployment import BaseDeployment
from aana.deployments.base_text_generation_deployment import (
BaseTextGenerationDeployment,
ChatOutput,
LLMBatchOutput,
LLMOutput,
)
from aana.utils.gpu import get_gpu_memory

with contextlib.suppress(ImportError):
from vllm.model_executor.utils import (
set_random_seed, # Ignore if we don't have GPU and only run on CPU with test cache
)
from vllm.model_executor.utils import set_random_seed
from vllm.entrypoints.chat_utils import (
apply_hf_chat_template,
apply_mistral_chat_template,
parse_chat_messages,
)
from vllm.sampling_params import SamplingParams as VLLMSamplingParams
from vllm.transformers_utils.tokenizer import MistralTokenizer
from vllm.utils import random_uuid

from aana.core.models.base import pydantic_protected_fields
Expand Down Expand Up @@ -63,7 +73,7 @@ class VLLMConfig(BaseModel):


@serve.deployment
class VLLMDeployment(BaseTextGenerationDeployment):
class VLLMDeployment(BaseDeployment):
"""Deployment to serve large language models using vLLM."""

async def apply_config(self, config: dict[str, Any]):
Expand Down Expand Up @@ -117,27 +127,101 @@ async def apply_config(self, config: dict[str, Any]):
self.tokenizer = self.engine.engine.tokenizer.tokenizer
self.model_config = await self.engine.get_model_config()

def apply_chat_template(
self, dialog: ChatDialog | ImageChatDialog
) -> tuple[str | list[int], dict | None]:
"""Apply the chat template to the dialog.
Args:
dialog (ChatDialog | ImageChatDialog): the dialog (optionally with images)
Returns:
tuple[str | list[int], dict | None]: the prompt and the multimodal data
"""

def image_to_base64(image: Image) -> str:
"""Convert an image to a base64 string."""
image_data = image.get_content()
base64_encoded_image = base64.b64encode(image_data)
base64_string = "data:image;base64," + base64_encoded_image.decode("utf-8")
return base64_string

def replace_image_type(messages: list[dict], images: list[Image]) -> list[dict]:
"""Replace the image type with image_url for compatibility with vLLM chat utils.
vLLM chat utils (parse_chat_messages) only support image_url type for images.
We need to replace the image type with image_url and provide the actual image content as base64 string.
"""
i = 0
for message in messages:
for item in message["content"]:
if item["type"] == "image":
item["type"] = "image_url"
if i >= len(images):
raise ValueError( # noqa: TRY003
"Number of images does not match the number of image items in the message."
)
item["image_url"] = {"url": image_to_base64(images[i])}
i += 1
if i != len(images):
raise ValueError( # noqa: TRY003
"Number of images does not match the number of image items in the message."
)
return messages

if isinstance(dialog, ImageChatDialog):
messages, images = dialog.to_objects()
messages = replace_image_type(messages, images)
else:
messages = dialog.model_dump()["messages"]
images = None

conversation, mm_data = parse_chat_messages(
messages, self.model_config, self.tokenizer
)

if isinstance(self.tokenizer, MistralTokenizer):
prompt = apply_mistral_chat_template(
self.tokenizer,
messages=messages,
add_generation_prompt=True,
)
else:
prompt = apply_hf_chat_template(
self.tokenizer,
conversation=conversation,
chat_template=self.tokenizer.chat_template,
add_generation_prompt=True,
# tokenize=True
)
return prompt, mm_data

async def generate_stream(
self, prompt: str, sampling_params: SamplingParams | None = None
self,
prompt: str | list[int],
sampling_params: SamplingParams | None = None,
mm_data: dict | None = None,
) -> AsyncGenerator[LLMOutput, None]:
"""Generate completion for the given prompt and stream the results.
Args:
prompt (str): the prompt
prompt (str | list[int]): the prompt or the tokenized prompt
mm_data (dict | None): the multimodal data
sampling_params (SamplingParams | None): the sampling parameters
Yields:
LLMOutput: the dictionary with the key "text" and the generated text as the value
"""
prompt = str(prompt)
if isinstance(prompt, str):
prompt_token_ids = self.tokenizer.encode(prompt)
else:
prompt_token_ids = prompt

if sampling_params is None:
sampling_params = SamplingParams()
sampling_params = merged_options(self.default_sampling_params, sampling_params)

request_id = None
# tokenize the prompt
prompt_token_ids = self.tokenizer.encode(prompt)

if len(prompt_token_ids) > self.model_config.max_model_len:
raise PromptTooLongException(
Expand All @@ -155,10 +239,17 @@ async def generate_stream(
request_id = random_uuid()
# set the random seed for reproducibility
set_random_seed(42)
if mm_data is not None:
inputs = TokensPrompt(
prompt_token_ids=prompt_token_ids,
multi_modal_data=mm_data,
)
else:
inputs = TokensPrompt(prompt_token_ids=prompt_token_ids)
results_generator = self.engine.generate(
sampling_params=sampling_params_vllm,
request_id=request_id,
inputs=TokensPrompt(prompt_token_ids=prompt_token_ids),
inputs=inputs,
)

num_returned = 0
Expand All @@ -173,3 +264,98 @@ async def generate_stream(
raise
except Exception as e:
raise InferenceException(model_name=self.model_id) from e

async def generate(
self,
prompt: str | list[int],
sampling_params: SamplingParams | None = None,
mm_data: dict | None = None,
) -> LLMOutput:
"""Generate completion for the given prompt.
Args:
prompt (str | list[int]): the prompt or the tokenized prompt
mm_data (dict | None): the multimodal data
sampling_params (SamplingParams | None): the sampling parameters
Returns:
LLMOutput: the dictionary with the key "text" and the generated text as the value
"""
generated_text = ""
async for chunk in self.generate_stream(
prompt, sampling_params=sampling_params, mm_data=mm_data
):
generated_text += chunk["text"]
return LLMOutput(text=generated_text)

async def generate_batch(
self,
prompts: list[str] | list[list[int]],
sampling_params: SamplingParams | None = None,
mm_data_list: list[dict] | None = None,
) -> LLMBatchOutput:
"""Generate completion for the batch of prompts.
Args:
prompts (List[str] | List[List[int]]): the list of prompts or the tokenized prompts
mm_data_list (List[dict] | None): the list of multimodal data
sampling_params (SamplingParams | None): the sampling parameters
Returns:
LLMBatchOutput: the dictionary with the key "texts"
and the list of generated texts as the value
"""
texts = []
for i, prompt in enumerate(prompts):
if mm_data_list is not None:
text = await self.generate(
prompt, sampling_params=sampling_params, mm_data=mm_data_list[i]
)
else:
text = await self.generate(prompt, sampling_params=sampling_params)
texts.append(text["text"])

return LLMBatchOutput(texts=texts)

async def chat(
self,
dialog: ChatDialog | ImageChatDialog,
sampling_params: SamplingParams | None = None,
) -> ChatOutput:
"""Chat with the model.
Args:
dialog (ChatDialog | ImageChatDialog): the dialog (optionally with images)
sampling_params (SamplingParams | None): the sampling parameters
Returns:
ChatOutput: the dictionary with the key "message"
and the response message with a role "assistant"
and the generated text as the content
"""
prompt_ids, mm_data = self.apply_chat_template(dialog)
response = await self.generate(
prompt_ids, sampling_params=sampling_params, mm_data=mm_data
)
response_message = ChatMessage(content=response["text"], role="assistant")
return ChatOutput(message=response_message)

async def chat_stream(
self,
dialog: ChatDialog | ImageChatDialog,
sampling_params: SamplingParams | None = None,
) -> AsyncGenerator[LLMOutput, None]:
"""Chat with the model and stream the responses.
Args:
dialog (ChatDialog | ImageChatDialog): the dialog (optionally with images)
sampling_params (SamplingParams | None): the sampling parameters
Yields:
LLMOutput: the dictionary with the key "text" and the generated text as the value
"""
prompt_ids, mm_data = self.apply_chat_template(dialog)
async for chunk in self.generate_stream(
prompt_ids, sampling_params=sampling_params, mm_data=mm_data
):
yield chunk
20 changes: 18 additions & 2 deletions aana/sdk.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,8 +262,12 @@ def unregister_endpoint(self, name: str):
if name in self.endpoints:
del self.endpoints[name]

def wait_for_deployment(self):
def wait_for_deployment(self): # noqa: C901
"""Wait for the deployment to complete."""
consecutive_resource_unavailable = 0
# Number of consecutive checks before raising an resource unavailable error
resource_unavailable_threshold = 5

while True:
status = serve.status()
if all(
Expand All @@ -290,21 +294,33 @@ def wait_for_deployment(self):
f"Error: {deployment_name} ({app_name}): {deployment_status.message}"
)
raise FailedDeployment("\n".join(error_messages))

gcs_address = ray.get_runtime_context().gcs_address
cluster_status = get_cluster_status(gcs_address)
demands = (
cluster_status.resource_demands.cluster_constraint_demand
+ cluster_status.resource_demands.ray_task_actor_demand
+ cluster_status.resource_demands.placement_group_demand
)

resource_unavailable = False
for demand in demands:
if isinstance(demand, ResourceDemand) and demand.bundles_by_count:
error_message = f"Error: No available node types can fulfill resource request {demand.bundles_by_count[0].bundle}. "
if "GPU" in demand.bundles_by_count[0].bundle:
error_message += "Might be due to insufficient or misconfigured CPU or GPU resources."
resource_unavailable = True
else:
error_message = f"Error: {demand}"
raise InsufficientResources(error_message)
resource_unavailable = True

if resource_unavailable:
consecutive_resource_unavailable += 1
if consecutive_resource_unavailable >= resource_unavailable_threshold:
raise InsufficientResources(error_message)
else:
consecutive_resource_unavailable = 0

time.sleep(1) # Wait for 1 second before checking again

def deploy(self, blocking: bool = False):
Expand Down
Loading

0 comments on commit 1ca768e

Please sign in to comment.