diff --git a/api/logs/combined_api.log b/api/logs/combined_api.log index a1b0a5a..cfa8f04 100644 --- a/api/logs/combined_api.log +++ b/api/logs/combined_api.log @@ -1 +1,2 @@ 2024-12-23 14:03:00.055 | INFO | __main__:main:220 - Starting Combined Video Generation Server +2024-12-23 14:58:34.285 | INFO | __main__:main:278 - Starting combined video generation server on port 8000 diff --git a/api/mochi_serve.py b/api/mochi_serve.py index 079ad89..bb4e63f 100644 --- a/api/mochi_serve.py +++ b/api/mochi_serve.py @@ -1,5 +1,30 @@ """ -LitServe API implementation for Mochi video generation service. +Mochi Video Generation Service API + +A high-performance API implementation for the Mochi video generation model using LitServe. +Provides enterprise-grade video generation capabilities with comprehensive monitoring, +validation, and error handling. + +Core Features: + - Automated video generation from text prompts + - S3 integration for result storage + - Prometheus metrics collection + - Resource utilization tracking + - Configurable generation parameters + - GPU memory management + +Technical Specifications: + - Input Resolution: 64-1024 pixels (width/height) + - Frame Range: 1-1000 frames + - FPS Range: 1-120 + - Supported Formats: MP4 output + - Storage: AWS S3 integration + +Performance Monitoring: + - Request latency tracking + - Memory usage monitoring + - GPU utilization metrics + - Success/failure rates """ import io @@ -19,33 +44,59 @@ import litserve import tempfile -# Set the directory for multiprocess mode os.environ["PROMETHEUS_MULTIPROC_DIR"] = "/tmp/prometheus_multiproc_dir" - -# Ensure the directory exists if not os.path.exists("/tmp/prometheus_multiproc_dir"): os.makedirs("/tmp/prometheus_multiproc_dir") - -# Use a multiprocess registry registry = CollectorRegistry() multiprocess.MultiProcessCollector(registry) +class PrometheusLogger(litserve.Logger): + """ + Enterprise-grade Prometheus metrics collector for Mochi service monitoring. + Implements detailed performance tracking with multi-process support for production + deployments. Provides high-resolution timing metrics for all service operations. + + Metrics: + request_processing_seconds: + - Type: Histogram + - Labels: function_name + - Description: Processing time per operation + """ -class PrometheusLogger(litserve.Logger): def __init__(self): super().__init__() self.function_duration = Histogram("request_processing_seconds", "Time spent processing request", ["function_name"], registry=registry) - def process(self, key, value): - print("processing", key, value) + def process(self, key: str, value: float) -> None: + """ + Record a metric observation with operation-specific labeling. + + Args: + key: Operation identifier for metric labeling + value: Duration measurement in seconds + """ self.function_duration.labels(function_name=key).observe(value) class VideoGenerationRequest(BaseModel): """ - Model representing a video generation request. + Validated request model for video generation parameters. + + Enforces constraints and provides default values for all generation parameters. + Ensures request validity before resource allocation. + + Attributes: + prompt: Primary generation directive + negative_prompt: Elements to avoid in generation + num_inference_steps: Generation quality control + guidance_scale: Prompt adherence strength + height: Output video height + width: Output video width + num_frames: Total frames to generate + fps: Playback frame rate """ + prompt: str = Field(..., description="Text description of the video to generate") negative_prompt: Optional[str] = Field("", description="Text description of what to avoid") num_inference_steps: int = Field(50, ge=1, le=1000, description="Number of inference steps") @@ -57,192 +108,236 @@ class VideoGenerationRequest(BaseModel): class MochiVideoAPI(LitAPI): """ - API for Mochi video generation using LitServer. + Production-ready API implementation for Mochi video generation service. + + Provides a robust, scalable interface for video generation with comprehensive + error handling, resource management, and performance monitoring. + + Features: + - Request validation and normalization + - Batched processing support + - Automatic resource cleanup + - Detailed error reporting + - Performance metrics collection """ def setup(self, device: str) -> None: - """Initialize the Mochi video generation model.""" - try: - self.settings = MochiSettings( - model_name="Mini-Mochi", - enable_vae_tiling=True, - enable_attention_slicing=True, - device=device - ) - - logger.info("Initializing Mochi inference engine") - self.engine = MochiInference(self.settings) - - # Create output directory - self.output_dir = Path("outputs") - self.output_dir.mkdir(parents=True, exist_ok=True) - - logger.info("Setup completed successfully") - except Exception as e: - logger.error(f"Error during setup: {e}") - raise + """ + Initialize the Mochi video generation infrastructure. + + Performs model loading, resource allocation, and directory setup for + production deployment. + + Args: + device: Target compute device for model deployment + + Raises: + RuntimeError: On initialization failure + """ + self.settings = MochiSettings( + model_name="Mini-Mochi", + enable_vae_tiling=True, + enable_attention_slicing=True, + device=device + ) + + logger.info("Initializing Mochi inference engine") + self.engine = MochiInference(self.settings) + + self.output_dir = Path("outputs") + self.output_dir.mkdir(parents=True, exist_ok=True) + + logger.info("Setup completed successfully") def decode_request(self, request: Union[Dict[str, Any], List[Dict[str, Any]]]) -> List[Dict[str, Any]]: - """Decode and validate the incoming request.""" - try: - # Ensure the request is a list of dictionaries - if not isinstance(request, list): - request = [request] - - # Validate each request in the list - validated_requests = [VideoGenerationRequest(**req).model_dump() for req in request] - return validated_requests - except Exception as e: - logger.error(f"Error in decode_request: {e}") - raise + """ + Validate and normalize incoming generation requests. + + Args: + request: Raw request data, single or batched + + Returns: + List of validated request parameters + + Raises: + ValidationError: For malformed requests + """ + if not isinstance(request, list): + request = [request] + + validated_requests = [VideoGenerationRequest(**req).model_dump() for req in request] + return validated_requests def batch(self, inputs: Union[Dict[str, Any], List[Dict[str, Any]]]) -> Dict[str, List[Any]]: """ Prepare inputs for batch processing. + + Args: + inputs: Single or multiple generation requests + + Returns: + Batched parameters ready for processing + + Raises: + ValueError: For invalid batch composition """ - try: - # Convert single input to list format - if not isinstance(inputs, list): - inputs = [inputs] - - # Initialize with default values - defaults = VideoGenerationRequest().model_dump() - - batched = { - "prompt": [], - "negative_prompt": [], - "num_inference_steps": [], - "guidance_scale": [], - "height": [], - "width": [], - "num_frames": [], - "fps": [] - } - - # Fill batched dictionary - for input_item in inputs: - for key in batched.keys(): - value = input_item.get(key, defaults.get(key)) - batched[key].append(value) - - return batched - except Exception as e: - logger.error(f"Error in batch processing: {e}") - raise + if not isinstance(inputs, list): + inputs = [inputs] + + defaults = VideoGenerationRequest().model_dump() + + batched = { + "prompt": [], + "negative_prompt": [], + "num_inference_steps": [], + "guidance_scale": [], + "height": [], + "width": [], + "num_frames": [], + "fps": [] + } + + for input_item in inputs: + for key in batched.keys(): + value = input_item.get(key, defaults.get(key)) + batched[key].append(value) + + return batched def predict(self, inputs: List[Dict[str, Any]]) -> List[Dict[str, Any]]: - """Process inputs and generate videos.""" + """ + Execute video generation for validated requests. + + Handles the complete generation pipeline including: + - Parameter validation + - Resource allocation + - Video generation + - S3 upload + - Performance monitoring + - Resource cleanup + + Args: + inputs: List of validated generation parameters + + Returns: + List of generation results or error details + + Raises: + RuntimeError: On generation failure + """ results = [] - try: - for request in inputs: - start_time = time.time() - try: - # Validate and parse the request - generation_request = VideoGenerationRequest(**request) + for request in inputs: + start_time = time.time() + try: + generation_request = VideoGenerationRequest(**request) + + with tempfile.TemporaryDirectory() as temp_dir: + temp_video_path = os.path.join(temp_dir, f"mochi_{int(time.time())}.mp4") - # Create temporary file path - with tempfile.TemporaryDirectory() as temp_dir: - temp_video_path = os.path.join(temp_dir, f"mochi_{int(time.time())}.mp4") - - # Prepare generation parameters - generation_params = generation_request.dict() - generation_params["output_path"] = temp_video_path - - # Generate video - logger.info(f"Starting generation for prompt: {generation_params['prompt']}") - self.engine.generate(**generation_params) - - end_time = time.time() - self.log("inference_time", end_time - start_time) - - # Get memory usage - allocated, peak = self.engine.get_memory_usage() - - # Upload to S3 - with open(temp_video_path, "rb") as video_file: - s3_response = mp4_to_s3_json(video_file, f"mochi_{int(time.time())}.mp4") - - result = { - "status": "success", - "video_id": s3_response["video_id"], - "video_url": s3_response["url"], - "prompt": generation_params["prompt"], - "generation_params": generation_params, - "time_taken": end_time - start_time, - "memory_usage": { - "allocated_gb": round(allocated, 2), - "peak_gb": round(peak, 2) - } - } - results.append(result) - - logger.info(f"Generation completed for prompt: {generation_params['prompt']}") + generation_params = generation_request.dict() + generation_params["output_path"] = temp_video_path - except Exception as e: - logger.error(f"Error in generation for request: {e}") - error_result = { - "status": "error", - "error": str(e) + logger.info(f"Starting generation for prompt: {generation_params['prompt']}") + self.engine.generate(**generation_params) + + end_time = time.time() + self.log("inference_time", end_time - start_time) + + allocated, peak = self.engine.get_memory_usage() + + with open(temp_video_path, "rb") as video_file: + s3_response = mp4_to_s3_json(video_file, f"mochi_{int(time.time())}.mp4") + + result = { + "status": "success", + "video_id": s3_response["video_id"], + "video_url": s3_response["url"], + "prompt": generation_params["prompt"], + "generation_params": generation_params, + "time_taken": end_time - start_time, + "memory_usage": { + "allocated_gb": round(allocated, 2), + "peak_gb": round(peak, 2) + } } - results.append(error_result) - finally: - self.engine.clear_memory() + results.append(result) - except Exception as e: - logger.error(f"Error in predict method: {e}") - results.append({ - "status": "error", - "error": str(e) - }) - + logger.info(f"Generation completed for prompt: {generation_params['prompt']}") + + except Exception as e: + logger.error(f"Error in generation for request: {e}") + error_result = { + "status": "error", + "error": str(e) + } + results.append(error_result) + finally: + self.engine.clear_memory() + return results if results else [{"status": "error", "error": "No results generated"}] def unbatch(self, outputs: List[Dict[str, Any]]) -> List[Dict[str, Any]]: - """Convert batched outputs back to individual results.""" + """ + Convert batched outputs to individual results. + + Args: + outputs: List of generation results + + Returns: + Unbatched list of results + """ return outputs def encode_response(self, output: Union[Dict[str, Any], List[Any]]) -> Dict[str, Any]: - """Encode the output for response.""" - try: - # If output is a list, take the first item - if isinstance(output, list): - output = output[0] if output else {"status": "error", "error": "No output generated"} - - # Handle error cases - if output.get("status") == "error": - return { - "status": "error", - "error": output.get("error", "Unknown error"), - "item_index": output.get("item_index") - } - - # Return the successful response directly - return { - "status": "success", - "video_id": output.get("video_id"), - "video_url": output.get("video_url"), - "generation_info": { - "prompt": output.get("prompt"), - "parameters": output.get("generation_params", {}) - }, - "performance": { - "time_taken": round(output.get("time_taken", 0), 2), - "memory_usage": output.get("memory_usage", {}) - } - } - - except Exception as e: - logger.error(f"Error in encode_response: {e}") + """ + Format generation results for API response. + + Args: + output: Raw generation results or error information + + Returns: + Formatted API response with standardized structure + + Note: + Handles both success and error cases with consistent formatting + """ + if isinstance(output, list): + output = output[0] if output else {"status": "error", "error": "No output generated"} + + if output.get("status") == "error": return { "status": "error", - "error": str(e) + "error": output.get("error", "Unknown error"), + "item_index": output.get("item_index") + } + + return { + "status": "success", + "video_id": output.get("video_id"), + "video_url": output.get("video_url"), + "generation_info": { + "prompt": output.get("prompt"), + "parameters": output.get("generation_params", {}) + }, + "performance": { + "time_taken": round(output.get("time_taken", 0), 2), + "memory_usage": output.get("memory_usage", {}) } + } -if __name__ == "__main__": - import sys +def main(): + """ + Initialize and launch the Mochi video generation service. + + Sets up the complete service infrastructure including: + - Prometheus metrics collection + - Structured logging + - API server configuration + - Error handling + """ prometheus_logger = PrometheusLogger() prometheus_logger.mount(path="/api/v1/metrics", app=make_asgi_app(registry=registry)) - # Configure logging + logger.remove() logger.add( sys.stdout, @@ -255,7 +350,6 @@ def encode_response(self, output: Union[Dict[str, Any], List[Any]]) -> Dict[str, retention="1 week", level="DEBUG" ) - try: api = MochiVideoAPI() @@ -268,10 +362,12 @@ def encode_response(self, output: Union[Dict[str, Any], List[Any]]) -> Dict[str, track_requests=True, loggers=prometheus_logger, generate_client_file=False - ) logger.info("Starting server on port 8000") server.run(port=8000) except Exception as e: logger.error(f"Server failed to start: {e}") - sys.exit(1) \ No newline at end of file + sys.exit(1) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/api/serve.py b/api/serve.py index 9239144..1e7d5d2 100644 --- a/api/serve.py +++ b/api/serve.py @@ -1,31 +1,44 @@ """ -Combined API router for LTX and Mochi video generation services. +combined_serve.py -This module provides a unified endpoint that can handle requests for both -LTX and Mochi video generation models. Clients specify which model to use -via the 'model_name' parameter in their requests. +A unified API for both LTX and Mochi video generation using batch-based parallel processing. +All requests (single or multi) are sent as a batch, and we process them concurrently. Usage: - POST /predict + 1. Place this file in your repo (e.g., in `api/combined_serve.py`). + 2. Run: python combined_serve.py + 3. POST requests to http://localhost:8000/api/v1/inference + +Expected request format: +{ + "batch": [ { - "model_name": "ltx", # or "mochi" - "prompt": "your video prompt here", - ...other model-specific parameters... - } + "model_name": "mochi", + "prompt": "A calm ocean scene at sunset", + "negative_prompt": "blurry, worst quality", + "num_inference_steps": 50, + "guidance_scale": 4.5, + "height": 480, + "width": 848 + ... + }, + ... + ] +} """ import sys import os -from typing import Dict, Any, List, Union -from pydantic import BaseModel, Field -from loguru import logger import torch import litserve as ls +import json +from concurrent.futures import ProcessPoolExecutor, as_completed +from typing import Any, Dict, List +from loguru import logger from prometheus_client import CollectorRegistry, Histogram, make_asgi_app, multiprocess -from ltx_serve import LTXVideoAPI -from mochi_serve import MochiVideoAPI - -# Setup Prometheus multiprocess mode +from api.ltx_serve import LTXVideoAPI +from api.mochi_serve import MochiVideoAPI +from configs.combined_settings import CombinedBatchRequest, CombinedItemRequest os.environ["PROMETHEUS_MULTIPROC_DIR"] = "/tmp/prometheus_multiproc" if not os.path.exists("/tmp/prometheus_multiproc"): os.makedirs("/tmp/prometheus_multiproc") @@ -33,9 +46,13 @@ registry = CollectorRegistry() multiprocess.MultiProcessCollector(registry) + class PrometheusLogger(ls.Logger): - """Custom logger for tracking combined API metrics.""" - + """ + Custom logger for Prometheus metrics. + Tracks request duration for each (model_name, function_name) pair. + """ + def __init__(self): super().__init__() self.function_duration = Histogram( @@ -46,167 +63,198 @@ def __init__(self): ) def process(self, key: str, value: float) -> None: - """Record metric observations.""" - model_name, func_name = key.split(":", 1) if ":" in key else ("unknown", key) + """ + Record metric observations with labels for both model_name and function_name. + `key` is expected to have the format "model_name:function_name". + """ + if ":" in key: + model_name, func_name = key.split(":", 1) + else: + model_name, func_name = "unknown", key + self.function_duration.labels( model_name=model_name, function_name=func_name ).observe(value) -class CombinedRequest(BaseModel): - """Request model for the combined API endpoint.""" - - model_name: str = Field( - ..., - description="Model to use for video generation ('ltx' or 'mochi')" - ) - - class Config: - extra = "allow" # Allow additional fields for model-specific parameters class CombinedVideoAPI(ls.LitAPI): - """Combined API for serving both LTX and Mochi video generation models.""" + """ + Combined Video Generation API for both LTX and Mochi models. + This API handles requests in batch form, even for single items. + + Steps: + 1) setup(device): Initialize LTX and Mochi sub-APIs on the specified device (CPU, GPU). + 2) decode_request(request): Parse the request body using Pydantic `CombinedBatchRequest`. + 3) predict(inputs): Parallel process each item in the batch. + 4) encode_response(outputs): Format the final JSON response. + """ def setup(self, device: str) -> None: - """Initialize both video generation models. - - Args: - device: Target device for model execution """ - logger.info(f"Setting up combined video API on device: {device}") - - # Initialize both APIs + Called once at server startup. + Initializes both the LTX and Mochi APIs on the same device. + """ + logger.info(f"Initializing CombinedVideoAPI on device={device}") self.ltx_api = LTXVideoAPI() self.mochi_api = MochiVideoAPI() - - # Setup each model + self.ltx_api.setup(device=device) self.mochi_api.setup(device=device) - - # Register models for routing + self.model_apis = { "ltx": self.ltx_api, "mochi": self.mochi_api } - - logger.info("Successfully initialized all models") - - def decode_request( - self, - request: Union[Dict[str, Any], List[Dict[str, Any]]] - ) -> Dict[str, Any]: - """Validate request and determine target model. - + + logger.info("All sub-APIs (LTX, Mochi) successfully set up") + + def decode_request(self, request: Any) -> Dict[str, List[Dict[str, Any]]]: + """ + Interprets the raw request body as a batch, then validates it. + We unify single vs. multiple requests by requiring a `batch` array. + Args: - request: Raw request data - + request: The raw request data (usually a dict from the body). + Returns: - Decoded request with model selection - + A dictionary with key 'items' containing a list of validated dicts. + Raises: - ValueError: If model_name is invalid + ValidationError if the request doesn't match CombinedBatchRequest schema. """ + # If user directly posted an array, wrap it to match the expected schema if isinstance(request, list): - request = request[0] # Handle single requests for now - - validated = CombinedRequest(**request).dict() - model_name = validated.pop("model_name").lower() - - if model_name not in self.model_apis: - raise ValueError( - f"Invalid model_name: {model_name}. " - f"Available models: {list(self.model_apis.keys())}" - ) - - return { - "model_name": model_name, - "request_data": validated - } + request = {"batch": request} - def predict(self, inputs: Dict[str, Any]) -> Dict[str, Any]: - """Route request to appropriate model and generate video. - - Args: - inputs: Decoded request data - - Returns: - Generation results from selected model + # Validate using CombinedBatchRequest + validated_batch = CombinedBatchRequest(**request) + + # Convert each CombinedItemRequest into a dict for usage in predict + items = [item.dict() for item in validated_batch.batch] + return {"items": items} + + def predict(self, inputs: Dict[str, List[Dict[str, Any]]]) -> Dict[str, Any]: """ - model_name = inputs["model_name"] - request_data = inputs["request_data"] - model_api = self.model_apis[model_name] - - try: - # Process request through selected model - decoded = model_api.decode_request(request_data) - predictions = model_api.predict(decoded) - result = predictions[0] if predictions else { - "status": "error", - "error": "No result returned" - } - - return { - "model_name": model_name, - "result": result - } - - except Exception as e: - import traceback - logger.error(f"Error in {model_name} prediction: {str(e)}") - return { - "model_name": model_name, - "status": "error", - "error": str(e), - "traceback": traceback.format_exc() - } - finally: - if torch.cuda.is_available(): - torch.cuda.empty_cache() + Execute parallel inference for all items in the 'items' list. - def encode_response(self, output: Dict[str, Any]) -> Dict[str, Any]: - """Encode final response using model-specific encoder. - Args: - output: Raw model output - + inputs: Dictionary with key 'items' -> list of items. + Each item is a dict with fields like 'model_name', 'prompt', etc. + Returns: - Encoded response ready for client + Dictionary with 'batch_results': a list of output dicts, + each containing status, video_id, video_url, etc. """ - model_name = output.get("model_name") - if model_name and model_name in self.model_apis: - result = output.get("result", {}) - - if result.get("status") == "error": + items = inputs["items"] + logger.info(f"Processing batch of {len(items)} request(s) in parallel") + + # We'll define a helper function for one item + def process_single(item: Dict[str, Any]) -> Dict[str, Any]: + """ + Takes a single request dict, delegates to the correct sub-API (LTX or Mochi). + Returns the predicted result (video URL, etc.). + """ + model_name = item.get("model_name", "").lower() + if model_name not in self.model_apis: return { "status": "error", - "error": result.get("error", "Unknown error"), - "traceback": result.get("traceback") + "error": f"Invalid model_name: {model_name}" } - - encoded = self.model_apis[model_name].encode_response(result) - encoded["model_name"] = model_name - return encoded - else: + + sub_api = self.model_apis[model_name] + + # Sub-API workflow: decode -> predict -> single result + # Note: sub_api.decode_request() often returns a list. We'll handle that carefully. + try: + # Prepare sub-request in their expected format + sub_decoded = sub_api.decode_request(item) + sub_pred = sub_api.predict(sub_decoded) + return sub_pred[0] if sub_pred else { + "status": "error", + "error": "No result returned from sub-API." + } + except Exception as e: + logger.error(f"[{model_name}] sub-api error: {e}") + return {"status": "error", "error": str(e), "model_name": model_name} + + # Use a ProcessPoolExecutor to handle CPU-heavy tasks concurrently + results = [] + with ProcessPoolExecutor(max_workers=os.cpu_count()) as executor: + future_to_idx = {} + for idx, item in enumerate(items): + future = executor.submit(process_single, item) + future_to_idx[future] = idx + + for f in as_completed(future_to_idx): + idx = future_to_idx[f] + try: + out = f.result() + out["item_index"] = idx + if "model_name" not in out: + out["model_name"] = items[idx].get("model_name", "unknown") + results.append(out) + except Exception as e: + # If something catastrophic happened in process_single + results.append({"status": "error", "error": str(e), "item_index": idx}) + + # Sort results by item_index so response order matches input order + results.sort(key=lambda x: x["item_index"]) + return {"batch_results": results} + + def encode_response(self, outputs: Dict[str, Any]) -> Dict[str, Any]: + """ + Convert the raw dictionary from `predict` into a final response. + We unify single vs. multiple items: + - The client always receives "batch_results" + (with 1 result if originally a single item). + + Sub-APIs often have their own encode_response() method to standardize the final JSON. + We'll call that to keep consistent format. + + Returns: + The final JSON-serializable dict. + """ + if "batch_results" not in outputs: return { - "status": "error", - "error": output.get("error", "Unknown routing error"), - "traceback": output.get("traceback") + "status": "error", + "error": "No batch_results field found in predict output" } + for item in outputs["batch_results"]: + if item.get("status") == "error": + continue + + model_name = item.get("model_name", "").lower() + if model_name in self.model_apis: + sub_encoded = self.model_apis[model_name].encode_response(item) + item.update(sub_encoded) + + return outputs + + def main(): - """Initialize and start the combined video generation server.""" - # Setup Prometheus metrics + """ + Main entry point for the combined server, exposing /predict on port 8000. + This version logs metrics to Prometheus and logs to console + file. + """ + from litserve import LitServer + + # PROMETHEUS LOGGER prometheus_logger = PrometheusLogger() prometheus_logger.mount( path="/metrics", app=make_asgi_app(registry=registry) ) - # Configure logging - logger.remove() + # LOGGING + logger.remove() # Remove default handler logger.add( sys.stdout, - format="{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {name}:{function} - {message}", + format="{time:YYYY-MM-DD HH:mm:ss} " + "| {level: <8} " + "| {name}:{function} - " + "{message}", level="INFO" ) logger.add( @@ -216,19 +264,20 @@ def main(): level="DEBUG" ) - # Start server - logger.info("Starting Combined Video Generation Server") api = CombinedVideoAPI() - server = ls.LitServer( + server = LitServer( api, - api_path="/predict", - accelerator="auto", - devices="auto", - max_batch_size=1, + api_path="/api/v1/inference", + accelerator="auto", + devices="auto", + max_batch_size=4, track_requests=True, loggers=[prometheus_logger] ) + + logger.info("Starting combined video generation server on port 8000") server.run(port=8000) + if __name__ == "__main__": main() diff --git a/configs/__pycache__/combined_settings.cpython-310.pyc b/configs/__pycache__/combined_settings.cpython-310.pyc new file mode 100644 index 0000000..659a291 Binary files /dev/null and b/configs/__pycache__/combined_settings.cpython-310.pyc differ diff --git a/configs/combined_settings.py b/configs/combined_settings.py new file mode 100644 index 0000000..bd9d7d9 --- /dev/null +++ b/configs/combined_settings.py @@ -0,0 +1,110 @@ +""" +combined_settings.py + +Central Pydantic models and optional unified config for combining Mochi and LTX requests/settings. + +1) CombinedItemRequest: + - Defines the schema for a single text-to-video request, including model_name (mochi/ltx) + and common fields like prompt, negative_prompt, resolution, etc. + +2) CombinedBatchRequest: + - Defines a list of CombinedItemRequest items, handling batch-mode requests. + +3) CombinedConfig (optional): + - Demonstrates how you might wrap both mochi_settings and ltx_settings into a single config + so everything can be accessed from a unified settings object if needed. + +Usage Example (Batch): + POST /predict + { + "batch": [ + { + "model_name": "mochi", + "prompt": "A calm ocean scene at sunset", + "negative_prompt": "blurry, worst quality", + "num_inference_steps": 50, + "guidance_scale": 4.5, + "height": 480, + "width": 848 + }, + { + "model_name": "ltx", + "prompt": "Golden autumn leaves swirling", + "num_inference_steps": 40, + "guidance_scale": 3.0, + "height": 480, + "width": 704 + } + ] + } +""" + +from pydantic import BaseModel, Field +from typing import List, Optional + +# If you want to embed or reference them here: +from .mochi_settings import MochiSettings +from .ltx_settings import LTXVideoSettings +from pydantic_settings import BaseSettings + + +class CombinedItemRequest(BaseModel): + """ + A single request object for either Mochi or LTX. + + Fields: + model_name (str): Which model to use: 'mochi' or 'ltx'. + prompt (str): Main prompt describing the video content. + negative_prompt (Optional[str]): Text describing what to avoid. + num_inference_steps (Optional[int]): Override for inference steps. + guidance_scale (Optional[float]): Classifier-free guidance scale. + height (Optional[int]): Video height in pixels. + width (Optional[int]): Video width in pixels. + (Add additional fields as needed for your models.) + """ + model_name: str = Field(..., description="Model to use: 'ltx' or 'mochi'.") + prompt: str = Field(..., description="Prompt describing the video content.") + negative_prompt: Optional[str] = Field(None, description="Things to avoid in generation.") + num_inference_steps: Optional[int] = Field(40, description="Number of denoising steps.") + guidance_scale: Optional[float] = Field(3.0, description="Guidance scale for generation.") + height: Optional[int] = Field(480, description="Video height in pixels.") + width: Optional[int] = Field(704, description="Video width in pixels.") + # Add any more fields your sub-models need, e.g. fps, frames, etc. + + +class CombinedBatchRequest(BaseModel): + """ + A batched request containing multiple CombinedItemRequest items. + + Usage: + { + "batch": [ + { "model_name": "mochi", "prompt": "...", ... }, + { "model_name": "ltx", "prompt": "...", ... } + ] + } + """ + batch: List[CombinedItemRequest] = Field( + ..., description="List of multiple CombinedItemRequest items to process in parallel." + ) + + +class CombinedConfig(BaseSettings): + """ + Optional: A unified config that embeds or references your Mochi/LTX model settings. + + This can be used if you want to store and manipulate both sets of settings in one place. + For example, you might define environment variables to override mochi or ltx defaults. + + Usage: + from configs.combined_settings import CombinedConfig + combined_config = CombinedConfig() + # Access mochi or ltx settings: combined_config.mochi_config, combined_config.ltx_config + """ + mochi_config: MochiSettings = MochiSettings() + ltx_config: LTXVideoSettings = LTXVideoSettings() + + class Config: + env_prefix = "COMBINED_" + validate_assignment = True + arbitrary_types_allowed = True