Skip to content

Commit

Permalink
Enhance LTX and Mochi video generation APIs with detailed documentati…
Browse files Browse the repository at this point in the history
…on and combined routing

- Updated ltx_serve.py and serve.py to improve API structure and documentation.
- Added comprehensive docstrings for key classes and methods, enhancing clarity on functionality and usage.
- Implemented a combined API router to handle requests for both LTX and Mochi models, allowing clients to specify the model via a single endpoint.
- Improved Prometheus metrics logging for better performance monitoring.
- Updated logging configuration to support combined API metrics.

These changes aim to streamline the API usage, improve maintainability, and enhance user understanding of the video generation services.
  • Loading branch information
VikramxD committed Dec 23, 2024
1 parent 2d5ac3a commit 349eed7
Show file tree
Hide file tree
Showing 12 changed files with 213 additions and 134 deletions.
Binary file added api/__pycache__/ltx_serve.cpython-310.pyc
Binary file not shown.
Binary file added api/__pycache__/mochi_serve.cpython-310.pyc
Binary file not shown.
1 change: 1 addition & 0 deletions api/logs/combined_api.log
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
2024-12-23 14:03:00.055 | INFO | __main__:main:220 - Starting Combined Video Generation Server
119 changes: 110 additions & 9 deletions api/ltx_serve.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
"""
LitServe API implementation for LTX video generation service.
This module provides a FastAPI-based service for generating videos using the LTX model.
It handles request validation, video generation, S3 upload, and monitoring through Prometheus.
Key Components:
- PrometheusLogger: Custom metrics logging
- VideoGenerationRequest: Request validation model
- LTXVideoAPI: Main API implementation
"""

import os
Expand Down Expand Up @@ -33,7 +41,14 @@
multiprocess.MultiProcessCollector(registry)

class PrometheusLogger(Logger):
"""Custom logger for Prometheus metrics."""
"""Custom logger for Prometheus metrics.
Implements metric collection for video generation request processing times
using Prometheus Histograms. Metrics are stored in a multi-process compatible registry.
Attributes:
function_duration (Histogram): Prometheus histogram for tracking processing times
"""

def __init__(self):
super().__init__()
Expand All @@ -45,11 +60,31 @@ def __init__(self):
)

def process(self, key: str, value: float) -> None:
"""Process and record metric."""
"""Process and record a metric value.
Args:
key (str): The name of the function or operation being measured
value (float): The duration or metric value to record
"""
self.function_duration.labels(function_name=key).observe(value)

class VideoGenerationRequest(BaseModel):
"""Model representing a video generation request."""
"""Model representing a video generation request.
Validates and normalizes input parameters for video generation.
Provides default values and constraints for all generation parameters.
Attributes:
prompt (str): Main text description for video generation
negative_prompt (Optional[str]): Text description of elements to avoid
num_inference_steps (int): Number of denoising steps (1-100)
guidance_scale (float): Controls adherence to prompt (1.0-20.0)
height (int): Video height in pixels (256-720, multiple of 32)
width (int): Video width in pixels (256-1280, multiple of 32)
num_frames (int): Number of frames to generate (1-257)
frame_rate (int): Output video frame rate (1-60)
seed (Optional[int]): Random seed for reproducibility
"""

prompt: str = Field(..., description="Text description of the video to generate")
negative_prompt: Optional[str] = Field(
Expand Down Expand Up @@ -97,10 +132,27 @@ class VideoGenerationRequest(BaseModel):
seed: Optional[int] = Field(None, description="Random seed for generation")

class LTXVideoAPI(LitAPI):
"""API for LTX video generation using LitServer."""
"""API for LTX video generation using LitServer.
Implements the core video generation workflow including model initialization,
request processing, video generation, and result handling.
Attributes:
settings (LTXVideoSettings): Configuration for video generation
engine (LTXInference): Video generation inference engine
"""

def setup(self, device: str) -> None:
"""Initialize the LTX video generation model."""
"""Initialize the LTX video generation model.
Sets up the video generation engine and loads models onto the specified device.
Args:
device (str): Target device for model execution ('cuda', 'cpu', etc.)
Raises:
Exception: If model initialization fails
"""
try:
logger.info(f"Initializing LTX video generation on device: {device}")

Expand All @@ -120,7 +172,19 @@ def decode_request(
self,
request: Union[Dict[str, Any], List[Dict[str, Any]]]
) -> List[Dict[str, Any]]:
"""Decode and validate the incoming request."""
"""Decode and validate the incoming request.
Converts raw request data into validated VideoGenerationRequest objects.
Args:
request: Single request dict or list of request dicts
Returns:
List of validated request dictionaries
Raises:
ValidationError: If request validation fails
"""
try:
# Ensure request is a list
if not isinstance(request, list):
Expand All @@ -141,7 +205,19 @@ def batch(
self,
inputs: Union[Dict[str, Any], List[Dict[str, Any]]]
) -> Dict[str, List[Any]]:
"""Prepare inputs for batch processing."""
"""Prepare inputs for batch processing.
Organizes single or multiple requests into a batched format for processing.
Args:
inputs: Single input dict or list of input dicts
Returns:
Dictionary with lists of batched parameters
Raises:
Exception: If batch preparation fails
"""
try:
# Convert single input to list
if not isinstance(inputs, list):
Expand Down Expand Up @@ -176,7 +252,23 @@ def batch(
raise

def predict(self, inputs: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""Process inputs and generate videos."""
"""Process inputs and generate videos.
Core video generation method that handles the complete pipeline:
- Parameter validation
- Video generation
- S3 upload
- Performance monitoring
Args:
inputs: List of validated generation requests
Returns:
List of generation results including video URLs and metadata
Raises:
Exception: If video generation or upload fails
"""
results = []

try:
Expand Down Expand Up @@ -317,7 +409,16 @@ def encode_response(
}

def main():
"""Main entry point for the API server."""
"""Main entry point for the API server.
Initializes and starts the Litser server with:
- Prometheus metrics endpoint
- Configured logging
- LTX video generation API
- Server settings for batching and acceleration
Exits with status code 1 if server initialization fails.
"""
# Initialize Prometheus logger
prometheus_logger = PrometheusLogger()
prometheus_logger.mount(
Expand Down
Loading

0 comments on commit 349eed7

Please sign in to comment.