diff --git a/.gitignore b/.gitignore index 0f66c46..ae54f51 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1 @@ -src/scripts/weights \ No newline at end of file +api/checkpoints \ No newline at end of file diff --git a/.gitmodules b/.gitmodules index 8253cf1..e0208d6 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,3 +1,3 @@ [submodule "LTX-Video"] - path = LTX-Video + path = ltx url = https://github.com/Lightricks/LTX-Video.git diff --git a/api/client.py b/api/client.py new file mode 100644 index 0000000..12b71e5 --- /dev/null +++ b/api/client.py @@ -0,0 +1,18 @@ + +# Copyright The Lightning AI team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import requests + +response = requests.post("http://127.0.0.1:8000/predict", json={"input": 4.0}) +print(f"Status: {response.status_code}\nResponse:\n {response.text}") diff --git a/api/logs/ltx_api.log b/api/logs/ltx_api.log index fef8ad9..42b9871 100644 --- a/api/logs/ltx_api.log +++ b/api/logs/ltx_api.log @@ -1,2 +1,14 @@ 2024-11-23 19:32:34.895 | ERROR | __main__:main:352 - Server failed to start: LitServer.__init__() got an unexpected keyword argument 'generate_client_file' 2024-11-23 19:33:20.102 | INFO | __main__:main:347 - Starting LTX video generation server on port 8000 +2024-11-26 10:03:46.386 | INFO | __main__:main:347 - Starting LTX video generation server on port 8000 +2024-11-26 10:27:38.795 | INFO | __main__:main:347 - Starting LTX video generation server on port 8000 +2024-11-26 10:33:47.301 | INFO | __main__:main:347 - Starting LTX video generation server on port 8000 +2024-11-26 10:37:58.440 | INFO | __main__:main:347 - Starting LTX video generation server on port 8000 +2024-11-26 10:40:43.972 | INFO | __main__:main:347 - Starting LTX video generation server on port 8000 +2024-11-26 10:47:33.232 | INFO | __main__:main:347 - Starting LTX video generation server on port 8000 +2024-11-26 11:09:31.128 | INFO | __main__:main:347 - Starting LTX video generation server on port 8000 +2024-11-26 11:20:44.187 | INFO | __main__:main:347 - Starting LTX video generation server on port 8000 +2024-11-26 11:26:48.145 | INFO | __main__:main:347 - Starting LTX video generation server on port 8000 +2024-11-26 11:27:36.160 | INFO | __main__:main:347 - Starting LTX video generation server on port 8000 +2024-11-26 11:31:41.717 | INFO | __main__:main:343 - Starting LTX video generation server on port 8000 +2024-11-26 11:34:03.736 | INFO | __main__:main:343 - Starting LTX video generation server on port 8000 diff --git a/api/ltx_serve.py b/api/ltx_serve.py index 03c3783..471d84d 100644 --- a/api/ltx_serve.py +++ b/api/ltx_serve.py @@ -103,19 +103,12 @@ def setup(self, device: str) -> None: try: logger.info(f"Initializing LTX video generation on device: {device}") - # Initialize settings - self.settings = LTXVideoSettings( - device=device, - ckpt_dir=os.environ.get("LTX_CKPT_DIR", "checkpoints"), - ) + # Initialize settings with device + self.settings = LTXVideoSettings(device=device) # Initialize inference engine self.engine = LTXInference(self.settings) - # Create output directory - self.output_dir = Path("outputs") - self.output_dir.mkdir(parents=True, exist_ok=True) - logger.info("LTX setup completed successfully") except Exception as e: @@ -193,21 +186,21 @@ def predict(self, inputs: List[Dict[str, Any]]) -> List[Dict[str, Any]]: # Validate request generation_request = VideoGenerationRequest(**request) + # Update settings with request parameters + self.settings.prompt = generation_request.prompt + self.settings.negative_prompt = generation_request.negative_prompt + self.settings.num_inference_steps = generation_request.num_inference_steps + self.settings.guidance_scale = generation_request.guidance_scale + self.settings.height = generation_request.height + self.settings.width = generation_request.width + self.settings.num_frames = generation_request.num_frames + self.settings.frame_rate = generation_request.frame_rate + self.settings.seed = generation_request.seed + # Create temporary directory for output with tempfile.TemporaryDirectory() as temp_dir: temp_video_path = Path(temp_dir) / f"ltx_{int(time.time())}.mp4" - - # Update settings with request parameters - self.settings.prompt = generation_request.prompt - self.settings.negative_prompt = generation_request.negative_prompt - self.settings.num_inference_steps = generation_request.num_inference_steps - self.settings.guidance_scale = generation_request.guidance_scale - self.settings.height = generation_request.height - self.settings.width = generation_request.width - self.settings.num_frames = generation_request.num_frames - self.settings.frame_rate = generation_request.frame_rate - self.settings.seed = generation_request.seed - self.settings.output_path = str(temp_video_path) + self.settings.output_path = temp_video_path # Generate video logger.info(f"Starting generation for prompt: {generation_request.prompt}") @@ -218,7 +211,10 @@ def predict(self, inputs: List[Dict[str, Any]]) -> List[Dict[str, Any]]: self.log("inference_time", generation_time) # Get memory statistics - memory_stats = self.engine.get_memory_stats() + memory_stats = { + "gpu_allocated": torch.cuda.memory_allocated() if torch.cuda.is_available() else 0, + "gpu_reserved": torch.cuda.memory_reserved() if torch.cuda.is_available() else 0 + } # Upload to S3 s3_response = mp4_to_s3_json( @@ -247,9 +243,9 @@ def predict(self, inputs: List[Dict[str, Any]]) -> List[Dict[str, Any]]: }) finally: - # Cleanup - if hasattr(self.engine, 'clear_memory'): - self.engine.clear_memory() + # Clear CUDA cache after each generation + if torch.cuda.is_available(): + torch.cuda.empty_cache() except Exception as e: logger.error(f"Error in predict method: {e}") diff --git a/configs/__pycache__/ltx_settings.cpython-311.pyc b/configs/__pycache__/ltx_settings.cpython-311.pyc index cf31d52..f4b1e8f 100644 Binary files a/configs/__pycache__/ltx_settings.cpython-311.pyc and b/configs/__pycache__/ltx_settings.cpython-311.pyc differ diff --git a/configs/ltx_settings.py b/configs/ltx_settings.py index 4df90d8..0231bdc 100644 --- a/configs/ltx_settings.py +++ b/configs/ltx_settings.py @@ -48,6 +48,12 @@ class LTXVideoSettings(BaseSettings): width: int = Field(704, ge=256, le=1280, description="Width of output video frames") num_frames: int = Field(121, ge=1, le=257, description="Number of frames to generate") frame_rate: int = Field(25, ge=1, le=60, description="Frame rate of output video") + num_images_per_prompt: int = Field( + 1, + ge=1, + le=4, + description="Number of videos to generate per prompt" + ) # Model Settings bfloat16: bool = Field(False, description="Use bfloat16 precision") @@ -158,4 +164,25 @@ def get_model_paths(self) -> tuple[Path, Path, Path]: vae_dir = self.ckpt_dir / "vae" scheduler_dir = self.ckpt_dir / "scheduler" - return unet_dir, vae_dir, scheduler_dir \ No newline at end of file + return unet_dir, vae_dir, scheduler_dir + + def get_output_resolution(self) -> tuple[int, int]: + """Get the output resolution as a tuple of (height, width).""" + return (self.height, self.width) + + def get_padded_num_frames(self) -> int: + """ + Calculate the padded number of frames. + Ensures the number of frames is compatible with model requirements. + """ + # Common video models often require frame counts to be multiples of 8 + FRAME_PADDING = 8 + + # Calculate padding needed to reach next multiple of FRAME_PADDING + remainder = self.num_frames % FRAME_PADDING + if remainder == 0: + return self.num_frames + + padding_needed = FRAME_PADDING - remainder + return self.num_frames + padding_needed + \ No newline at end of file diff --git a/scripts/__pycache__/ltx_inference.cpython-311.pyc b/scripts/__pycache__/ltx_inference.cpython-311.pyc index 82a47c3..b0c2b93 100644 Binary files a/scripts/__pycache__/ltx_inference.cpython-311.pyc and b/scripts/__pycache__/ltx_inference.cpython-311.pyc differ diff --git a/scripts/__pycache__/mp4_to_s3_json.cpython-311.pyc b/scripts/__pycache__/mp4_to_s3_json.cpython-311.pyc index 3c6b8f9..4c9cadd 100644 Binary files a/scripts/__pycache__/mp4_to_s3_json.cpython-311.pyc and b/scripts/__pycache__/mp4_to_s3_json.cpython-311.pyc differ diff --git a/scripts/__pycache__/s3_manager.cpython-311.pyc b/scripts/__pycache__/s3_manager.cpython-311.pyc index 6ea7a60..86b691a 100644 Binary files a/scripts/__pycache__/s3_manager.cpython-311.pyc and b/scripts/__pycache__/s3_manager.cpython-311.pyc differ diff --git a/scripts/ltx_inference.py b/scripts/ltx_inference.py index 461625f..d591b37 100644 --- a/scripts/ltx_inference.py +++ b/scripts/ltx_inference.py @@ -128,7 +128,7 @@ def _initialize_pipeline(self) -> LTXVideoPipeline: def _load_scheduler(self, scheduler_dir: Path): """Load and configure the scheduler""" - from ltx_video.schedulers.rf import RectifiedFlowScheduler + from ltx.ltx_video.schedulers.rf import RectifiedFlowScheduler scheduler_config_path = scheduler_dir / "scheduler_config.json" scheduler_config = RectifiedFlowScheduler.load_config(scheduler_config_path) return RectifiedFlowScheduler.from_config(scheduler_config)