Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deployment Health Check and Automatic Restart #191

Merged
merged 2 commits into from
Oct 29, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 66 additions & 0 deletions aana/deployments/base_deployment.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,44 @@
import inspect
from functools import wraps
from typing import Any

from aana.exceptions.runtime import InferenceException


def exception_handler(func):
"""AanaDeploymentHandle decorator to catch exceptions and store them in the deployment for health check purposes.

Args:
func (function): The function to decorate.

Returns:
function: The decorated function
"""

@wraps(func)
async def wrapper(self, *args, **kwargs):
self.num_requests_since_last_health_check += 1
try:
return await func(self, *args, **kwargs)
except Exception as e:
self.raised_exceptions.append(e)
raise

@wraps(func)
async def wrapper_generator(self, *args, **kwargs):
self.num_requests_since_last_health_check += 1
try:
async for item in func(self, *args, **kwargs):
yield item
except Exception as e:
self.raised_exceptions.append(e)
raise

if inspect.isasyncgenfunction(func):
return wrapper_generator
else:
return wrapper


class BaseDeployment:
"""Base class for all deployments.
Expand All @@ -13,6 +51,9 @@ def __init__(self):
"""Inits to unconfigured state."""
self.config = None
self._configured = False
self.num_requests_since_last_health_check = 0
self.raised_exceptions = []
self.restart_exceptions = [InferenceException]

async def reconfigure(self, config: dict[str, Any]):
"""Reconfigure the deployment.
Expand All @@ -22,6 +63,31 @@ async def reconfigure(self, config: dict[str, Any]):
self.config = config
await self.apply_config(config)
self._configured = True
if "restart_exceptions" in config:
self.restart_exceptions = config["restart_exceptions"]
movchan74 marked this conversation as resolved.
Show resolved Hide resolved

async def check_health(self):
"""Check the health of the deployment.

Raises:
Raises the exception that caused the deployment to be unhealthy.
"""
raised_restart_exceptions = [
exception
for exception in self.raised_exceptions
if exception.__class__ in self.restart_exceptions
]
# Restart the deployment if more than 50% of the requests raised restart exceptions
if self.num_requests_since_last_health_check != 0:
ratio_restart_exceptions = (
len(raised_restart_exceptions)
/ self.num_requests_since_last_health_check
)
if ratio_restart_exceptions > 0.5:
raise raised_restart_exceptions[0]

self.raised_exceptions = []
self.num_requests_since_last_health_check = 0

async def apply_config(self, config: dict[str, Any]):
"""Apply the configuration.
Expand Down
7 changes: 6 additions & 1 deletion aana/deployments/base_text_generation_deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from aana.core.chat.chat_template import apply_chat_template
from aana.core.models.chat import ChatDialog, ChatMessage
from aana.core.models.sampling import SamplingParams
from aana.deployments.base_deployment import BaseDeployment
from aana.deployments.base_deployment import BaseDeployment, exception_handler


class LLMOutput(TypedDict):
Expand Down Expand Up @@ -57,6 +57,7 @@ class BaseTextGenerationDeployment(BaseDeployment):
You can also override these methods to implement custom inference logic.
"""

@exception_handler
async def generate_stream(
self, prompt: str, sampling_params: SamplingParams | None = None
) -> AsyncGenerator[LLMOutput, None]:
Expand All @@ -71,6 +72,7 @@ async def generate_stream(
"""
raise NotImplementedError

@exception_handler
async def generate(
self, prompt: str, sampling_params: SamplingParams | None = None
) -> LLMOutput:
Expand All @@ -88,6 +90,7 @@ async def generate(
generated_text += chunk["text"]
return LLMOutput(text=generated_text)

@exception_handler
async def generate_batch(
self, prompts: list[str], sampling_params: SamplingParams | None = None
) -> LLMBatchOutput:
Expand All @@ -108,6 +111,7 @@ async def generate_batch(

return LLMBatchOutput(texts=texts)

@exception_handler
async def chat(
self, dialog: ChatDialog, sampling_params: SamplingParams | None = None
) -> ChatOutput:
Expand All @@ -127,6 +131,7 @@ async def chat(
response_message = ChatMessage(content=response["text"], role="assistant")
return ChatOutput(message=response_message)

@exception_handler
async def chat_stream(
self, dialog: ChatDialog, sampling_params: SamplingParams | None = None
) -> AsyncGenerator[LLMOutput, None]:
Expand Down
3 changes: 2 additions & 1 deletion aana/deployments/haystack_component_deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from ray import serve

from aana.deployments.aana_deployment_handle import AanaDeploymentHandle
from aana.deployments.base_deployment import BaseDeployment
from aana.deployments.base_deployment import BaseDeployment, exception_handler
from aana.utils.asyncio import run_async
from aana.utils.core import import_from_path

Expand Down Expand Up @@ -84,6 +84,7 @@ async def apply_config(self, config: dict[str, Any]):

self.component.warm_up()

@exception_handler
async def run(self, **data: dict[str, Any]) -> dict[str, Any]:
"""Run the model on the input data."""
return self.component.run(**data)
Expand Down
4 changes: 3 additions & 1 deletion aana/deployments/hf_blip2_deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from aana.core.models.captions import Caption, CaptionsList
from aana.core.models.image import Image
from aana.core.models.types import Dtype
from aana.deployments.base_deployment import BaseDeployment
from aana.deployments.base_deployment import BaseDeployment, exception_handler
from aana.exceptions.runtime import InferenceException
from aana.processors.batch import BatchProcessor

Expand Down Expand Up @@ -106,6 +106,7 @@ async def apply_config(self, config: dict[str, Any]):
self.processor = Blip2Processor.from_pretrained(self.model_id)
self.model.to(self.device)

@exception_handler
async def generate(self, image: Image) -> CaptioningOutput:
"""Generate captions for the given image.

Expand All @@ -124,6 +125,7 @@ async def generate(self, image: Image) -> CaptioningOutput:
)
return CaptioningOutput(caption=captions["captions"][0])

@exception_handler
async def generate_batch(self, **kwargs) -> CaptioningBatchOutput:
"""Generate captions for the given images.

Expand Down
3 changes: 2 additions & 1 deletion aana/deployments/hf_pipeline_deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from aana.core.models.base import pydantic_protected_fields
from aana.core.models.custom_config import CustomConfig
from aana.core.models.image import Image
from aana.deployments.base_deployment import BaseDeployment
from aana.deployments.base_deployment import BaseDeployment, exception_handler


class HfPipelineConfig(BaseModel):
Expand Down Expand Up @@ -80,6 +80,7 @@ async def apply_config(self, config: dict[str, Any]):
else:
raise

@exception_handler
async def call(self, *args, **kwargs):
"""Call the pipeline.

Expand Down
2 changes: 2 additions & 0 deletions aana/deployments/hf_text_generation_deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from aana.core.models.base import merged_options, pydantic_protected_fields
from aana.core.models.sampling import SamplingParams
from aana.deployments.base_deployment import exception_handler
from aana.deployments.base_text_generation_deployment import (
BaseTextGenerationDeployment,
LLMOutput,
Expand Down Expand Up @@ -48,6 +49,7 @@ class HfTextGenerationConfig(BaseModel):
class BaseHfTextGenerationDeployment(BaseTextGenerationDeployment):
"""Base class for Hugging Face text generation deployments."""

@exception_handler
async def generate_stream(
self, prompt: str, sampling_params: SamplingParams | None = None
) -> AsyncGenerator[LLMOutput, None]:
Expand Down
5 changes: 4 additions & 1 deletion aana/deployments/idefics_2_deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from aana.core.models.image_chat import ImageChatDialog
from aana.core.models.sampling import SamplingParams
from aana.core.models.types import Dtype
from aana.deployments.base_deployment import BaseDeployment
from aana.deployments.base_deployment import BaseDeployment, exception_handler
from aana.deployments.base_text_generation_deployment import ChatOutput, LLMOutput
from aana.exceptions.runtime import InferenceException
from aana.utils.streamer import async_streamer_adapter
Expand Down Expand Up @@ -88,6 +88,7 @@ async def apply_config(self, config: dict[str, Any]):
self.model_id, **self.model_kwargs
)

@exception_handler
async def chat_stream(
self, dialog: ImageChatDialog, sampling_params: SamplingParams | None = None
) -> AsyncGenerator[LLMOutput, None]:
Expand Down Expand Up @@ -153,6 +154,7 @@ async def chat_stream(
except Exception as e:
raise InferenceException(model_name=self.model_id) from e

@exception_handler
movchan74 marked this conversation as resolved.
Show resolved Hide resolved
async def chat(
self, dialog: ImageChatDialog, sampling_params: SamplingParams | None = None
) -> ChatOutput:
Expand All @@ -171,6 +173,7 @@ async def chat(

return ChatOutput(message=ChatMessage(content=text, role="assistant"))

@exception_handler
async def chat_batch(
self,
dialogs: list[ImageChatDialog],
Expand Down
3 changes: 2 additions & 1 deletion aana/deployments/pyannote_speaker_diarization_deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
SpeakerDiarizationSegment,
)
from aana.core.models.time import TimeInterval
from aana.deployments.base_deployment import BaseDeployment
from aana.deployments.base_deployment import BaseDeployment, exception_handler
from aana.exceptions.runtime import InferenceException
from aana.processors.speaker import combine_homogeneous_speaker_diarization_segments

Expand Down Expand Up @@ -116,6 +116,7 @@ async def __inference(

return speaker_segments

@exception_handler
async def diarize(
self, audio: Audio, params: PyannoteSpeakerDiarizationParams | None = None
) -> SpeakerDiarizationOutput:
Expand Down
3 changes: 2 additions & 1 deletion aana/deployments/sentence_transformer_deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from typing_extensions import TypedDict

from aana.core.models.base import pydantic_protected_fields
from aana.deployments.base_deployment import BaseDeployment
from aana.deployments.base_deployment import BaseDeployment, exception_handler
from aana.exceptions.runtime import InferenceException
from aana.processors.batch import BatchProcessor

Expand Down Expand Up @@ -70,6 +70,7 @@ async def apply_config(self, config: dict[str, Any]):

self.model = SentenceTransformer(self.model_id)

@exception_handler
async def embed_batch(self, **kwargs) -> np.ndarray:
"""Embed the given sentences.

Expand Down
3 changes: 2 additions & 1 deletion aana/deployments/vad_deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from aana.core.models.base import pydantic_protected_fields
from aana.core.models.time import TimeInterval
from aana.core.models.vad import VadParams, VadSegment
from aana.deployments.base_deployment import BaseDeployment
from aana.deployments.base_deployment import BaseDeployment, exception_handler
from aana.exceptions.runtime import InferenceException
from aana.processors.vad import BinarizeVadScores, VoiceActivitySegmentation
from aana.utils.download import download_model
Expand Down Expand Up @@ -211,6 +211,7 @@ async def __inference(self, audio: Audio) -> list[dict]:

return vad_segments

@exception_handler
async def asr_preprocess_vad(
self, audio: Audio, params: VadParams | None = None
) -> VadOutput:
Expand Down
19 changes: 18 additions & 1 deletion aana/deployments/vllm_deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from aana.core.models.image_chat import ImageChatDialog
from aana.core.models.sampling import SamplingParams
from aana.core.models.types import Dtype
from aana.deployments.base_deployment import BaseDeployment
from aana.deployments.base_deployment import BaseDeployment, exception_handler
from aana.deployments.base_text_generation_deployment import (
ChatOutput,
LLMBatchOutput,
Expand Down Expand Up @@ -72,6 +72,11 @@ class VLLMConfig(BaseModel):
class VLLMDeployment(BaseDeployment):
"""Deployment to serve large language models using vLLM."""

def __init__(self):
"""Initialize the deployment."""
super().__init__()
self.engine = None

async def apply_config(self, config: dict[str, Any]):
"""Apply the configuration.

Expand Down Expand Up @@ -123,6 +128,13 @@ async def apply_config(self, config: dict[str, Any]):
self.tokenizer = self.engine.engine.tokenizer.tokenizer
self.model_config = await self.engine.get_model_config()

async def check_health(self):
"""Check the health of the deployment."""
if self.engine:
await self.engine.check_health()

await super().check_health()

def apply_chat_template(
self, dialog: ChatDialog | ImageChatDialog
) -> tuple[str | list[int], dict | None]:
Expand Down Expand Up @@ -192,6 +204,7 @@ def replace_image_type(messages: list[dict], images: list[Image]) -> list[dict]:
)
return prompt, mm_data

@exception_handler
async def generate_stream( # noqa: C901
self,
prompt: str | list[int],
Expand Down Expand Up @@ -274,6 +287,7 @@ async def generate_stream( # noqa: C901
except Exception as e:
raise InferenceException(model_name=self.model_id) from e

@exception_handler
async def generate(
self,
prompt: str | list[int],
Expand All @@ -297,6 +311,7 @@ async def generate(
generated_text += chunk["text"]
return LLMOutput(text=generated_text)

@exception_handler
async def generate_batch(
self,
prompts: list[str] | list[list[int]],
Expand Down Expand Up @@ -326,6 +341,7 @@ async def generate_batch(

return LLMBatchOutput(texts=texts)

@exception_handler
async def chat(
self,
dialog: ChatDialog | ImageChatDialog,
Expand All @@ -349,6 +365,7 @@ async def chat(
response_message = ChatMessage(content=response["text"], role="assistant")
return ChatOutput(message=response_message)

@exception_handler
async def chat_stream(
self,
dialog: ChatDialog | ImageChatDialog,
Expand Down
Loading
Loading