diff --git a/api/logs/combined_api.log b/api/logs/combined_api.log
index a1b0a5a..cfa8f04 100644
--- a/api/logs/combined_api.log
+++ b/api/logs/combined_api.log
@@ -1 +1,2 @@
2024-12-23 14:03:00.055 | INFO | __main__:main:220 - Starting Combined Video Generation Server
+2024-12-23 14:58:34.285 | INFO | __main__:main:278 - Starting combined video generation server on port 8000
diff --git a/api/mochi_serve.py b/api/mochi_serve.py
index 079ad89..bb4e63f 100644
--- a/api/mochi_serve.py
+++ b/api/mochi_serve.py
@@ -1,5 +1,30 @@
"""
-LitServe API implementation for Mochi video generation service.
+Mochi Video Generation Service API
+
+A high-performance API implementation for the Mochi video generation model using LitServe.
+Provides enterprise-grade video generation capabilities with comprehensive monitoring,
+validation, and error handling.
+
+Core Features:
+ - Automated video generation from text prompts
+ - S3 integration for result storage
+ - Prometheus metrics collection
+ - Resource utilization tracking
+ - Configurable generation parameters
+ - GPU memory management
+
+Technical Specifications:
+ - Input Resolution: 64-1024 pixels (width/height)
+ - Frame Range: 1-1000 frames
+ - FPS Range: 1-120
+ - Supported Formats: MP4 output
+ - Storage: AWS S3 integration
+
+Performance Monitoring:
+ - Request latency tracking
+ - Memory usage monitoring
+ - GPU utilization metrics
+ - Success/failure rates
"""
import io
@@ -19,33 +44,59 @@
import litserve
import tempfile
-# Set the directory for multiprocess mode
os.environ["PROMETHEUS_MULTIPROC_DIR"] = "/tmp/prometheus_multiproc_dir"
-
-# Ensure the directory exists
if not os.path.exists("/tmp/prometheus_multiproc_dir"):
os.makedirs("/tmp/prometheus_multiproc_dir")
-
-# Use a multiprocess registry
registry = CollectorRegistry()
multiprocess.MultiProcessCollector(registry)
+class PrometheusLogger(litserve.Logger):
+ """
+ Enterprise-grade Prometheus metrics collector for Mochi service monitoring.
+ Implements detailed performance tracking with multi-process support for production
+ deployments. Provides high-resolution timing metrics for all service operations.
+
+ Metrics:
+ request_processing_seconds:
+ - Type: Histogram
+ - Labels: function_name
+ - Description: Processing time per operation
+ """
-class PrometheusLogger(litserve.Logger):
def __init__(self):
super().__init__()
self.function_duration = Histogram("request_processing_seconds", "Time spent processing request", ["function_name"], registry=registry)
- def process(self, key, value):
- print("processing", key, value)
+ def process(self, key: str, value: float) -> None:
+ """
+ Record a metric observation with operation-specific labeling.
+
+ Args:
+ key: Operation identifier for metric labeling
+ value: Duration measurement in seconds
+ """
self.function_duration.labels(function_name=key).observe(value)
class VideoGenerationRequest(BaseModel):
"""
- Model representing a video generation request.
+ Validated request model for video generation parameters.
+
+ Enforces constraints and provides default values for all generation parameters.
+ Ensures request validity before resource allocation.
+
+ Attributes:
+ prompt: Primary generation directive
+ negative_prompt: Elements to avoid in generation
+ num_inference_steps: Generation quality control
+ guidance_scale: Prompt adherence strength
+ height: Output video height
+ width: Output video width
+ num_frames: Total frames to generate
+ fps: Playback frame rate
"""
+
prompt: str = Field(..., description="Text description of the video to generate")
negative_prompt: Optional[str] = Field("", description="Text description of what to avoid")
num_inference_steps: int = Field(50, ge=1, le=1000, description="Number of inference steps")
@@ -57,192 +108,236 @@ class VideoGenerationRequest(BaseModel):
class MochiVideoAPI(LitAPI):
"""
- API for Mochi video generation using LitServer.
+ Production-ready API implementation for Mochi video generation service.
+
+ Provides a robust, scalable interface for video generation with comprehensive
+ error handling, resource management, and performance monitoring.
+
+ Features:
+ - Request validation and normalization
+ - Batched processing support
+ - Automatic resource cleanup
+ - Detailed error reporting
+ - Performance metrics collection
"""
def setup(self, device: str) -> None:
- """Initialize the Mochi video generation model."""
- try:
- self.settings = MochiSettings(
- model_name="Mini-Mochi",
- enable_vae_tiling=True,
- enable_attention_slicing=True,
- device=device
- )
-
- logger.info("Initializing Mochi inference engine")
- self.engine = MochiInference(self.settings)
-
- # Create output directory
- self.output_dir = Path("outputs")
- self.output_dir.mkdir(parents=True, exist_ok=True)
-
- logger.info("Setup completed successfully")
- except Exception as e:
- logger.error(f"Error during setup: {e}")
- raise
+ """
+ Initialize the Mochi video generation infrastructure.
+
+ Performs model loading, resource allocation, and directory setup for
+ production deployment.
+
+ Args:
+ device: Target compute device for model deployment
+
+ Raises:
+ RuntimeError: On initialization failure
+ """
+ self.settings = MochiSettings(
+ model_name="Mini-Mochi",
+ enable_vae_tiling=True,
+ enable_attention_slicing=True,
+ device=device
+ )
+
+ logger.info("Initializing Mochi inference engine")
+ self.engine = MochiInference(self.settings)
+
+ self.output_dir = Path("outputs")
+ self.output_dir.mkdir(parents=True, exist_ok=True)
+
+ logger.info("Setup completed successfully")
def decode_request(self, request: Union[Dict[str, Any], List[Dict[str, Any]]]) -> List[Dict[str, Any]]:
- """Decode and validate the incoming request."""
- try:
- # Ensure the request is a list of dictionaries
- if not isinstance(request, list):
- request = [request]
-
- # Validate each request in the list
- validated_requests = [VideoGenerationRequest(**req).model_dump() for req in request]
- return validated_requests
- except Exception as e:
- logger.error(f"Error in decode_request: {e}")
- raise
+ """
+ Validate and normalize incoming generation requests.
+
+ Args:
+ request: Raw request data, single or batched
+
+ Returns:
+ List of validated request parameters
+
+ Raises:
+ ValidationError: For malformed requests
+ """
+ if not isinstance(request, list):
+ request = [request]
+
+ validated_requests = [VideoGenerationRequest(**req).model_dump() for req in request]
+ return validated_requests
def batch(self, inputs: Union[Dict[str, Any], List[Dict[str, Any]]]) -> Dict[str, List[Any]]:
"""
Prepare inputs for batch processing.
+
+ Args:
+ inputs: Single or multiple generation requests
+
+ Returns:
+ Batched parameters ready for processing
+
+ Raises:
+ ValueError: For invalid batch composition
"""
- try:
- # Convert single input to list format
- if not isinstance(inputs, list):
- inputs = [inputs]
-
- # Initialize with default values
- defaults = VideoGenerationRequest().model_dump()
-
- batched = {
- "prompt": [],
- "negative_prompt": [],
- "num_inference_steps": [],
- "guidance_scale": [],
- "height": [],
- "width": [],
- "num_frames": [],
- "fps": []
- }
-
- # Fill batched dictionary
- for input_item in inputs:
- for key in batched.keys():
- value = input_item.get(key, defaults.get(key))
- batched[key].append(value)
-
- return batched
- except Exception as e:
- logger.error(f"Error in batch processing: {e}")
- raise
+ if not isinstance(inputs, list):
+ inputs = [inputs]
+
+ defaults = VideoGenerationRequest().model_dump()
+
+ batched = {
+ "prompt": [],
+ "negative_prompt": [],
+ "num_inference_steps": [],
+ "guidance_scale": [],
+ "height": [],
+ "width": [],
+ "num_frames": [],
+ "fps": []
+ }
+
+ for input_item in inputs:
+ for key in batched.keys():
+ value = input_item.get(key, defaults.get(key))
+ batched[key].append(value)
+
+ return batched
def predict(self, inputs: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
- """Process inputs and generate videos."""
+ """
+ Execute video generation for validated requests.
+
+ Handles the complete generation pipeline including:
+ - Parameter validation
+ - Resource allocation
+ - Video generation
+ - S3 upload
+ - Performance monitoring
+ - Resource cleanup
+
+ Args:
+ inputs: List of validated generation parameters
+
+ Returns:
+ List of generation results or error details
+
+ Raises:
+ RuntimeError: On generation failure
+ """
results = []
- try:
- for request in inputs:
- start_time = time.time()
- try:
- # Validate and parse the request
- generation_request = VideoGenerationRequest(**request)
+ for request in inputs:
+ start_time = time.time()
+ try:
+ generation_request = VideoGenerationRequest(**request)
+
+ with tempfile.TemporaryDirectory() as temp_dir:
+ temp_video_path = os.path.join(temp_dir, f"mochi_{int(time.time())}.mp4")
- # Create temporary file path
- with tempfile.TemporaryDirectory() as temp_dir:
- temp_video_path = os.path.join(temp_dir, f"mochi_{int(time.time())}.mp4")
-
- # Prepare generation parameters
- generation_params = generation_request.dict()
- generation_params["output_path"] = temp_video_path
-
- # Generate video
- logger.info(f"Starting generation for prompt: {generation_params['prompt']}")
- self.engine.generate(**generation_params)
-
- end_time = time.time()
- self.log("inference_time", end_time - start_time)
-
- # Get memory usage
- allocated, peak = self.engine.get_memory_usage()
-
- # Upload to S3
- with open(temp_video_path, "rb") as video_file:
- s3_response = mp4_to_s3_json(video_file, f"mochi_{int(time.time())}.mp4")
-
- result = {
- "status": "success",
- "video_id": s3_response["video_id"],
- "video_url": s3_response["url"],
- "prompt": generation_params["prompt"],
- "generation_params": generation_params,
- "time_taken": end_time - start_time,
- "memory_usage": {
- "allocated_gb": round(allocated, 2),
- "peak_gb": round(peak, 2)
- }
- }
- results.append(result)
-
- logger.info(f"Generation completed for prompt: {generation_params['prompt']}")
+ generation_params = generation_request.dict()
+ generation_params["output_path"] = temp_video_path
- except Exception as e:
- logger.error(f"Error in generation for request: {e}")
- error_result = {
- "status": "error",
- "error": str(e)
+ logger.info(f"Starting generation for prompt: {generation_params['prompt']}")
+ self.engine.generate(**generation_params)
+
+ end_time = time.time()
+ self.log("inference_time", end_time - start_time)
+
+ allocated, peak = self.engine.get_memory_usage()
+
+ with open(temp_video_path, "rb") as video_file:
+ s3_response = mp4_to_s3_json(video_file, f"mochi_{int(time.time())}.mp4")
+
+ result = {
+ "status": "success",
+ "video_id": s3_response["video_id"],
+ "video_url": s3_response["url"],
+ "prompt": generation_params["prompt"],
+ "generation_params": generation_params,
+ "time_taken": end_time - start_time,
+ "memory_usage": {
+ "allocated_gb": round(allocated, 2),
+ "peak_gb": round(peak, 2)
+ }
}
- results.append(error_result)
- finally:
- self.engine.clear_memory()
+ results.append(result)
- except Exception as e:
- logger.error(f"Error in predict method: {e}")
- results.append({
- "status": "error",
- "error": str(e)
- })
-
+ logger.info(f"Generation completed for prompt: {generation_params['prompt']}")
+
+ except Exception as e:
+ logger.error(f"Error in generation for request: {e}")
+ error_result = {
+ "status": "error",
+ "error": str(e)
+ }
+ results.append(error_result)
+ finally:
+ self.engine.clear_memory()
+
return results if results else [{"status": "error", "error": "No results generated"}]
def unbatch(self, outputs: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
- """Convert batched outputs back to individual results."""
+ """
+ Convert batched outputs to individual results.
+
+ Args:
+ outputs: List of generation results
+
+ Returns:
+ Unbatched list of results
+ """
return outputs
def encode_response(self, output: Union[Dict[str, Any], List[Any]]) -> Dict[str, Any]:
- """Encode the output for response."""
- try:
- # If output is a list, take the first item
- if isinstance(output, list):
- output = output[0] if output else {"status": "error", "error": "No output generated"}
-
- # Handle error cases
- if output.get("status") == "error":
- return {
- "status": "error",
- "error": output.get("error", "Unknown error"),
- "item_index": output.get("item_index")
- }
-
- # Return the successful response directly
- return {
- "status": "success",
- "video_id": output.get("video_id"),
- "video_url": output.get("video_url"),
- "generation_info": {
- "prompt": output.get("prompt"),
- "parameters": output.get("generation_params", {})
- },
- "performance": {
- "time_taken": round(output.get("time_taken", 0), 2),
- "memory_usage": output.get("memory_usage", {})
- }
- }
-
- except Exception as e:
- logger.error(f"Error in encode_response: {e}")
+ """
+ Format generation results for API response.
+
+ Args:
+ output: Raw generation results or error information
+
+ Returns:
+ Formatted API response with standardized structure
+
+ Note:
+ Handles both success and error cases with consistent formatting
+ """
+ if isinstance(output, list):
+ output = output[0] if output else {"status": "error", "error": "No output generated"}
+
+ if output.get("status") == "error":
return {
"status": "error",
- "error": str(e)
+ "error": output.get("error", "Unknown error"),
+ "item_index": output.get("item_index")
+ }
+
+ return {
+ "status": "success",
+ "video_id": output.get("video_id"),
+ "video_url": output.get("video_url"),
+ "generation_info": {
+ "prompt": output.get("prompt"),
+ "parameters": output.get("generation_params", {})
+ },
+ "performance": {
+ "time_taken": round(output.get("time_taken", 0), 2),
+ "memory_usage": output.get("memory_usage", {})
}
+ }
-if __name__ == "__main__":
- import sys
+def main():
+ """
+ Initialize and launch the Mochi video generation service.
+
+ Sets up the complete service infrastructure including:
+ - Prometheus metrics collection
+ - Structured logging
+ - API server configuration
+ - Error handling
+ """
prometheus_logger = PrometheusLogger()
prometheus_logger.mount(path="/api/v1/metrics", app=make_asgi_app(registry=registry))
- # Configure logging
+
logger.remove()
logger.add(
sys.stdout,
@@ -255,7 +350,6 @@ def encode_response(self, output: Union[Dict[str, Any], List[Any]]) -> Dict[str,
retention="1 week",
level="DEBUG"
)
-
try:
api = MochiVideoAPI()
@@ -268,10 +362,12 @@ def encode_response(self, output: Union[Dict[str, Any], List[Any]]) -> Dict[str,
track_requests=True,
loggers=prometheus_logger,
generate_client_file=False
-
)
logger.info("Starting server on port 8000")
server.run(port=8000)
except Exception as e:
logger.error(f"Server failed to start: {e}")
- sys.exit(1)
\ No newline at end of file
+ sys.exit(1)
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file
diff --git a/api/serve.py b/api/serve.py
index 9239144..1e7d5d2 100644
--- a/api/serve.py
+++ b/api/serve.py
@@ -1,31 +1,44 @@
"""
-Combined API router for LTX and Mochi video generation services.
+combined_serve.py
-This module provides a unified endpoint that can handle requests for both
-LTX and Mochi video generation models. Clients specify which model to use
-via the 'model_name' parameter in their requests.
+A unified API for both LTX and Mochi video generation using batch-based parallel processing.
+All requests (single or multi) are sent as a batch, and we process them concurrently.
Usage:
- POST /predict
+ 1. Place this file in your repo (e.g., in `api/combined_serve.py`).
+ 2. Run: python combined_serve.py
+ 3. POST requests to http://localhost:8000/api/v1/inference
+
+Expected request format:
+{
+ "batch": [
{
- "model_name": "ltx", # or "mochi"
- "prompt": "your video prompt here",
- ...other model-specific parameters...
- }
+ "model_name": "mochi",
+ "prompt": "A calm ocean scene at sunset",
+ "negative_prompt": "blurry, worst quality",
+ "num_inference_steps": 50,
+ "guidance_scale": 4.5,
+ "height": 480,
+ "width": 848
+ ...
+ },
+ ...
+ ]
+}
"""
import sys
import os
-from typing import Dict, Any, List, Union
-from pydantic import BaseModel, Field
-from loguru import logger
import torch
import litserve as ls
+import json
+from concurrent.futures import ProcessPoolExecutor, as_completed
+from typing import Any, Dict, List
+from loguru import logger
from prometheus_client import CollectorRegistry, Histogram, make_asgi_app, multiprocess
-from ltx_serve import LTXVideoAPI
-from mochi_serve import MochiVideoAPI
-
-# Setup Prometheus multiprocess mode
+from api.ltx_serve import LTXVideoAPI
+from api.mochi_serve import MochiVideoAPI
+from configs.combined_settings import CombinedBatchRequest, CombinedItemRequest
os.environ["PROMETHEUS_MULTIPROC_DIR"] = "/tmp/prometheus_multiproc"
if not os.path.exists("/tmp/prometheus_multiproc"):
os.makedirs("/tmp/prometheus_multiproc")
@@ -33,9 +46,13 @@
registry = CollectorRegistry()
multiprocess.MultiProcessCollector(registry)
+
class PrometheusLogger(ls.Logger):
- """Custom logger for tracking combined API metrics."""
-
+ """
+ Custom logger for Prometheus metrics.
+ Tracks request duration for each (model_name, function_name) pair.
+ """
+
def __init__(self):
super().__init__()
self.function_duration = Histogram(
@@ -46,167 +63,198 @@ def __init__(self):
)
def process(self, key: str, value: float) -> None:
- """Record metric observations."""
- model_name, func_name = key.split(":", 1) if ":" in key else ("unknown", key)
+ """
+ Record metric observations with labels for both model_name and function_name.
+ `key` is expected to have the format "model_name:function_name".
+ """
+ if ":" in key:
+ model_name, func_name = key.split(":", 1)
+ else:
+ model_name, func_name = "unknown", key
+
self.function_duration.labels(
model_name=model_name,
function_name=func_name
).observe(value)
-class CombinedRequest(BaseModel):
- """Request model for the combined API endpoint."""
-
- model_name: str = Field(
- ...,
- description="Model to use for video generation ('ltx' or 'mochi')"
- )
-
- class Config:
- extra = "allow" # Allow additional fields for model-specific parameters
class CombinedVideoAPI(ls.LitAPI):
- """Combined API for serving both LTX and Mochi video generation models."""
+ """
+ Combined Video Generation API for both LTX and Mochi models.
+ This API handles requests in batch form, even for single items.
+
+ Steps:
+ 1) setup(device): Initialize LTX and Mochi sub-APIs on the specified device (CPU, GPU).
+ 2) decode_request(request): Parse the request body using Pydantic `CombinedBatchRequest`.
+ 3) predict(inputs): Parallel process each item in the batch.
+ 4) encode_response(outputs): Format the final JSON response.
+ """
def setup(self, device: str) -> None:
- """Initialize both video generation models.
-
- Args:
- device: Target device for model execution
"""
- logger.info(f"Setting up combined video API on device: {device}")
-
- # Initialize both APIs
+ Called once at server startup.
+ Initializes both the LTX and Mochi APIs on the same device.
+ """
+ logger.info(f"Initializing CombinedVideoAPI on device={device}")
self.ltx_api = LTXVideoAPI()
self.mochi_api = MochiVideoAPI()
-
- # Setup each model
+
self.ltx_api.setup(device=device)
self.mochi_api.setup(device=device)
-
- # Register models for routing
+
self.model_apis = {
"ltx": self.ltx_api,
"mochi": self.mochi_api
}
-
- logger.info("Successfully initialized all models")
-
- def decode_request(
- self,
- request: Union[Dict[str, Any], List[Dict[str, Any]]]
- ) -> Dict[str, Any]:
- """Validate request and determine target model.
-
+
+ logger.info("All sub-APIs (LTX, Mochi) successfully set up")
+
+ def decode_request(self, request: Any) -> Dict[str, List[Dict[str, Any]]]:
+ """
+ Interprets the raw request body as a batch, then validates it.
+ We unify single vs. multiple requests by requiring a `batch` array.
+
Args:
- request: Raw request data
-
+ request: The raw request data (usually a dict from the body).
+
Returns:
- Decoded request with model selection
-
+ A dictionary with key 'items' containing a list of validated dicts.
+
Raises:
- ValueError: If model_name is invalid
+ ValidationError if the request doesn't match CombinedBatchRequest schema.
"""
+ # If user directly posted an array, wrap it to match the expected schema
if isinstance(request, list):
- request = request[0] # Handle single requests for now
-
- validated = CombinedRequest(**request).dict()
- model_name = validated.pop("model_name").lower()
-
- if model_name not in self.model_apis:
- raise ValueError(
- f"Invalid model_name: {model_name}. "
- f"Available models: {list(self.model_apis.keys())}"
- )
-
- return {
- "model_name": model_name,
- "request_data": validated
- }
+ request = {"batch": request}
- def predict(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
- """Route request to appropriate model and generate video.
-
- Args:
- inputs: Decoded request data
-
- Returns:
- Generation results from selected model
+ # Validate using CombinedBatchRequest
+ validated_batch = CombinedBatchRequest(**request)
+
+ # Convert each CombinedItemRequest into a dict for usage in predict
+ items = [item.dict() for item in validated_batch.batch]
+ return {"items": items}
+
+ def predict(self, inputs: Dict[str, List[Dict[str, Any]]]) -> Dict[str, Any]:
"""
- model_name = inputs["model_name"]
- request_data = inputs["request_data"]
- model_api = self.model_apis[model_name]
-
- try:
- # Process request through selected model
- decoded = model_api.decode_request(request_data)
- predictions = model_api.predict(decoded)
- result = predictions[0] if predictions else {
- "status": "error",
- "error": "No result returned"
- }
-
- return {
- "model_name": model_name,
- "result": result
- }
-
- except Exception as e:
- import traceback
- logger.error(f"Error in {model_name} prediction: {str(e)}")
- return {
- "model_name": model_name,
- "status": "error",
- "error": str(e),
- "traceback": traceback.format_exc()
- }
- finally:
- if torch.cuda.is_available():
- torch.cuda.empty_cache()
+ Execute parallel inference for all items in the 'items' list.
- def encode_response(self, output: Dict[str, Any]) -> Dict[str, Any]:
- """Encode final response using model-specific encoder.
-
Args:
- output: Raw model output
-
+ inputs: Dictionary with key 'items' -> list of items.
+ Each item is a dict with fields like 'model_name', 'prompt', etc.
+
Returns:
- Encoded response ready for client
+ Dictionary with 'batch_results': a list of output dicts,
+ each containing status, video_id, video_url, etc.
"""
- model_name = output.get("model_name")
- if model_name and model_name in self.model_apis:
- result = output.get("result", {})
-
- if result.get("status") == "error":
+ items = inputs["items"]
+ logger.info(f"Processing batch of {len(items)} request(s) in parallel")
+
+ # We'll define a helper function for one item
+ def process_single(item: Dict[str, Any]) -> Dict[str, Any]:
+ """
+ Takes a single request dict, delegates to the correct sub-API (LTX or Mochi).
+ Returns the predicted result (video URL, etc.).
+ """
+ model_name = item.get("model_name", "").lower()
+ if model_name not in self.model_apis:
return {
"status": "error",
- "error": result.get("error", "Unknown error"),
- "traceback": result.get("traceback")
+ "error": f"Invalid model_name: {model_name}"
}
-
- encoded = self.model_apis[model_name].encode_response(result)
- encoded["model_name"] = model_name
- return encoded
- else:
+
+ sub_api = self.model_apis[model_name]
+
+ # Sub-API workflow: decode -> predict -> single result
+ # Note: sub_api.decode_request() often returns a list. We'll handle that carefully.
+ try:
+ # Prepare sub-request in their expected format
+ sub_decoded = sub_api.decode_request(item)
+ sub_pred = sub_api.predict(sub_decoded)
+ return sub_pred[0] if sub_pred else {
+ "status": "error",
+ "error": "No result returned from sub-API."
+ }
+ except Exception as e:
+ logger.error(f"[{model_name}] sub-api error: {e}")
+ return {"status": "error", "error": str(e), "model_name": model_name}
+
+ # Use a ProcessPoolExecutor to handle CPU-heavy tasks concurrently
+ results = []
+ with ProcessPoolExecutor(max_workers=os.cpu_count()) as executor:
+ future_to_idx = {}
+ for idx, item in enumerate(items):
+ future = executor.submit(process_single, item)
+ future_to_idx[future] = idx
+
+ for f in as_completed(future_to_idx):
+ idx = future_to_idx[f]
+ try:
+ out = f.result()
+ out["item_index"] = idx
+ if "model_name" not in out:
+ out["model_name"] = items[idx].get("model_name", "unknown")
+ results.append(out)
+ except Exception as e:
+ # If something catastrophic happened in process_single
+ results.append({"status": "error", "error": str(e), "item_index": idx})
+
+ # Sort results by item_index so response order matches input order
+ results.sort(key=lambda x: x["item_index"])
+ return {"batch_results": results}
+
+ def encode_response(self, outputs: Dict[str, Any]) -> Dict[str, Any]:
+ """
+ Convert the raw dictionary from `predict` into a final response.
+ We unify single vs. multiple items:
+ - The client always receives "batch_results"
+ (with 1 result if originally a single item).
+
+ Sub-APIs often have their own encode_response() method to standardize the final JSON.
+ We'll call that to keep consistent format.
+
+ Returns:
+ The final JSON-serializable dict.
+ """
+ if "batch_results" not in outputs:
return {
- "status": "error",
- "error": output.get("error", "Unknown routing error"),
- "traceback": output.get("traceback")
+ "status": "error",
+ "error": "No batch_results field found in predict output"
}
+ for item in outputs["batch_results"]:
+ if item.get("status") == "error":
+ continue
+
+ model_name = item.get("model_name", "").lower()
+ if model_name in self.model_apis:
+ sub_encoded = self.model_apis[model_name].encode_response(item)
+ item.update(sub_encoded)
+
+ return outputs
+
+
def main():
- """Initialize and start the combined video generation server."""
- # Setup Prometheus metrics
+ """
+ Main entry point for the combined server, exposing /predict on port 8000.
+ This version logs metrics to Prometheus and logs to console + file.
+ """
+ from litserve import LitServer
+
+ # PROMETHEUS LOGGER
prometheus_logger = PrometheusLogger()
prometheus_logger.mount(
path="/metrics",
app=make_asgi_app(registry=registry)
)
- # Configure logging
- logger.remove()
+ # LOGGING
+ logger.remove() # Remove default handler
logger.add(
sys.stdout,
- format="{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {name}:{function} - {message}",
+ format="{time:YYYY-MM-DD HH:mm:ss} "
+ "| {level: <8} "
+ "| {name}:{function} - "
+ "{message}",
level="INFO"
)
logger.add(
@@ -216,19 +264,20 @@ def main():
level="DEBUG"
)
- # Start server
- logger.info("Starting Combined Video Generation Server")
api = CombinedVideoAPI()
- server = ls.LitServer(
+ server = LitServer(
api,
- api_path="/predict",
- accelerator="auto",
- devices="auto",
- max_batch_size=1,
+ api_path="/api/v1/inference",
+ accelerator="auto",
+ devices="auto",
+ max_batch_size=4,
track_requests=True,
loggers=[prometheus_logger]
)
+
+ logger.info("Starting combined video generation server on port 8000")
server.run(port=8000)
+
if __name__ == "__main__":
main()
diff --git a/configs/__pycache__/combined_settings.cpython-310.pyc b/configs/__pycache__/combined_settings.cpython-310.pyc
new file mode 100644
index 0000000..659a291
Binary files /dev/null and b/configs/__pycache__/combined_settings.cpython-310.pyc differ
diff --git a/configs/combined_settings.py b/configs/combined_settings.py
new file mode 100644
index 0000000..bd9d7d9
--- /dev/null
+++ b/configs/combined_settings.py
@@ -0,0 +1,110 @@
+"""
+combined_settings.py
+
+Central Pydantic models and optional unified config for combining Mochi and LTX requests/settings.
+
+1) CombinedItemRequest:
+ - Defines the schema for a single text-to-video request, including model_name (mochi/ltx)
+ and common fields like prompt, negative_prompt, resolution, etc.
+
+2) CombinedBatchRequest:
+ - Defines a list of CombinedItemRequest items, handling batch-mode requests.
+
+3) CombinedConfig (optional):
+ - Demonstrates how you might wrap both mochi_settings and ltx_settings into a single config
+ so everything can be accessed from a unified settings object if needed.
+
+Usage Example (Batch):
+ POST /predict
+ {
+ "batch": [
+ {
+ "model_name": "mochi",
+ "prompt": "A calm ocean scene at sunset",
+ "negative_prompt": "blurry, worst quality",
+ "num_inference_steps": 50,
+ "guidance_scale": 4.5,
+ "height": 480,
+ "width": 848
+ },
+ {
+ "model_name": "ltx",
+ "prompt": "Golden autumn leaves swirling",
+ "num_inference_steps": 40,
+ "guidance_scale": 3.0,
+ "height": 480,
+ "width": 704
+ }
+ ]
+ }
+"""
+
+from pydantic import BaseModel, Field
+from typing import List, Optional
+
+# If you want to embed or reference them here:
+from .mochi_settings import MochiSettings
+from .ltx_settings import LTXVideoSettings
+from pydantic_settings import BaseSettings
+
+
+class CombinedItemRequest(BaseModel):
+ """
+ A single request object for either Mochi or LTX.
+
+ Fields:
+ model_name (str): Which model to use: 'mochi' or 'ltx'.
+ prompt (str): Main prompt describing the video content.
+ negative_prompt (Optional[str]): Text describing what to avoid.
+ num_inference_steps (Optional[int]): Override for inference steps.
+ guidance_scale (Optional[float]): Classifier-free guidance scale.
+ height (Optional[int]): Video height in pixels.
+ width (Optional[int]): Video width in pixels.
+ (Add additional fields as needed for your models.)
+ """
+ model_name: str = Field(..., description="Model to use: 'ltx' or 'mochi'.")
+ prompt: str = Field(..., description="Prompt describing the video content.")
+ negative_prompt: Optional[str] = Field(None, description="Things to avoid in generation.")
+ num_inference_steps: Optional[int] = Field(40, description="Number of denoising steps.")
+ guidance_scale: Optional[float] = Field(3.0, description="Guidance scale for generation.")
+ height: Optional[int] = Field(480, description="Video height in pixels.")
+ width: Optional[int] = Field(704, description="Video width in pixels.")
+ # Add any more fields your sub-models need, e.g. fps, frames, etc.
+
+
+class CombinedBatchRequest(BaseModel):
+ """
+ A batched request containing multiple CombinedItemRequest items.
+
+ Usage:
+ {
+ "batch": [
+ { "model_name": "mochi", "prompt": "...", ... },
+ { "model_name": "ltx", "prompt": "...", ... }
+ ]
+ }
+ """
+ batch: List[CombinedItemRequest] = Field(
+ ..., description="List of multiple CombinedItemRequest items to process in parallel."
+ )
+
+
+class CombinedConfig(BaseSettings):
+ """
+ Optional: A unified config that embeds or references your Mochi/LTX model settings.
+
+ This can be used if you want to store and manipulate both sets of settings in one place.
+ For example, you might define environment variables to override mochi or ltx defaults.
+
+ Usage:
+ from configs.combined_settings import CombinedConfig
+ combined_config = CombinedConfig()
+ # Access mochi or ltx settings: combined_config.mochi_config, combined_config.ltx_config
+ """
+ mochi_config: MochiSettings = MochiSettings()
+ ltx_config: LTXVideoSettings = LTXVideoSettings()
+
+ class Config:
+ env_prefix = "COMBINED_"
+ validate_assignment = True
+ arbitrary_types_allowed = True