From 47fcaabaf822eaf07e16adfd8ee9d94683f002ef Mon Sep 17 00:00:00 2001 From: vikramxD Date: Sat, 23 Nov 2024 19:52:54 +0000 Subject: [PATCH] Add LTX-Video submodule Add uv installation and setup to run.sh Update dependencies in requirements.txt Add pyproject.toml to Dockerfile Update README.md Delete LTX-Video submodule Add AWS settings to configs/aws_settings.py Add error logs to api/logs/ltx_api.log Update pyproject.toml Delete AWS settings from src/configs/aws_settings.py Delete AWS settings cache file Delete logs from api/logs/api.log Add mp4_to_s3_json.py script Delete mp4_to_s3_json.py script from src/scripts Add mochi_weights.py to configs Delete mochi_weights.py from src/configs Update mp4_to_s3_json.py script Delete logs from api/logs/api.log --- LTX-Video | 1 - {src/api => api}/logs/api.log | 0 api/logs/ltx_api.log | 2 + api/ltx_serve.py | 355 +++++++++++++++++ {src/api => api}/mochi_serve.py | 0 .../__pycache__/aws_settings.cpython-310.pyc | Bin .../__pycache__/aws_settings.cpython-311.pyc | Bin 0 -> 726 bytes .../__pycache__/ltx_settings.cpython-311.pyc | Bin 0 -> 7975 bytes .../mochi_settings.cpython-310.pyc | Bin .../__pycache__/mochi_weights.cpython-310.pyc | Bin {src/configs => configs}/aws_settings.py | 0 configs/ltx_settings.py | 161 ++++++++ {src/configs => configs}/mochi_settings.py | 0 {src/configs => configs}/mochi_weights.py | 0 pyproject.toml | 8 + .../__pycache__/ltx_inference.cpython-311.pyc | Bin 0 -> 18772 bytes .../mochi_diffusers.cpython-310.pyc | Bin .../mp4_to_s3_json.cpython-310.pyc | Bin .../mp4_to_s3_json.cpython-311.pyc | Bin 0 -> 1352 bytes .../__pycache__/s3_manager.cpython-310.pyc | Bin .../__pycache__/s3_manager.cpython-311.pyc | Bin 0 -> 4109 bytes .../convert_mochi_to_diffusers.py | 0 {src/scripts => scripts}/download_weights.py | 0 scripts/ltx_inference.py | 367 ++++++++++++++++++ {src/scripts => scripts}/mochi_diffusers.py | 0 {src/scripts => scripts}/mp4_to_s3_json.py | 0 {src/scripts => scripts}/s3_manager.py | 0 27 files changed, 893 insertions(+), 1 deletion(-) delete mode 160000 LTX-Video rename {src/api => api}/logs/api.log (100%) create mode 100644 api/logs/ltx_api.log create mode 100644 api/ltx_serve.py rename {src/api => api}/mochi_serve.py (100%) rename {src/configs => configs}/__pycache__/aws_settings.cpython-310.pyc (100%) create mode 100644 configs/__pycache__/aws_settings.cpython-311.pyc create mode 100644 configs/__pycache__/ltx_settings.cpython-311.pyc rename {src/configs => configs}/__pycache__/mochi_settings.cpython-310.pyc (100%) rename {src/configs => configs}/__pycache__/mochi_weights.cpython-310.pyc (100%) rename {src/configs => configs}/aws_settings.py (100%) create mode 100644 configs/ltx_settings.py rename {src/configs => configs}/mochi_settings.py (100%) rename {src/configs => configs}/mochi_weights.py (100%) create mode 100644 scripts/__pycache__/ltx_inference.cpython-311.pyc rename {src/scripts => scripts}/__pycache__/mochi_diffusers.cpython-310.pyc (100%) rename {src/scripts => scripts}/__pycache__/mp4_to_s3_json.cpython-310.pyc (100%) create mode 100644 scripts/__pycache__/mp4_to_s3_json.cpython-311.pyc rename {src/scripts => scripts}/__pycache__/s3_manager.cpython-310.pyc (100%) create mode 100644 scripts/__pycache__/s3_manager.cpython-311.pyc rename {src/scripts => scripts}/convert_mochi_to_diffusers.py (100%) rename {src/scripts => scripts}/download_weights.py (100%) create mode 100644 scripts/ltx_inference.py rename {src/scripts => scripts}/mochi_diffusers.py (100%) rename {src/scripts => scripts}/mp4_to_s3_json.py (100%) rename {src/scripts => scripts}/s3_manager.py (100%) diff --git a/LTX-Video b/LTX-Video deleted file mode 160000 index 23f1048..0000000 --- a/LTX-Video +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 23f104870f57f31f286a2c6970f995c55b6f505f diff --git a/src/api/logs/api.log b/api/logs/api.log similarity index 100% rename from src/api/logs/api.log rename to api/logs/api.log diff --git a/api/logs/ltx_api.log b/api/logs/ltx_api.log new file mode 100644 index 0000000..fef8ad9 --- /dev/null +++ b/api/logs/ltx_api.log @@ -0,0 +1,2 @@ +2024-11-23 19:32:34.895 | ERROR | __main__:main:352 - Server failed to start: LitServer.__init__() got an unexpected keyword argument 'generate_client_file' +2024-11-23 19:33:20.102 | INFO | __main__:main:347 - Starting LTX video generation server on port 8000 diff --git a/api/ltx_serve.py b/api/ltx_serve.py new file mode 100644 index 0000000..03c3783 --- /dev/null +++ b/api/ltx_serve.py @@ -0,0 +1,355 @@ +""" +LitServe API implementation for LTX video generation service. +""" + +import os +import sys +import time +import tempfile +from typing import Dict, Any, List, Union, Optional +from pathlib import Path +from pydantic import BaseModel, Field +from litserve import LitAPI, LitServer, Logger +from loguru import logger +from prometheus_client import ( + CollectorRegistry, + Histogram, + make_asgi_app, + multiprocess +) + +from configs.ltx_settings import LTXVideoSettings +from scripts.ltx_inference import LTXInference +from scripts import mp4_to_s3_json + +# Set up prometheus multiprocess mode +os.environ["PROMETHEUS_MULTIPROC_DIR"] = "/tmp/prometheus_multiproc_dir" +if not os.path.exists("/tmp/prometheus_multiproc_dir"): + os.makedirs("/tmp/prometheus_multiproc_dir") + +# Initialize prometheus registry +registry = CollectorRegistry() +multiprocess.MultiProcessCollector(registry) + +class PrometheusLogger(Logger): + """Custom logger for Prometheus metrics.""" + + def __init__(self): + super().__init__() + self.function_duration = Histogram( + "ltx_request_processing_seconds", + "Time spent processing LTX video request", + ["function_name"], + registry=registry + ) + + def process(self, key: str, value: float) -> None: + """Process and record metric.""" + self.function_duration.labels(function_name=key).observe(value) + +class VideoGenerationRequest(BaseModel): + """Model representing a video generation request.""" + + prompt: str = Field(..., description="Text description of the video to generate") + negative_prompt: Optional[str] = Field( + "worst quality, inconsistent motion, blurry, jittery, distorted", + description="Text description of what to avoid" + ) + num_inference_steps: int = Field( + 40, + ge=1, + le=100, + description="Number of inference steps" + ) + guidance_scale: float = Field( + 3.0, + ge=1.0, + le=20.0, + description="Guidance scale for generation" + ) + height: int = Field( + 480, + ge=256, + le=720, + multiple_of=32, + description="Video height in pixels" + ) + width: int = Field( + 704, + ge=256, + le=1280, + multiple_of=32, + description="Video width in pixels" + ) + num_frames: int = Field( + 121, + ge=1, + le=257, + description="Number of frames to generate" + ) + frame_rate: int = Field( + 25, + ge=1, + le=60, + description="Frames per second for output" + ) + seed: Optional[int] = Field(None, description="Random seed for generation") + +class LTXVideoAPI(LitAPI): + """API for LTX video generation using LitServer.""" + + def setup(self, device: str) -> None: + """Initialize the LTX video generation model.""" + try: + logger.info(f"Initializing LTX video generation on device: {device}") + + # Initialize settings + self.settings = LTXVideoSettings( + device=device, + ckpt_dir=os.environ.get("LTX_CKPT_DIR", "checkpoints"), + ) + + # Initialize inference engine + self.engine = LTXInference(self.settings) + + # Create output directory + self.output_dir = Path("outputs") + self.output_dir.mkdir(parents=True, exist_ok=True) + + logger.info("LTX setup completed successfully") + + except Exception as e: + logger.error(f"Error during LTX setup: {e}") + raise + + def decode_request( + self, + request: Union[Dict[str, Any], List[Dict[str, Any]]] + ) -> List[Dict[str, Any]]: + """Decode and validate the incoming request.""" + try: + # Ensure request is a list + if not isinstance(request, list): + request = [request] + + # Validate each request + validated_requests = [ + VideoGenerationRequest(**req).model_dump() + for req in request + ] + return validated_requests + + except Exception as e: + logger.error(f"Error in decode_request: {e}") + raise + + def batch( + self, + inputs: Union[Dict[str, Any], List[Dict[str, Any]]] + ) -> Dict[str, List[Any]]: + """Prepare inputs for batch processing.""" + try: + # Convert single input to list + if not isinstance(inputs, list): + inputs = [inputs] + + # Get default values + defaults = VideoGenerationRequest().model_dump() + + # Initialize batch dictionary + batched = { + "prompt": [], + "negative_prompt": [], + "num_inference_steps": [], + "guidance_scale": [], + "height": [], + "width": [], + "num_frames": [], + "frame_rate": [], + "seed": [] + } + + # Fill batch dictionary + for input_item in inputs: + for key in batched.keys(): + value = input_item.get(key, defaults.get(key)) + batched[key].append(value) + + return batched + + except Exception as e: + logger.error(f"Error in batch processing: {e}") + raise + + def predict(self, inputs: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """Process inputs and generate videos.""" + results = [] + + try: + for request in inputs: + start_time = time.time() + + try: + # Validate request + generation_request = VideoGenerationRequest(**request) + + # Create temporary directory for output + with tempfile.TemporaryDirectory() as temp_dir: + temp_video_path = Path(temp_dir) / f"ltx_{int(time.time())}.mp4" + + # Update settings with request parameters + self.settings.prompt = generation_request.prompt + self.settings.negative_prompt = generation_request.negative_prompt + self.settings.num_inference_steps = generation_request.num_inference_steps + self.settings.guidance_scale = generation_request.guidance_scale + self.settings.height = generation_request.height + self.settings.width = generation_request.width + self.settings.num_frames = generation_request.num_frames + self.settings.frame_rate = generation_request.frame_rate + self.settings.seed = generation_request.seed + self.settings.output_path = str(temp_video_path) + + # Generate video + logger.info(f"Starting generation for prompt: {generation_request.prompt}") + self.engine.generate() + + end_time = time.time() + generation_time = end_time - start_time + self.log("inference_time", generation_time) + + # Get memory statistics + memory_stats = self.engine.get_memory_stats() + + # Upload to S3 + s3_response = mp4_to_s3_json( + temp_video_path, + f"ltx_{int(time.time())}.mp4" + ) + + result = { + "status": "success", + "video_id": s3_response["video_id"], + "video_url": s3_response["url"], + "prompt": generation_request.prompt, + "generation_params": generation_request.model_dump(), + "time_taken": generation_time, + "memory_usage": memory_stats + } + results.append(result) + + logger.info(f"Generation completed successfully") + + except Exception as e: + logger.error(f"Error in generation: {e}") + results.append({ + "status": "error", + "error": str(e) + }) + + finally: + # Cleanup + if hasattr(self.engine, 'clear_memory'): + self.engine.clear_memory() + + except Exception as e: + logger.error(f"Error in predict method: {e}") + results.append({ + "status": "error", + "error": str(e) + }) + + return results if results else [{"status": "error", "error": "No results generated"}] + + def unbatch(self, outputs: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """Convert batched outputs back to individual results.""" + return outputs + + def encode_response( + self, + output: Union[Dict[str, Any], List[Dict[str, Any]]] + ) -> Dict[str, Any]: + """Encode the output for response.""" + try: + # Handle list output + if isinstance(output, list): + output = output[0] if output else { + "status": "error", + "error": "No output generated" + } + + # Handle error cases + if output.get("status") == "error": + return { + "status": "error", + "error": output.get("error", "Unknown error"), + "item_index": output.get("item_index") + } + + # Return successful response + return { + "status": "success", + "video_id": output.get("video_id"), + "video_url": output.get("video_url"), + "generation_info": { + "prompt": output.get("prompt"), + "parameters": output.get("generation_params", {}) + }, + "performance": { + "time_taken": round(output.get("time_taken", 0), 2), + "memory_usage": output.get("memory_usage", {}) + } + } + + except Exception as e: + logger.error(f"Error in encode_response: {e}") + return { + "status": "error", + "error": str(e) + } + +def main(): + """Main entry point for the API server.""" + # Initialize Prometheus logger + prometheus_logger = PrometheusLogger() + prometheus_logger.mount( + path="/api/v1/metrics", + app=make_asgi_app(registry=registry) + ) + + # Configure logging + logger.remove() + logger.add( + sys.stdout, + format="{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {name}:{function} - {message}", + level="INFO" + ) + logger.add( + "logs/ltx_api.log", + rotation="100 MB", + retention="1 week", + level="DEBUG" + ) + + try: + # Initialize API and server + api = LTXVideoAPI() + server = LitServer( + api, + api_path='/api/v1/video/ltx', + accelerator="auto", + devices="auto", + max_batch_size=1, + track_requests=True, + loggers=prometheus_logger, + ) + + # Start server + logger.info("Starting LTX video generation server on port 8000") + server.run(port=8000) + + except Exception as e: + logger.error(f"Server failed to start: {e}") + sys.exit(1) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/src/api/mochi_serve.py b/api/mochi_serve.py similarity index 100% rename from src/api/mochi_serve.py rename to api/mochi_serve.py diff --git a/src/configs/__pycache__/aws_settings.cpython-310.pyc b/configs/__pycache__/aws_settings.cpython-310.pyc similarity index 100% rename from src/configs/__pycache__/aws_settings.cpython-310.pyc rename to configs/__pycache__/aws_settings.cpython-310.pyc diff --git a/configs/__pycache__/aws_settings.cpython-311.pyc b/configs/__pycache__/aws_settings.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..83c4635748e9f4e89e483f1a174b9a80d61471f5 GIT binary patch literal 726 zcmZuu&2G~`5Z;Y9jT=&x6k4eV6pQVA(Rr;Q^0MW`I5<4+UhOci*NeFq z__blc+(GSa?7VX$u_=36XYCFw(`_CZT@cqP_3fccsmN0b-SI%?1xhbsd$6d4IXIQj zcod3FN^Rc{B0GY`Ii;ZvrEpz>e3mb04v20Z)x#i&>PMdM4Fl)gt2=?;^ZLAQzvq-M z?{;r=8Oj$^ASL)I!1&geJ}4z*d%_a}u{0&6^kF3-yC0t=1Y%`MD(Oyn{A!{lqzpbK z<#e|(KAxBfX@E}^LJ}5NGTo{sila+r`;q7T*-0p0cCnYcGE?naklpC}iz|5^-Uo_d h@D9LC!x*RNaq_p%3Iso$p?_iW2EL z(6(54R|xW;E&QM^j9S3D09)+3TEqqXP!xSCHrV|GW()!z3=q^nu@8Qu;ViJoQ@`&F zFQRC->1g=Px%|%UJLmkq?;QP$*Xw5BnHjy1{MR7E{5M7_kG&jt{dW*}$Ouf75iEi= zWl3A3*0e2ZvtZnoV$=4horc+zBkF)OHq{hug0DU0OuM2kn&(Kl)6LQ5v?uCGd!t^O z)|6^V`=UM?cBcI4)@ZARu`+Xv;QEFU+{E?R#xUQ)&wQc*fxX5=n)6?{FJ?0dF`1KM zikQuC>8y}T5iXIHxX+hwaJNN)WVs~CkXk-Tad$*zotw`klVT=08;cVNuW@2VAxRT9 z=6;FoM68DM(gxNWOBwc+3G1$v7Wqm1PF+3MyGPy<+1!|EaHp4zEl9XWh68Pb7 z$5NsYQ?iobgHDlDEdh$Q%9+@Pyq;BfA$up2%EkmpbfmJ$B$49fG18Bt9C>{o#{Q5Y zOwd9_19(>*-gf}+Z@_yS@U2yR0D9fl(Ce0l67AIz9YV07 z1t0N4uRE)lE+Hgz3q4Q5k6F+e{LClXN^!ka+!3Kq;Dr7LT!6GeeFIh8pfDs13r8Dp z?X2R223|yLlaeuTbOLX^fauF53Bfd!bc7GFs<`r#0bODUfR=7Xe<&vLSyG? zY&6)JloKwH5QL{8+z%^zKyW>_z%0Vge8`|MLtKDzC0Jqh8}>0I ze_P39>SNI%$}e-U%H|KP;$f)aDC`dSOW10w8CqU|mPZ6PaL=cP9~kC35=?~y^5@+g zJoStumkHwk;G{8HG~fQYm|RyRF@8%Phw4XZ{rUETTt(T0TbKrRY7$5ymP;u{vq0py zBvP(2oF?ChLIaZ^f4-gt@sLrfI%60G)636SRTe>puh&bxu0aR3CBp-~DV8R%2i9q}_fQn#G$|B53MuDw^PS*);EtQib$oxuF6oLT(09i>PLjLFyNrKA< z>*+u{xeV|iY;ssQB&Osf(7G>E#_=Z3a4$~X;OA!+=H{0T7lyxBm|mVYS}?e@vdGU~ zom!l^W;i4Q6=dSL&MHq61n7P`2~+rxan+INS5>57*5&uPtxUU~@eVTdV2p8Z*c! zfm*FYjMRItz9xE9LWJPivb<@Azh(sQJn~fM z1Z#9!#VVT{#IT!|bkgip zH&WnnjB%G!1P*EDsHqm{5XSPo^<{xEK zRI08HVUk;lG)rjZ3>FMYRtjUVjlWGKF|o4(`szOTU-ss~I%S@|6B(Kb_RIpPX@ z?+M)hOzk@VwEf?D;RCPcnJRjwRL@k&!T8&s`$nGmMl@ff=!=vXt9ST&fAB%G8lKSn zlSTie>Yvm*dmh|V{X+ovw%%WS?&*8x>DyF9Z0OSd@{=#& z1MlzIKYGoO=DA$-Tvk1ozxVX&UExPwz5B?cQGMiu-rcVUyNkg{soCLedBuSE0Y!<5 z0Mp*}yzQN5ZSQDpqs6w-5@Yf5mhXd{+W(~%yjcw1RQ)&qWDf8H&2zTsIjefk>TR75 zK2km0U;gClf)Tj_HG#_bTn(LldIbc$nrEWunNU3wx~KIo@RaNThp*fRTKsy-eQ}KW z&U=A_&mYFz(;c=yob8(S+y3acL%7a5Ak$Lo@B#u4nIs(1|Jia6oM^$8w1Q+o+LIQ+ z^2m12>SDGm)6AC(_iVsCfH$8lTOO9mJ$8#NRIN!R55M*HLB?wr2k^ah>%bueYVEye z-(t7ySXTkpQNwen9@YX*4Ayl2$o)@+Cdz;4x|bi|KR~QfM3v2)`G2U=#&x>-VSB}p zu2P&tewL$FnC=JD)8l|3{=Xcymd&Onxn(IwI5AOQbD4>1byhnAP9K}g5XG?FjuFEp zgJ+LdLQ=ljJY zia*fO>~e%PS;8U8l;E1+i`|zNK)(@(^cj2%Z$)PafkQ(zYLbZ)Quz#cat00Islaey z7EP3KvXJ!4gdxpBJU`UH5^tcXTY!|GgFvcaLTA*j)9~(iUvvzpLrZG=CB4009sFEv zThzPXRYUKUn%Hi4i2_Rau{44i!GEcQ>pwrqu>Aa)n@@K9k>3!i-6Gx4lvLhJdk*z+Nj z>KNAhhW2v!SF%B@WWyp(poQJj=JvL@LVzm^y{&8a%m;`(Tj&~<5SPB)VHiC8*Bn) zcswE0(yTPiT9;;80^61a**#%{%N902HAmoB1$$q{g=La!91dC84EkPNECGik4rnVU zRUH;il&fA0F1bt4UFkB4Dc#A^GSxc#%#z=ETF4gWBA}| zr|kI9uI2IKz3y6U>=0rP%FA;gfO8y-%lq8f^UT?!IeUxFUe($A0&REUeXT9Bk~IGwG?G3jrhN-2h74F#@t$%z7eFvBJ* za23z7k(}S{f0ux42 zjH7rD#VHU5i>6CDgYoxKe1PICiU|~xD5&Owv%8{qEva2gdiyyj5yTQfED^*KLSu>H;{|A-U2p#c;KK+XM)s{T36KUeh6ZBOqEs_Y!R z8aqd=PT}hQ)YC!r;!U-W*ZTNkAFr~Vdu=KUUtVK*J>0)HQyAPId^)5~FRS4dExb|; zufU*ZRCWbkja`8?4mN3c&sMPQ+ti60YG73htQG^S+p{|tRdy9#ja}8DI~}|7o?KAW zi7(XF8(Qm)V(X3VnVnNAdjnpLy`jUrkiCnAsr^Cq!&NmH)q>GtFuMKeuC1mqin&mD zcR!#`tf;~3TJU-?c-_QYSJ~?tdmVa&GrV_39h+4<=CqEvV#gfLYi;m>PUx{iJu$Ba z7PP=ZF|c6vc0pwqGdwtKlnJ_)0N+1rYPQ*X}Q>>=k%5_6jxkx^v!K zL}8qjFwRQYv3+UxTH}m#cAMP{sccANLwZMOdtpb0NIBl!3yKAm^}?&My(MwO>8B=M*%(LbH-Ckn zet?%)0F^wfeGt5$2Js5heyC-5T=tWNj)Jnk{PeUs@9!Sw+hqn;C_y@Q35j@ zVEKT}O=d#o#k01uo)Xtg+Ume0I1q#LYP4>YZE^JuU24K?XC>XFL!=uCZ!iqsy7@m) zxaKDOdTxy#RO$aAb5f2Tb1Au^)cUu?Ub9+2`EQoHTd!!& K4{yn*bMrq)!A|S| literal 0 HcmV?d00001 diff --git a/src/configs/__pycache__/mochi_settings.cpython-310.pyc b/configs/__pycache__/mochi_settings.cpython-310.pyc similarity index 100% rename from src/configs/__pycache__/mochi_settings.cpython-310.pyc rename to configs/__pycache__/mochi_settings.cpython-310.pyc diff --git a/src/configs/__pycache__/mochi_weights.cpython-310.pyc b/configs/__pycache__/mochi_weights.cpython-310.pyc similarity index 100% rename from src/configs/__pycache__/mochi_weights.cpython-310.pyc rename to configs/__pycache__/mochi_weights.cpython-310.pyc diff --git a/src/configs/aws_settings.py b/configs/aws_settings.py similarity index 100% rename from src/configs/aws_settings.py rename to configs/aws_settings.py diff --git a/configs/ltx_settings.py b/configs/ltx_settings.py new file mode 100644 index 0000000..4df90d8 --- /dev/null +++ b/configs/ltx_settings.py @@ -0,0 +1,161 @@ +""" +Configuration module for LTX video generation model with HuggingFace Hub integration. +""" + +from typing import Optional, Union +from pathlib import Path +import os +import torch +from pydantic_settings import BaseSettings +from pydantic import Field, field_validator, model_validator +from huggingface_hub import snapshot_download +from loguru import logger + +class LTXVideoSettings(BaseSettings): + """ + Configuration settings for LTX video generation model. + """ + + # Model Settings + model_id: str = Field( + default="Lightricks/LTX-Video", + description="HuggingFace model ID" + ) + ckpt_dir: Path = Field( + default_factory=lambda: Path(os.getenv('LTX_CKPT_DIR', 'checkpoints')), + description="Directory containing model checkpoints" + ) + use_auth_token: Optional[str] = Field( + default=None, + description="HuggingFace auth token for private models" + ) + + + input_video_path: Optional[Path] = Field(None, description="Path to input video file") + input_image_path: Optional[Path] = Field(None, description="Path to input image file") + output_path: Optional[Path] = Field( + default_factory=lambda: Path("outputs"), + description="Path to save output files" + ) + + # Generation Settings + seed: int = Field(171198, description="Random seed for generation") + num_inference_steps: int = Field(40, ge=1, le=100, description="Number of inference steps") + guidance_scale: float = Field(3.0, ge=1.0, le=20.0, description="Guidance scale") + + # Video Parameters + height: int = Field(480, ge=256, le=720, description="Height of output video frames") + width: int = Field(704, ge=256, le=1280, description="Width of output video frames") + num_frames: int = Field(121, ge=1, le=257, description="Number of frames to generate") + frame_rate: int = Field(25, ge=1, le=60, description="Frame rate of output video") + + # Model Settings + bfloat16: bool = Field(False, description="Use bfloat16 precision") + device: str = Field("cuda", description="Device to run inference on") + + # Prompt Settings + prompt: Optional[str] = Field(None, description="Text prompt for generation") + negative_prompt: str = Field( + "worst quality, inconsistent motion, blurry, jittery, distorted", + description="Negative prompt for undesired features" + ) + + # Constants + MAX_HEIGHT: int = 720 + MAX_WIDTH: int = 1280 + MAX_NUM_FRAMES: int = 257 + + def download_model(self) -> Path: + """ + Download model from HuggingFace Hub if not already present. + + Returns: + Path: Path to the model checkpoint directory + """ + try: + logger.info(f"Checking for model in {self.ckpt_dir}") + + # Check if model files already exist + if self._verify_model_files(): + logger.info("Model files already present") + return self.ckpt_dir + + # Create checkpoint directory if it doesn't exist + self.ckpt_dir.mkdir(parents=True, exist_ok=True) + + # Download model from HuggingFace + logger.info(f"Downloading model {self.model_id} to {self.ckpt_dir}") + snapshot_download( + repo_id=self.model_id, + local_dir=self.ckpt_dir, + local_dir_use_symlinks=False, + repo_type='model', + token=self.use_auth_token + ) + + # Verify downloaded files + if not self._verify_model_files(): + raise ValueError("Model download appears incomplete. Please check the files.") + + logger.info("Model downloaded successfully") + return self.ckpt_dir + + except Exception as e: + logger.error(f"Error downloading model: {e}") + raise + + def _verify_model_files(self) -> bool: + """ + Verify that all required model files are present. + + Returns: + bool: True if all required files are present + """ + required_dirs = ['unet', 'vae', 'scheduler'] + required_files = { + 'unet': ['config.json', 'unet_diffusion_pytorch_model.safetensors'], + 'vae': ['config.json', 'vae_diffusion_pytorch_model.safetensors'], + 'scheduler': ['scheduler_config.json'] + } + + try: + # Check for required directories + for dir_name in required_dirs: + dir_path = self.ckpt_dir / dir_name + if not dir_path.is_dir(): + return False + + # Check for required files in each directory + for file_name in required_files[dir_name]: + if not (dir_path / file_name).is_file(): + return False + + return True + + except Exception: + return False + + @field_validator("ckpt_dir") + @classmethod + def validate_ckpt_dir(cls, v: Path) -> Path: + """Convert checkpoint directory to Path and create if needed.""" + return Path(v) + + # Other validators remain the same... + + class Config: + """Pydantic configuration.""" + env_prefix = "LTX_" + arbitrary_types_allowed = True + validate_assignment = True + + def get_model_paths(self) -> tuple[Path, Path, Path]: + """Get paths to model components after ensuring model is downloaded.""" + # Ensure model is downloaded + self.download_model() + + unet_dir = self.ckpt_dir / "unet" + vae_dir = self.ckpt_dir / "vae" + scheduler_dir = self.ckpt_dir / "scheduler" + + return unet_dir, vae_dir, scheduler_dir \ No newline at end of file diff --git a/src/configs/mochi_settings.py b/configs/mochi_settings.py similarity index 100% rename from src/configs/mochi_settings.py rename to configs/mochi_settings.py diff --git a/src/configs/mochi_weights.py b/configs/mochi_weights.py similarity index 100% rename from src/configs/mochi_weights.py rename to configs/mochi_weights.py diff --git a/pyproject.toml b/pyproject.toml index 1f8b0f1..6584f5b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,6 +21,14 @@ dependencies = [ "boto3", ] +[build-system] +requires = ["setuptools>=61.0"] +build-backend = "setuptools.build_meta" + +[tool.setuptools] +packages = ["ltx", "scripts" , "configs" ,"api" ] + + [tool.ruff] line-length = 120 \ No newline at end of file diff --git a/scripts/__pycache__/ltx_inference.cpython-311.pyc b/scripts/__pycache__/ltx_inference.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..82a47c335aacda5fd6ea4fed5cfebd3e5dce5741 GIT binary patch literal 18772 zcmb7seQX<7p4jkv_$8A1B1MT7^({)VY}u0RvmM*A{QV(2v12DnvINZ;*_0_#nV}qq zOYJ(>Y|93|i#G6`b+_3bZ|lu!8z0WS=33u1=z5DSus~aMh8Ua%(;6^P1b6w5X%b)? z6uAC=Z^$8sq<4M8ncuv5^WJaXd-Hyu?>GFJ+wGuWqa#PIU7Vt*|BV-w&0J1=l{8S) zJ&K_gD8|4TlZFKYIU5&@a5g1P%ccdBLA_^QFhia>X<4=|SV^8GX<$;9(17)OG^U`3Y%k0p{ z(AQ7kP`8C)a(|?94P{3w=?5+PvWV*h#dzMM81KhMzygQ5EkqdKd(^@hJH)i!Ai10i z@Q!^9w?C;AqM%*Ii{yy3Wq;u~m0DyuHWg>X%V}mc$%YrxT=?vzSHm|G44V#LV^b^_ z%OujN35Vm-^#orjw!)>c1Ru^^hpMqeDxAcCY;8^7o zscQsPJe^uhTw7J28D`kUM2ckqH@V5PnG6)HG-AT>1$JvbVo+!%mSHoAWfs!r3$e^~ z%saoLwgS?YQ_HbyY{aOvT$( zlZKJOvT&#E;ycUBY=%q30e$>>VlhFoTQ7036c2;8%yPSr5}Z)j3!No6E+kgiBy<7t zTIbR!CV?%2E}UaC*V9aG$&2N_L7pYa-ei+;?Ug~IrF`+#CLt1*0^*>MFcwUVku@(G z8Pj{_1q(CESmA7CY^;qjv-F~wq2b!j*vYj6t{rgggnKTicE$%sD;#ZbnAr}tUHvVZYw9_eb|~M;bkyQyI-8ef z{LgwmV`eOiG!p<^ex{2JENPEmg7EZi>>nx710{k?sA-Acj@`j*gSy>ziit2`wukA1 z1M&95*$iXkUv| z*%$+ilQH|RNW*Y+PNjfGfOBM7Lh@CvQt3nZ?L_8!_*&v7n+of=h1xm9p&^^)uJQY{ zE4-!F*oQBHyn_na=~0E#SC`m$Myoj=F>|fZUZoXS(CSJQXr4(gM|qZIc%?m>(ApfW zv{IofDPTHgo=q+i1UwE;IQ;GC4lbR}>^KKCm(##acJOg7v6A6;Br~@FU706Oth}Sp zQGz}i&34qzrHM*53Jo4)Ds`JGwNU<0-h9U?)2=PLw@CMjbe}}`33Q(v=-CQP76X%F zV7C<54Tao!_U25a{+){JAiDjZhAG&^Icf!UjVpGeNMOB zajqY|dq!#>6zL&}9unxGC$!^^{ezbGTh=pozbpAiM0!-BM+JKH3GJ5Xu8qNu#--qx zNRLbOxIm9TaeD9U+j907o&BP7P;w5ID1&|aiOYBA^p>l?=;|*_-QO<_>=s>9l50w+ z>?ISFB1O3%v_-MqjImKB!RdGzzEv`JACtM814T0Hq^LC`q>LnGs-4d{6Jye=Qq$C$ zIiv9?m?zpS&zbeJd(8rJ2Y&kc1}fzope8Ba{I>Cy={4$YgMoUDGQclWGtZVPsxOtZ z)M2$|h{6p0fbDn9(wwobwk>BZ)XY>B3&NaDUlQ8ms@`4FsH#x&^y>XOYG{Y2Xf>*( z>ZQV(70SWrH*UITYwggt5Tt`88$JtSkI1ihGYVJ^d!;xRIbiW>am87K zgA3zrJC=+tCX%e;AxS<1Dkx0Fc!ukT0^9(WGG@{WeRUBEWF}`KG@(2W?EuA%Esoz< z$wXH`T~$2fk|@7whG9{0RPz)Qq?H!Ts6L0scMLL(Y=n{(dz7>V%4Ius^sMFYgF5{G z1#p{s8t4`J&xnCDdAA(wE1VO9lX=hQZr}Qj!fw$$BDqJ9(1tfBpLB&j>MBhBNzX?; z`O`>&!{i15Ut*E35mfVKfbXnksfp zJ$zH_nv=Tb@+~s$+@iaSboa)*#zH=C#zI1JYsbCO4@ZAEc6Tg4Cwp2yxcUCgKVt8u zM0kz|85Qq=-v6lMfax>yEM@zp$pBEHS5s`Jfsr7A5=OG`xfn@R1m;ZoX$6TXU9zCeono+Lp62rUwm7N+-@V#?+j3G!Vb#c|_)#C5MWzCEItl zISIcq4?>TSX?C>y3VtUMbHmWDOJFEjd9=qe5j%(cMRAh3uTm^G4msQl2*L=qBbY#- z5og4vlH*>;Ra(fytMuqquWJahHB+_pG`3Kar+7$}8uG*u)lsKIr9v#1^*13;ym1;T z@%I1#eRlN;gY#n7JWT)Yfx@!by<4CIpNDoxq5Z;vvtsC+6gnr+!6%)gQs_<>nVw}m zG{`qIl*A`CY&k<6*;$~ZMntv68wk{zsgY5P;XyrV*)r~G{Uwc97!zZDV9`teH9F(1 zDXYoI=aQ-&Il8&DWUO_4qcfTf3@oL#MYe_(*=yg*lGYP$DrZ@$_!cN#tF8c8;6_ z*pfZFx6a&*rb zcL<(P#!tLW+%d@DPGGK`k6&leUB_|L*bc911$eH33?s~x=ny=&2W!k?4Ufuts`XIA z<8aa_9-`>1fDy=n9}g^wFy7MGB)0=Ek;u8702B+8d1r;?P+yBUs~ohf)o2euSxOqG zn31hYsZ12REo)HqKjnpxSeB7%8%6~HXVcRbj?HGaG} zeq07mUTv3r-)y(@qA^Y&h;eFwqP2Ut1~O9x`}5Zd*Ix1@JY3JpD~;)7Fj4{@#1v{qVbm8PPu``N#4npL+ax58?6LJPfYR ztU^YB$5%-J;HgJ1J+%z18yr92p;ry7oUyqJsDa^{p=5I=UT0zfp_-p*0%-G%1L1=N z+4(X!ZjfG>U{-f=5LFrZs&;(C>11yRh?w z7(OA;A-Q{UtNX=b_lv^8i(>aBsrwQr2(o8r%QIQ@ObWY>h@LsgGbhkkXm+aJ?DLtqq2-C1=e-uk@{5qjQ!uJJhDLz#d^n zBW4FMSha5_EQ4vob)Of!{37@;tc8 zS5`CO1TFxD6N};U$`v!Om|i|{GOZ6aFRi6Zh znvr6Oaa`;j#Rigfc{Rf-7UDBd?5ipMtyPxIvJp4oKZ=F4Kyhbc97_B016}1y^|?;9 zt?GqKy(kxKi}5SKAi;wTD^XFkmeQ)rT;yWQtYS~GZ%0)O4osHo+isx)0=jB@r<`z+ z^3k$`1=U8v>ncF?OLz%`s~N&^Ud~ z4oQJS`BNZa^QUEhf8jOJKk;xpP3RgC-J+g}vvyb=udO@7KB6PnhdS8>guR#Z( zjO#AhCaS2_RZ{1-NM3>Ifd0dXc@5V6j{AOnr_%?`%x7;x?Jt~Ie5ayAaFh*LI+WS zU^2Yhm1(P@UDDV&tg@m1*z&+ySI3%y72P%KGnhH*&E>6X%M~=XUb%;7{YY^}#+EWL z^aFbxi*3s4bv@KryuRj|4Q4}AHm<9UGl*=uTy0Erxwsn9SQagU#$0VAayG_+V-Lc+ zX4Wlf!Vnx_PR5zDzU#`M6;s}-tH>#*E+5wx$fm=l2~ijI!1XbV$S2yUEcH}sfD%C! zLY1r9GuXT^7xbbN!ZA!By4>0D%d2XT0|}<6#$9M@x7nV`axsKTkTp;S;vG`x@8WnQ6S@rtPDN*B7n z@qzK{u~dppMpF=WL6(Y@Ht4w)ToHvh4T!E#Jj;n&ECbDB;|Uz5$Tp6{XX9j4I#7C+ zn*~xiom|Dx3^jWt#xPK^rghe2`r0*?Q_Qe7n^sz*(82g>5*_|n7FOiBYOf5(6%*vD zLf&3}X=&BMC#wQnAGTr}0-V2G4#66N?;-d;g4+n*Lx2)haUMVQ^2JNh<1f#iJEpi_ znV&m9fArL)Q|ISr&#G3rm8b>0;*Rn$cztTec<_uO1t_gG1B&AT%~RZ}L?<0!8O23p zedWqm?RzE5R_W;vuvPCOxC?+s7fnUjtJb?hm)|^Fw`N_*{u$Q$j7WD7%iH?5Lc5Bg zU1DfT3Qd6+hP;vS`!}|RPZWnwh{LC(;Zq_#kaw@=REU>t|ml!QKA*`wXJ*($*wNh z75LI_X#<-XK;BmJQm$=+Gx8+Z{n3yR-u3Xr<1yjVWg&P)3|^6fSMq1_XPySP<8;eWM-C6(cS_+ya{uuCajAb!?(Rk1XBusN zF#X%Ya_ctPKPvmXap85l92k&0xEOmpVUO^_L6JTr(J6NlO~cO5{y#(<mSooJuovYyIafrtNG^I&}h zSKN;gXl+K<8&m0*9n-2djC2sqfHF_+UrlE+>1BxafXpPjm{CYc4y}OZo~=KGht#dL zH*%AwP>4SV0D?NaROdk9O|fHF-XXV#3kSsZ9eMlH)}Dglhdc5%xudr*B6dXJW*cT5 z%+qpvs6dPDqfp8l*oc6Zxg>agPoRHKb?YF}a({~8j{(FrJq0doaF921yR7FeL58|j zH8QTN)*S%yu~MW)?&@`f)O2kP#?X8XT(7A-LyN03cR`D*%ULt$%*~g;z;JD{1eP_%ks9Q>0J8$B95wd&eV(xV=srpZhNM# z>ZmObT*X|krQP`%)YY`uN?Y&@90*-Tby~lHy7M?QTrF47wAyMOxR_iyv_$IWFzAFd z(FrFmf-JAZ;6fN|S*nL)yc+PwV@$*j2-sDHf0iy6*pa1&UfVvjynTqd1X>_)EC}(< zfGWk_O7Nid+#n8*FObx~K>D92z7N6@zxYSE&ic+1eT@&JGQos1>2Rb&RhufB90o0L zSeHXA7<RlU;#j}@YiFYvHder2G&nOhl*pRcMvcn7c>nBbF;>H zq8zHukX3M*>;*N8=rR}^oMcr!QE{qyL?PtvAeKJ>P+5r3^ESnyjt5zaxQ}Ijir^;* z{xtx_kWgrXDVkbo&aM9lie-a3SA|Z&`af`=|2tKaX{WsH=&2Y6aqm8D_=BVC?RUE+ z->B${NUn(BihSXt=EB5Vsv`_3iF7+LUTd<41Q1l~Z6CEV`)P9lP2@4#NjVpyEd2r%k-%rjL&aW@6zrOzZ zqt4HI9>+fG1)Q6jW` z-k`$#tn3eM#O}Vcl<&vzu={?5nsc_7?qnMgKm@zi-QbxadDD`j1HdBU}EH zMgK|Be_HaN&YzS6;ll1eY1#5m6#Wx_!9BPo`u9ryz4?>SYj5Dr_cmgZCtR2X7lz;& zlD+L)-jSkrWRrgAd*FQd(j$lHos+zC0@E(A6wK@RRaD|*jK-g5%kHPvxy z(;>7^!LE%k406R40=_un5wj#x410_F5Ae!$QWJF>B!aJSK0n(#qIw{HY(yV~p1V9d z)(osg(1NNV{$*>?ob4r?Ar@9?BTO2IY#|15_-L3ZZgQ#^INzElR#MmIBPIw|MZ0G; zl>o0IiC#*fv(H9KC(@XBn*+P6vfA;t&?W9aBEW7do-#pU*b`Q)U)uO9=tVW*(V*t+K20OPjeHR8jzW&o@8_-aDQR&t`{6?~v#n zLSDIt|kb8RU(ehSAeUB-H3Jpkty>=*PyC8Quv=fv@Hn5zL^Gni}4lwYeq( z`f06Jr0Kr(0XPt4Xv-NuPWwTdur=3OagD{F^4(xyM>$?B|KvR z9mKk5s?i8g=WEy)8)zz4{VaW+u%Lfh1R+BfJQxQqqo~p!^p$6izYF3t-Fh7bq6KQ& zzXW*H4P|nnscw?dJUg`uA^I_{A6Ja!C-M9(SydtvnV}LL5y9bf`AZXwB3IoIWioUo zmal=afDVjkdG)`1l{HT+uS|g}#{{a0A;mk9nH0+hborvNGf z6-lc~|HNWY+*f0Km3MIe8H#|52^UZ~WEYBuW6SaD>S~G7sH|*N5IOj&8yz6rbswslTxt$k3NT|zJbDQ;Vt}a_7zS@ zz8AK9Q$^nt$T!2INzwPB&&hTfp(r-vaXf8yDqnzjI#bnRw{g z+IhUV^Z0+^esxRSc|qEFL3F(&xn2?~d&y6^#-PDG>c+op>7Jc2{c^@UyU+H^7YzUv zN-^VfQA`l{0`YpQi;Ia{I$n)5OL+B3c(c!zo;p7)aCLQCfyiz>?V`7xc=#;@*4U)g zh3kep<_bQz15s#I8Btw|Zn%RohHM}~oZzXi@u1f)U;>991x|0uS;>=aKcZ0UClW_p zrUtw7dAKVN3)A@7;zjCCq}Cn;NdRM4T`>K9L&TaLE;}CCa4f8f`*1~S6K_b_RDgB9 zbj5*dh;kX|QsN)rePpNS(yK`(oJwb)9(-Zch$Xg|ftAe3ojZ4iV;OMO!sja!5r@)^ zEDBdc%V?4PRi3PJEaLAZ?oZ8WN-PBK-@?-sD=|`d#d92#jALl9@W(hV&6RzboB)M5 z5do?sToLOyRChJ_r#bGwV(uf%r7MqZ;-e%Svp&M8iE1>9Oy~arHq{do-0GMvc1%BH zKD{B$oD^oR{Q8X8@v7AEDyZ7s5VJ#d9W|=3FjjlDnR)PCX~&`3tjB}D8kde-s?8Fv zye?h2DqXp*%e;}6ZruEuqLPLqW!&oRy!&%s=X&Oc-FLf7l*t|@0QK*_!kRF60``Y* z0eCbF;PHT(Bl=H(g5B1A@6d;b#I`}H4L`eZhan7T-Ge`5X<_hZdbSDMUJ+X^OD&fL z`Z7_ozQ8x8m@$G)u`S1Lfx9cE3lGxt;6XykjG+J(8aKxH9;_f9t@9G-XD|eK7z==< zant)2)=bwR^aEDPAov5Q>vVagM5wxxw1Uokh-Xi2a4O@LpSRMTB zn8C}jX1Q|&G<0*J0UKu2e>nmcz9l;{r_u^WB>(||ck&j1G*BBjE)$GKvfTdw47#XW zZ1rw9zCwrM8OlI-=bMK~xCrDlZu*hq+aavvIY#&|$Do|XZ)>sY@m;3w89+CkWq>Y-vO~5a>$3~ zev$jz+~0ifv+t?bF7F4Y-#`5?&)zvpuB&zHQmWm7jsMoL9b(ALOX)-koEW7!b0W%qT7KRStbLNL1G_as+R~ z`WOUWs5y<_dT@USIJ2YAJ5m>MSWwA>A+b`OfqPdzyi)CzJcNo&2p1_rxJW@#otWx$ z=ou6|gP?5ObARaGU_>u0M+^ww0ol9lrvnA%&m;FEqIZ0~WxWN0Bc8%H8@JxS zwP6)Ky^^O_pnJ)1M~qxIG#|e9Cd2@z{|T3oFn%G4wrDg8+A&7ILRc)MUD4=U@EKhB zjx8Ex((!1N!$ooq%^&j7o9e;B6)ZJW7L^eaFH6E(Ib3t*&LBYErY`7NFOe^a)Rn{^ zVfJ?bfCFq50*5&yMh=}uYE+Lc#gMoj4h>@t)mz1cG78mo#f6^=B;x8~C3&0NpJ6ey zP}E34Tvz81gyG=-8_c)c+OK3ZTj&ynz-6~g3B5Z?6rAfpr_}fc+U;Hn-R++#TmH5(vj@LKjg^#@B9IK#w*lnqlf73eP6G3CHKC#GDK!;|ON=Qi3l z;)177viFzVm`~}!;IL%|=G|Ra$pR_B2@WoPW3JFvhztIRN4v@C*GI`oc;z*6t_Q6>mX{4rTifrBl`NQq zWg1oFoT2*-~4(t9D)l3Hbl6>;^DS)_%9Z~gSmIY zE5Ob{??umgV=Vnw2++t^d^I05PLPi!s~<}#Z6sqNm71vj zm4{+Xr65Q{`)d#G14LuGaOx}xgaqeC_46p6d{{uf8Bq2?c=bdz#F~7az)!@g|EBNkz2|5kaAh=qo@dNfNqM@zlUOD2Jq|!5r;y2ma)QCk=xSt;v6ykjaHHGnum_ z)P(yI-xaQKC{D1T6LHWKp^{=gHGlj(5ej5()rpL6=>w_+Jji)r*1=Bk{~aLc^#()9 zXfc?983Smc46fTw{K=I2Hu=d^NYL*x)pDEsWGW!&cbV!I^m|R=R)H$-wH0N`b({QT zs#kc%E>j(XewV2!q1i4|ZwmTdrY;NmU8d#*{Vr1{1pO{k#|8Z^QpPnZAAW-Q({azfFGNzcB|Pn=pm{56i+>djJ3c literal 0 HcmV?d00001 diff --git a/src/scripts/__pycache__/mochi_diffusers.cpython-310.pyc b/scripts/__pycache__/mochi_diffusers.cpython-310.pyc similarity index 100% rename from src/scripts/__pycache__/mochi_diffusers.cpython-310.pyc rename to scripts/__pycache__/mochi_diffusers.cpython-310.pyc diff --git a/src/scripts/__pycache__/mp4_to_s3_json.cpython-310.pyc b/scripts/__pycache__/mp4_to_s3_json.cpython-310.pyc similarity index 100% rename from src/scripts/__pycache__/mp4_to_s3_json.cpython-310.pyc rename to scripts/__pycache__/mp4_to_s3_json.cpython-310.pyc diff --git a/scripts/__pycache__/mp4_to_s3_json.cpython-311.pyc b/scripts/__pycache__/mp4_to_s3_json.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3c6b8f9ba6f530e07e617f1fd0dca2624df96e56 GIT binary patch literal 1352 zcmZ`3O-~a+bar>zLMuXxQWXvp!hv2|4E12sn3R|x9|c3@U@y(m9mB`GUOoR^DoRu)vrDvT;cr3AB(ok6T@!j2QogFTuT zd)K|1?+wiQG&utzeZTl|L~@2Q(`nUE3z%du+0yfPi3xo(+Z0=K%T8cR*RYve<6EbH zT`%ICp_Z`YC79a8$jhA|jPii20C4e}p*=t)M%qBRybYVK%>k1fb;RI+=^wc^J2LgT zN1#`8N2w_LZh2Q)092%Zyg;Vu>DK7*6;Oc(ER!QQ8woJoPQyEDf>xj=eUW>$cRUrj z;?a~3yqQp)Vu||^gVl1;R5hd;DE%~!ME4LiQL3cAHVu@SW+xhIH4mw1KC`@p%(YFN zr|6|lHxS*xZKJuzLe2G-E(%3{lzQ5)XivnZjl6f=6cdpsv^}w zVu(@K)Ey?8v?ys?@ZlnJUF5+P+Cf*uw7cSJdY&dws;7XwX;4)+bfa*+3|?1vDLi-m z^6}O^Okpqp-9bXxmKU;W$qCEpqM_uz`t6an3gM7dFnhg*X9Q zbYDoEVOOseILEzu3BL`5kD2gi=LD?jT+0d*XXu(o|CO!9jdcWrk_*po48crTRuZL$F;n^Cz)r5m*aP$<8zF(+C9@b%^0TVTt zXa>j5f>Wo#soMS2iCqg$)r09qFkPK*-a&`=YJoUkFn+<}viI+V+BrvAvZdBPfM_ro zVCRmkn}k0xa)a+6pSMW;Gaz9uzRl2OU$SIsb`d`!V{FUcu2p8}T#+QH31YS0)dYd6 l_?n=@wX=zd4J0Oa$reaklkiVLDU@<;FgcQP+zw7LbASnNgAI}ac2>qK?nni2~uS+1@MJ6)29GdYtH^X^#eunqB zFe3oxbK;ydBXPcsJR^I$&`b!pkW=QuGhq(#2#1$p#D~rIpqU7dP9jshg-mIg2Ms)a zndvaOG)l@ZS#>JOQ8{JjFIicTI>vsJPGi@#@>$1dn0qYLsJxy9;V#06O%pmwDPqZ#uy{!}|`%}api^*wK1Zi{XYXXNT3@N>6gW%C)gNU(kt6UVai zNs-Eiw+ys9bIs8+hJhVNzk+Y*mPxmiNtZPAXKoXuI|ZjFYo) z-X+X0RB;;1NxNZf^|4fHJe}50jbG3wPta|wDLtM#H-5fFLnEyD-1y1KPp1gr7r}%@ z2j?z1%wyWS!TrRx-TkMnyftSVm#zJdL9Bx7>~}`=Ie)446>d;P*TKTo_4&m5CEV9k z#bH2a49EgnjrH6auBlsAI041!)hU4c&yU$S!8^9dF3^#YZ_YekYbm_M3@t(=FOzk+ZVg;5DH zWcy_)Kt{8LaX?Xk4TNTc(Mp`j|A{a0rT{1ai-&>LO`#;D(JtgNCpOA%L_WMJx*g3r zj4PGIKcf45J0>}F0g!W3{6b{-@uh3rHFRCLfUa>#d5T7|00#TwEH>N+Je7|ZjVsut z!YO=%3Ma?U1EAED2ZYnuy=uY|A>us3(vYG5RD)%gFyhxGE^Wta7(~97B>JN05Ry_7McPOwzUQUY90pl z0%2QbXouiU-V_1C+)@zcT4w@T1-8jr5H;mQpuQ8nY?Mw(U5W3hIsE#*>M3mSsKkahGpr+LID#a z<6tVAR@QPIDmjIm)rKtE=?|hex^{Un0bh4g*9p$#yg;!Ydx-#5#k}=-5$i!n zncuTH0@vFP!-qL6tSq2f?~ZRm_d=^`?Do|C(Q4npzkmF<)0Mulr{~J*L`9vbsuQmS zJ^?We$YN(*Mmq);yK0G@-wfOvScz8>d#j1PE17EI0JO%max?|^(lMVu{>bNLB}IDS zb#0N3Ex8xz&Vg|E2gJG3W^s;RhH)MS(Bn{vPooFiR8sTk1Fcz&hT)vOK)|KZ_>1Vk^XS00{GTJ=k5-3{J~>qxny3yqty`Wwttw*d@m2Z*C!?yLY|88ZYKGv0IOXD{M| z&*OtDpFK&G^=#Q$=q4(v)GhbQd*crH=UiCT9ih8!Ho-M0q z8|iVAbUL)cm19G2SJdIEI$TzVH&3VRLf&7vYwlT>mM%ZdR zKpyVB^9f;D(IYVsoc{t@2rh7<6sekzMR;1pgR2>0(9QW9!#CHMv`lL!G znG;R|p}g&oC=;UWFT}73-sk_%c>f_}lo_&EL$~_P&OZO2#{bzQBGa=Yr0>^hCb3-R z`TCvm3xDB0A{U^I9nYPwfz$<#<7%jL!MkcGyx?6eg1X>cHMFzbez)`rx(}C8+g* None: + """Set random seeds for reproducibility""" + random.seed(self.config.seed) + np.random.seed(self.config.seed) + torch.manual_seed(self.config.seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(self.config.seed) + + def _load_vae(self, vae_dir: Path) -> CausalVideoAutoencoder: + """Load and configure the VAE model""" + vae_ckpt_path = vae_dir / "vae_diffusion_pytorch_model.safetensors" + vae_config_path = vae_dir / "config.json" + + with open(vae_config_path, "r") as f: + vae_config = json.load(f) + + vae = CausalVideoAutoencoder.from_config(vae_config) + vae_state_dict = safetensors.torch.load_file(vae_ckpt_path) + vae.load_state_dict(vae_state_dict) + + if torch.cuda.is_available(): + vae = vae.cuda() + return vae.to(torch.bfloat16) + + def _load_unet(self, unet_dir: Path) -> Transformer3DModel: + """Load and configure the UNet model""" + unet_ckpt_path = unet_dir / "unet_diffusion_pytorch_model.safetensors" + unet_config_path = unet_dir / "config.json" + + transformer_config = Transformer3DModel.load_config(unet_config_path) + transformer = Transformer3DModel.from_config(transformer_config) + unet_state_dict = safetensors.torch.load_file(unet_ckpt_path) + transformer.load_state_dict(unet_state_dict, strict=True) + + if torch.cuda.is_available(): + transformer = transformer.cuda() + return transformer + + def _initialize_pipeline(self) -> LTXVideoPipeline: + """Initialize the complete LTX pipeline with all components""" + unet_dir, vae_dir, scheduler_dir = self.config.get_model_paths() + + # Load models + vae = self._load_vae(vae_dir) + unet = self._load_unet(unet_dir) + scheduler = self._load_scheduler(scheduler_dir) + patchifier = SymmetricPatchifier(patch_size=1) + + # Load text encoder and tokenizer + text_encoder = T5EncoderModel.from_pretrained( + "PixArt-alpha/PixArt-XL-2-1024-MS", + subfolder="text_encoder" + ) + if torch.cuda.is_available(): + text_encoder = text_encoder.to("cuda") + + tokenizer = T5Tokenizer.from_pretrained( + "PixArt-alpha/PixArt-XL-2-1024-MS", + subfolder="tokenizer" + ) + + if self.config.bfloat16 and unet.dtype != torch.bfloat16: + unet = unet.to(torch.bfloat16) + + # Initialize pipeline with all components + pipeline = LTXVideoPipeline( + transformer=unet, + patchifier=patchifier, + text_encoder=text_encoder, + tokenizer=tokenizer, + scheduler=scheduler, + vae=vae, + ) + + if torch.cuda.is_available(): + pipeline = pipeline.to("cuda") + + return pipeline + + def _load_scheduler(self, scheduler_dir: Path): + """Load and configure the scheduler""" + from ltx_video.schedulers.rf import RectifiedFlowScheduler + scheduler_config_path = scheduler_dir / "scheduler_config.json" + scheduler_config = RectifiedFlowScheduler.load_config(scheduler_config_path) + return RectifiedFlowScheduler.from_config(scheduler_config) + + def load_input_image(self) -> Optional[torch.Tensor]: + """Load and preprocess input image if provided""" + if not self.config.input_image_path: + return None + + image = Image.open(self.config.input_image_path).convert("RGB") + target_height, target_width = self.config.height, self.config.width + + # Calculate aspect ratio and resize + input_width, input_height = image.size + aspect_ratio_target = target_width / target_height + aspect_ratio_frame = input_width / input_height + + if aspect_ratio_frame > aspect_ratio_target: + new_width = int(input_height * aspect_ratio_target) + new_height = input_height + x_start = (input_width - new_width) // 2 + y_start = 0 + else: + new_width = input_width + new_height = int(input_width / aspect_ratio_target) + x_start = 0 + y_start = (input_height - new_height) // 2 + + image = image.crop((x_start, y_start, x_start + new_width, y_start + new_height)) + image = image.resize((target_width, target_height)) + + frame_tensor = torch.tensor(np.array(image)).permute(2, 0, 1).float() + frame_tensor = (frame_tensor / 127.5) - 1.0 + return frame_tensor.unsqueeze(0).unsqueeze(2) + + def generate(self) -> None: + """Run the main generation pipeline""" + # Load input image if provided + media_items_prepad = self.load_input_image() + + # Calculate dimensions + height_padded, width_padded = self.config.get_output_resolution() + num_frames_padded = self.config.get_padded_num_frames() + + logger.info(f"Generating with dimensions: {height_padded}x{width_padded}x{num_frames_padded}") + + # Calculate padding + padding = self._calculate_padding( + self.config.height, + self.config.width, + height_padded, + width_padded + ) + + # Pad input media if present + if media_items_prepad is not None: + media_items = F.pad(media_items_prepad, padding, mode="constant", value=-1) + else: + media_items = None + + # Prepare generation parameters + generator = torch.Generator( + device="cuda" if torch.cuda.is_available() else "cpu" + ).manual_seed(self.config.seed) + + # Run pipeline + images = self.pipeline( + prompt=self.config.prompt, + negative_prompt=self.config.negative_prompt, + num_inference_steps=self.config.num_inference_steps, + num_images_per_prompt=self.config.num_images_per_prompt, + guidance_scale=self.config.guidance_scale, + generator=generator, + output_type="pt", + height=height_padded, + width=width_padded, + num_frames=num_frames_padded, + frame_rate=self.config.frame_rate, + media_items=media_items, + is_video=True, + vae_per_channel_normalize=True, + conditioning_method=( + ConditioningMethod.FIRST_FRAME + if media_items is not None + else ConditioningMethod.UNCONDITIONAL + ), + mixed_precision=not self.config.bfloat16, + ).images + + # Process and save outputs + self._save_outputs(images, padding, media_items_prepad) + + def _calculate_padding( + self, + source_height: int, + source_width: int, + target_height: int, + target_width: int + ) -> tuple[int, int, int, int]: + """Calculate padding values for input tensors""" + pad_height = target_height - source_height + pad_width = target_width - source_width + + pad_top = pad_height // 2 + pad_bottom = pad_height - pad_top + pad_left = pad_width // 2 + pad_right = pad_width - pad_left + + return (pad_left, pad_right, pad_top, pad_bottom) + + def _save_outputs( + self, + images: torch.Tensor, + padding: tuple[int, int, int, int], + media_items_prepad: Optional[torch.Tensor] + ) -> None: + """Save generated outputs as videos and/or images""" + + output_dir = ( + Path(self.config.output_path) + if self.config.output_path + else Path(f"outputs/{datetime.today().strftime('%Y-%m-%d')}") + ) + output_dir.mkdir(parents=True, exist_ok=True) + + # Crop padding + pad_left, pad_right, pad_top, pad_bottom = padding + pad_bottom = -pad_bottom if pad_bottom != 0 else images.shape[3] + pad_right = -pad_right if pad_right != 0 else images.shape[4] + + images = images[ + :, :, + :self.config.num_frames, + pad_top:pad_bottom, + pad_left:pad_right + ] + + # Save each generated sequence + for i in range(images.shape[0]): + video_np = images[i].permute(1, 2, 3, 0).cpu().float().numpy() + video_np = (video_np * 255).astype(np.uint8) + + # Save as image if single frame + if video_np.shape[0] == 1: + self._save_single_frame(video_np[0], i, output_dir) + else: + self._save_video(video_np, i, output_dir, media_items_prepad) + + logger.info(f"Outputs saved to {output_dir}") + + def _save_single_frame(self, frame: np.ndarray, index: int, output_dir: Path) -> None: + """Save a single frame as an image""" + output_filename = self._get_unique_filename( + f"image_output_{index}", + ".png", + output_dir + ) + imageio.imwrite(output_filename, frame) + + def _save_video( + self, + video_np: np.ndarray, + index: int, + output_dir: Path, + media_items_prepad: Optional[torch.Tensor] + ) -> None: + """Save video frames and optional condition image""" + # Save video + base_filename = f"img_to_vid_{index}" if self.config.input_image_path else f"text_to_vid_{index}" + output_filename = self._get_unique_filename(base_filename, ".mp4", output_dir) + + with imageio.get_writer(output_filename, fps=self.config.frame_rate) as video: + for frame in video_np: + video.append_data(frame) + + # Save condition image if provided + if media_items_prepad is not None: + reference_image = ( + (media_items_prepad[0, :, 0].permute(1, 2, 0).cpu().data.numpy() + 1.0) + / 2.0 * 255 + ) + condition_filename = self._get_unique_filename( + base_filename, + ".png", + output_dir, + "_condition" + ) + imageio.imwrite(condition_filename, reference_image.astype(np.uint8)) + + def _get_unique_filename( + self, + base: str, + ext: str, + output_dir: Path, + suffix: str = "" + ) -> Path: + """Generate a unique filename for outputs""" + prompt_str = self._convert_prompt_to_filename(self.config.prompt or "no_prompt") + base_filename = f"{base}_{prompt_str}_{self.config.seed}_{self.config.height}x{self.config.width}x{self.config.num_frames}" + + for i in range(1000): + filename = output_dir / f"{base_filename}_{i}{suffix}{ext}" + if not filename.exists(): + return filename + + raise FileExistsError("Could not find a unique filename after 1000 attempts.") + + @staticmethod + def _convert_prompt_to_filename(text: str, max_len: int = 30) -> str: + """Convert prompt text to a valid filename""" + clean_text = "".join( + char.lower() for char in text if char.isalpha() or char.isspace() + ) + words = clean_text.split() + + result = [] + current_length = 0 + + for word in words: + new_length = current_length + len(word) + if new_length <= max_len: + result.append(word) + current_length += len(word) + else: + break + + return "-".join(result) + +def main(): + """Main entry point for inference""" + config = LTXVideoSettings() + inference = LTXInference(config) + inference.generate() + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/src/scripts/mochi_diffusers.py b/scripts/mochi_diffusers.py similarity index 100% rename from src/scripts/mochi_diffusers.py rename to scripts/mochi_diffusers.py diff --git a/src/scripts/mp4_to_s3_json.py b/scripts/mp4_to_s3_json.py similarity index 100% rename from src/scripts/mp4_to_s3_json.py rename to scripts/mp4_to_s3_json.py diff --git a/src/scripts/s3_manager.py b/scripts/s3_manager.py similarity index 100% rename from src/scripts/s3_manager.py rename to scripts/s3_manager.py