Skip to content

Commit

Permalink
Merge pull request #7 from VikramxD/minimochi-dev
Browse files Browse the repository at this point in the history
Refactor gitignore and gitmodules
  • Loading branch information
VikramxD authored Nov 26, 2024
2 parents b902c92 + 2d4c4a3 commit e48f5f8
Show file tree
Hide file tree
Showing 11 changed files with 82 additions and 29 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1 +1 @@
src/scripts/weights
api/checkpoints
2 changes: 1 addition & 1 deletion .gitmodules
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
[submodule "LTX-Video"]
path = LTX-Video
path = ltx
url = https://github.com/Lightricks/LTX-Video.git
18 changes: 18 additions & 0 deletions api/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@

# Copyright The Lightning AI team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import requests

response = requests.post("http://127.0.0.1:8000/predict", json={"input": 4.0})
print(f"Status: {response.status_code}\nResponse:\n {response.text}")
12 changes: 12 additions & 0 deletions api/logs/ltx_api.log
Original file line number Diff line number Diff line change
@@ -1,2 +1,14 @@
2024-11-23 19:32:34.895 | ERROR | __main__:main:352 - Server failed to start: LitServer.__init__() got an unexpected keyword argument 'generate_client_file'
2024-11-23 19:33:20.102 | INFO | __main__:main:347 - Starting LTX video generation server on port 8000
2024-11-26 10:03:46.386 | INFO | __main__:main:347 - Starting LTX video generation server on port 8000
2024-11-26 10:27:38.795 | INFO | __main__:main:347 - Starting LTX video generation server on port 8000
2024-11-26 10:33:47.301 | INFO | __main__:main:347 - Starting LTX video generation server on port 8000
2024-11-26 10:37:58.440 | INFO | __main__:main:347 - Starting LTX video generation server on port 8000
2024-11-26 10:40:43.972 | INFO | __main__:main:347 - Starting LTX video generation server on port 8000
2024-11-26 10:47:33.232 | INFO | __main__:main:347 - Starting LTX video generation server on port 8000
2024-11-26 11:09:31.128 | INFO | __main__:main:347 - Starting LTX video generation server on port 8000
2024-11-26 11:20:44.187 | INFO | __main__:main:347 - Starting LTX video generation server on port 8000
2024-11-26 11:26:48.145 | INFO | __main__:main:347 - Starting LTX video generation server on port 8000
2024-11-26 11:27:36.160 | INFO | __main__:main:347 - Starting LTX video generation server on port 8000
2024-11-26 11:31:41.717 | INFO | __main__:main:343 - Starting LTX video generation server on port 8000
2024-11-26 11:34:03.736 | INFO | __main__:main:343 - Starting LTX video generation server on port 8000
46 changes: 21 additions & 25 deletions api/ltx_serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,19 +103,12 @@ def setup(self, device: str) -> None:
try:
logger.info(f"Initializing LTX video generation on device: {device}")

# Initialize settings
self.settings = LTXVideoSettings(
device=device,
ckpt_dir=os.environ.get("LTX_CKPT_DIR", "checkpoints"),
)
# Initialize settings with device
self.settings = LTXVideoSettings(device=device)

# Initialize inference engine
self.engine = LTXInference(self.settings)

# Create output directory
self.output_dir = Path("outputs")
self.output_dir.mkdir(parents=True, exist_ok=True)

logger.info("LTX setup completed successfully")

except Exception as e:
Expand Down Expand Up @@ -193,21 +186,21 @@ def predict(self, inputs: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
# Validate request
generation_request = VideoGenerationRequest(**request)

# Update settings with request parameters
self.settings.prompt = generation_request.prompt
self.settings.negative_prompt = generation_request.negative_prompt
self.settings.num_inference_steps = generation_request.num_inference_steps
self.settings.guidance_scale = generation_request.guidance_scale
self.settings.height = generation_request.height
self.settings.width = generation_request.width
self.settings.num_frames = generation_request.num_frames
self.settings.frame_rate = generation_request.frame_rate
self.settings.seed = generation_request.seed

# Create temporary directory for output
with tempfile.TemporaryDirectory() as temp_dir:
temp_video_path = Path(temp_dir) / f"ltx_{int(time.time())}.mp4"

# Update settings with request parameters
self.settings.prompt = generation_request.prompt
self.settings.negative_prompt = generation_request.negative_prompt
self.settings.num_inference_steps = generation_request.num_inference_steps
self.settings.guidance_scale = generation_request.guidance_scale
self.settings.height = generation_request.height
self.settings.width = generation_request.width
self.settings.num_frames = generation_request.num_frames
self.settings.frame_rate = generation_request.frame_rate
self.settings.seed = generation_request.seed
self.settings.output_path = str(temp_video_path)
self.settings.output_path = temp_video_path

# Generate video
logger.info(f"Starting generation for prompt: {generation_request.prompt}")
Expand All @@ -218,7 +211,10 @@ def predict(self, inputs: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
self.log("inference_time", generation_time)

# Get memory statistics
memory_stats = self.engine.get_memory_stats()
memory_stats = {
"gpu_allocated": torch.cuda.memory_allocated() if torch.cuda.is_available() else 0,
"gpu_reserved": torch.cuda.memory_reserved() if torch.cuda.is_available() else 0
}

# Upload to S3
s3_response = mp4_to_s3_json(
Expand Down Expand Up @@ -247,9 +243,9 @@ def predict(self, inputs: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
})

finally:
# Cleanup
if hasattr(self.engine, 'clear_memory'):
self.engine.clear_memory()
# Clear CUDA cache after each generation
if torch.cuda.is_available():
torch.cuda.empty_cache()

except Exception as e:
logger.error(f"Error in predict method: {e}")
Expand Down
Binary file modified configs/__pycache__/ltx_settings.cpython-311.pyc
Binary file not shown.
29 changes: 28 additions & 1 deletion configs/ltx_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,12 @@ class LTXVideoSettings(BaseSettings):
width: int = Field(704, ge=256, le=1280, description="Width of output video frames")
num_frames: int = Field(121, ge=1, le=257, description="Number of frames to generate")
frame_rate: int = Field(25, ge=1, le=60, description="Frame rate of output video")
num_images_per_prompt: int = Field(
1,
ge=1,
le=4,
description="Number of videos to generate per prompt"
)

# Model Settings
bfloat16: bool = Field(False, description="Use bfloat16 precision")
Expand Down Expand Up @@ -158,4 +164,25 @@ def get_model_paths(self) -> tuple[Path, Path, Path]:
vae_dir = self.ckpt_dir / "vae"
scheduler_dir = self.ckpt_dir / "scheduler"

return unet_dir, vae_dir, scheduler_dir
return unet_dir, vae_dir, scheduler_dir

def get_output_resolution(self) -> tuple[int, int]:
"""Get the output resolution as a tuple of (height, width)."""
return (self.height, self.width)

def get_padded_num_frames(self) -> int:
"""
Calculate the padded number of frames.
Ensures the number of frames is compatible with model requirements.
"""
# Common video models often require frame counts to be multiples of 8
FRAME_PADDING = 8

# Calculate padding needed to reach next multiple of FRAME_PADDING
remainder = self.num_frames % FRAME_PADDING
if remainder == 0:
return self.num_frames

padding_needed = FRAME_PADDING - remainder
return self.num_frames + padding_needed

Binary file modified scripts/__pycache__/ltx_inference.cpython-311.pyc
Binary file not shown.
Binary file modified scripts/__pycache__/mp4_to_s3_json.cpython-311.pyc
Binary file not shown.
Binary file modified scripts/__pycache__/s3_manager.cpython-311.pyc
Binary file not shown.
2 changes: 1 addition & 1 deletion scripts/ltx_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def _initialize_pipeline(self) -> LTXVideoPipeline:

def _load_scheduler(self, scheduler_dir: Path):
"""Load and configure the scheduler"""
from ltx_video.schedulers.rf import RectifiedFlowScheduler
from ltx.ltx_video.schedulers.rf import RectifiedFlowScheduler
scheduler_config_path = scheduler_dir / "scheduler_config.json"
scheduler_config = RectifiedFlowScheduler.load_config(scheduler_config_path)
return RectifiedFlowScheduler.from_config(scheduler_config)
Expand Down

0 comments on commit e48f5f8

Please sign in to comment.