From 2d4c4a338ac73ff621b2601aba8bddcecbe9973d Mon Sep 17 00:00:00 2001 From: vikramxD Date: Tue, 26 Nov 2024 11:40:19 +0000 Subject: [PATCH] Refactor gitignore and gitmodules `` --- .gitignore | 2 +- .gitmodules | 2 +- api/client.py | 18 +++++++ api/logs/ltx_api.log | 12 +++++ api/ltx_serve.py | 46 ++++++++---------- .../__pycache__/ltx_settings.cpython-311.pyc | Bin 7975 -> 8957 bytes configs/ltx_settings.py | 29 ++++++++++- .../__pycache__/ltx_inference.cpython-311.pyc | Bin 18772 -> 18776 bytes .../mp4_to_s3_json.cpython-311.pyc | Bin 1352 -> 1352 bytes .../__pycache__/s3_manager.cpython-311.pyc | Bin 4109 -> 4109 bytes scripts/ltx_inference.py | 2 +- 11 files changed, 82 insertions(+), 29 deletions(-) create mode 100644 api/client.py diff --git a/.gitignore b/.gitignore index 0f66c46..ae54f51 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1 @@ -src/scripts/weights \ No newline at end of file +api/checkpoints \ No newline at end of file diff --git a/.gitmodules b/.gitmodules index 8253cf1..e0208d6 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,3 +1,3 @@ [submodule "LTX-Video"] - path = LTX-Video + path = ltx url = https://github.com/Lightricks/LTX-Video.git diff --git a/api/client.py b/api/client.py new file mode 100644 index 0000000..12b71e5 --- /dev/null +++ b/api/client.py @@ -0,0 +1,18 @@ + +# Copyright The Lightning AI team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import requests + +response = requests.post("http://127.0.0.1:8000/predict", json={"input": 4.0}) +print(f"Status: {response.status_code}\nResponse:\n {response.text}") diff --git a/api/logs/ltx_api.log b/api/logs/ltx_api.log index fef8ad9..42b9871 100644 --- a/api/logs/ltx_api.log +++ b/api/logs/ltx_api.log @@ -1,2 +1,14 @@ 2024-11-23 19:32:34.895 | ERROR | __main__:main:352 - Server failed to start: LitServer.__init__() got an unexpected keyword argument 'generate_client_file' 2024-11-23 19:33:20.102 | INFO | __main__:main:347 - Starting LTX video generation server on port 8000 +2024-11-26 10:03:46.386 | INFO | __main__:main:347 - Starting LTX video generation server on port 8000 +2024-11-26 10:27:38.795 | INFO | __main__:main:347 - Starting LTX video generation server on port 8000 +2024-11-26 10:33:47.301 | INFO | __main__:main:347 - Starting LTX video generation server on port 8000 +2024-11-26 10:37:58.440 | INFO | __main__:main:347 - Starting LTX video generation server on port 8000 +2024-11-26 10:40:43.972 | INFO | __main__:main:347 - Starting LTX video generation server on port 8000 +2024-11-26 10:47:33.232 | INFO | __main__:main:347 - Starting LTX video generation server on port 8000 +2024-11-26 11:09:31.128 | INFO | __main__:main:347 - Starting LTX video generation server on port 8000 +2024-11-26 11:20:44.187 | INFO | __main__:main:347 - Starting LTX video generation server on port 8000 +2024-11-26 11:26:48.145 | INFO | __main__:main:347 - Starting LTX video generation server on port 8000 +2024-11-26 11:27:36.160 | INFO | __main__:main:347 - Starting LTX video generation server on port 8000 +2024-11-26 11:31:41.717 | INFO | __main__:main:343 - Starting LTX video generation server on port 8000 +2024-11-26 11:34:03.736 | INFO | __main__:main:343 - Starting LTX video generation server on port 8000 diff --git a/api/ltx_serve.py b/api/ltx_serve.py index 03c3783..471d84d 100644 --- a/api/ltx_serve.py +++ b/api/ltx_serve.py @@ -103,19 +103,12 @@ def setup(self, device: str) -> None: try: logger.info(f"Initializing LTX video generation on device: {device}") - # Initialize settings - self.settings = LTXVideoSettings( - device=device, - ckpt_dir=os.environ.get("LTX_CKPT_DIR", "checkpoints"), - ) + # Initialize settings with device + self.settings = LTXVideoSettings(device=device) # Initialize inference engine self.engine = LTXInference(self.settings) - # Create output directory - self.output_dir = Path("outputs") - self.output_dir.mkdir(parents=True, exist_ok=True) - logger.info("LTX setup completed successfully") except Exception as e: @@ -193,21 +186,21 @@ def predict(self, inputs: List[Dict[str, Any]]) -> List[Dict[str, Any]]: # Validate request generation_request = VideoGenerationRequest(**request) + # Update settings with request parameters + self.settings.prompt = generation_request.prompt + self.settings.negative_prompt = generation_request.negative_prompt + self.settings.num_inference_steps = generation_request.num_inference_steps + self.settings.guidance_scale = generation_request.guidance_scale + self.settings.height = generation_request.height + self.settings.width = generation_request.width + self.settings.num_frames = generation_request.num_frames + self.settings.frame_rate = generation_request.frame_rate + self.settings.seed = generation_request.seed + # Create temporary directory for output with tempfile.TemporaryDirectory() as temp_dir: temp_video_path = Path(temp_dir) / f"ltx_{int(time.time())}.mp4" - - # Update settings with request parameters - self.settings.prompt = generation_request.prompt - self.settings.negative_prompt = generation_request.negative_prompt - self.settings.num_inference_steps = generation_request.num_inference_steps - self.settings.guidance_scale = generation_request.guidance_scale - self.settings.height = generation_request.height - self.settings.width = generation_request.width - self.settings.num_frames = generation_request.num_frames - self.settings.frame_rate = generation_request.frame_rate - self.settings.seed = generation_request.seed - self.settings.output_path = str(temp_video_path) + self.settings.output_path = temp_video_path # Generate video logger.info(f"Starting generation for prompt: {generation_request.prompt}") @@ -218,7 +211,10 @@ def predict(self, inputs: List[Dict[str, Any]]) -> List[Dict[str, Any]]: self.log("inference_time", generation_time) # Get memory statistics - memory_stats = self.engine.get_memory_stats() + memory_stats = { + "gpu_allocated": torch.cuda.memory_allocated() if torch.cuda.is_available() else 0, + "gpu_reserved": torch.cuda.memory_reserved() if torch.cuda.is_available() else 0 + } # Upload to S3 s3_response = mp4_to_s3_json( @@ -247,9 +243,9 @@ def predict(self, inputs: List[Dict[str, Any]]) -> List[Dict[str, Any]]: }) finally: - # Cleanup - if hasattr(self.engine, 'clear_memory'): - self.engine.clear_memory() + # Clear CUDA cache after each generation + if torch.cuda.is_available(): + torch.cuda.empty_cache() except Exception as e: logger.error(f"Error in predict method: {e}") diff --git a/configs/__pycache__/ltx_settings.cpython-311.pyc b/configs/__pycache__/ltx_settings.cpython-311.pyc index cf31d52a10b191abecb38d63abf5f016f88bdb22..f4b1e8f9f32dc23123e10267876e74223f2784a5 100644 GIT binary patch delta 1703 zcmZ`(TTdHD6rNdMu(=pV#t93A4x}M$n?O}ULv9cn%q0oX0ok zwzizinphP*djZT6Vs?6tp1(c{5Z=7l3~TAC7k${6iFn2_v!%OyH8x61~kVimnht!Wz z_bP_8VOJJ=nT|3yu-p+l9b2-mg8A+_dSv}tA-YBH*4E>#m! z9CQ3)4?GrklQ}VwPRub*vPe3nNEb%L^#Uy5ST+1 zbMZg-0|&lR{JnnF=DG?26(=7zKlV-GQS*UM!Pkrb`0^OPXMXA*#y5+<`43^o2S8gf z|2*7_RkPv9!`8P!=Ticv0ThW(3b{qdb_nKZjNq@$c;ifNe!JD+6VqaY{dX(CZIneH z!n3kK-eyK5XavsmjWB~4F-EkskxCncGhK_P4OL5$uuj6nNT=f9jF8S4Q=>8CB)Oqd zBc^mgIwv72KZQ^Z0`YNBbP=6!aRViPlo_E)(9rXsnU1}U?|LEcdlBsrz$fUDw_&+v zW#+5kdVS}*tCKf^b$7uzuvis~Ufo6%FXCNQI&?Q@D35KJxE{Egu>9`mH(X>>UQbuzGMqq1JB0QX6de=5Tkkr#qzQx|l zBC1a2VTQtnItNX@p&Bug&?t+;EnH8loF!P&*b(co3-aNj0_b$o*Env zgx(q)8XB7zDY@ZjSWQyKOCFIT40tHX7^JF;*%p=TxTH!EAu-04><_r4eY_tamOg;c zEZkDL7RgQdy0R1b<>kpjVxyra@7}CwSizt7=leH(N0tZHYug{m9c%K2{CeKAS?~KI zS7_a+?_PIxTX(9w%I6>;pBF#{F!N#KLgcg%djU%JIZcc6GeYkZa8|%M0p|sXzop76 zGAblDK&fUv9@g~)Gh!O$?+ddk%noa;c~Tfs0v61I+$Qzdi@(Z>^u}geN5Q}8g8Xa6 zx8Y*h?PzyCK>$w$tlQ<4o-%ScoECrsByaJr<}dL(hb4znNUu%bKejgeU9OBk-3u;{ wz>% delta 667 zcmezCy4;R$IWI340}#0CI;F1_+sHSGk+EmYO6lP*bWl7;mvBFR*oywicCJEEZ090=cS8tPIi=kcyY^EJZO^Rxi zY^nfIhdr3586^kiIe>Y(QSvTet|OT1l**aLl;YgN5~Yyhlqv~iy8zjWKxZkXa4us6 zIv3m(586%a}PlfmUg<6?sn<5D8**o1847$(X&lSEPfH zF>Xl)=2$U7M%Fl>vPqMR#Z4L0C$AL0Eawm87X^TbK#&1UMZO@m42aDFG;Eec z8sm(~jFKV3ni@q3AgM$Ukpv==Cu_)*g&PBzw^)+$^K*(!L84|L!W=|cfCx(vVFe<< z2D^b+oIv6hcXCc*adB>HNk)E3Q4B~p9VE9~iKc6KXz!Wxn8$2OA0i DJPwZw diff --git a/configs/ltx_settings.py b/configs/ltx_settings.py index 4df90d8..0231bdc 100644 --- a/configs/ltx_settings.py +++ b/configs/ltx_settings.py @@ -48,6 +48,12 @@ class LTXVideoSettings(BaseSettings): width: int = Field(704, ge=256, le=1280, description="Width of output video frames") num_frames: int = Field(121, ge=1, le=257, description="Number of frames to generate") frame_rate: int = Field(25, ge=1, le=60, description="Frame rate of output video") + num_images_per_prompt: int = Field( + 1, + ge=1, + le=4, + description="Number of videos to generate per prompt" + ) # Model Settings bfloat16: bool = Field(False, description="Use bfloat16 precision") @@ -158,4 +164,25 @@ def get_model_paths(self) -> tuple[Path, Path, Path]: vae_dir = self.ckpt_dir / "vae" scheduler_dir = self.ckpt_dir / "scheduler" - return unet_dir, vae_dir, scheduler_dir \ No newline at end of file + return unet_dir, vae_dir, scheduler_dir + + def get_output_resolution(self) -> tuple[int, int]: + """Get the output resolution as a tuple of (height, width).""" + return (self.height, self.width) + + def get_padded_num_frames(self) -> int: + """ + Calculate the padded number of frames. + Ensures the number of frames is compatible with model requirements. + """ + # Common video models often require frame counts to be multiples of 8 + FRAME_PADDING = 8 + + # Calculate padding needed to reach next multiple of FRAME_PADDING + remainder = self.num_frames % FRAME_PADDING + if remainder == 0: + return self.num_frames + + padding_needed = FRAME_PADDING - remainder + return self.num_frames + padding_needed + \ No newline at end of file diff --git a/scripts/__pycache__/ltx_inference.cpython-311.pyc b/scripts/__pycache__/ltx_inference.cpython-311.pyc index 82a47c335aacda5fd6ea4fed5cfebd3e5dce5741..b0c2b93c5972a05c40a74c652b1cf06d0732c4f1 100644 GIT binary patch delta 63 zcmcaIiSfoHM&9MTyj%=G5V^=TJ$xhY1PNB@oRSK?$@3)U3%Fh2KtY>*CHqtvZ*9I~ HSHT4Uh9?z& delta 59 zcmcaHiSf!LM&9MTyj%=Gz@h1s9=wrvf&`=Z<{uxdajb diff --git a/scripts/ltx_inference.py b/scripts/ltx_inference.py index 461625f..d591b37 100644 --- a/scripts/ltx_inference.py +++ b/scripts/ltx_inference.py @@ -128,7 +128,7 @@ def _initialize_pipeline(self) -> LTXVideoPipeline: def _load_scheduler(self, scheduler_dir: Path): """Load and configure the scheduler""" - from ltx_video.schedulers.rf import RectifiedFlowScheduler + from ltx.ltx_video.schedulers.rf import RectifiedFlowScheduler scheduler_config_path = scheduler_dir / "scheduler_config.json" scheduler_config = RectifiedFlowScheduler.load_config(scheduler_config_path) return RectifiedFlowScheduler.from_config(scheduler_config)