diff --git a/configs/__pycache__/allegro_settings.cpython-310.pyc b/configs/__pycache__/allegro_settings.cpython-310.pyc new file mode 100644 index 0000000..1c67165 Binary files /dev/null and b/configs/__pycache__/allegro_settings.cpython-310.pyc differ diff --git a/configs/allegro_settings.py b/configs/allegro_settings.py new file mode 100644 index 0000000..235376e --- /dev/null +++ b/configs/allegro_settings.py @@ -0,0 +1,34 @@ +from pydantic_settings import BaseSettings + +class AllegroSettings(BaseSettings): + """ + A Pydantic settings class for Allegro inference configuration. + + This class uses Pydantic to provide validation and easy environment-based configuration + for Allegro inference pipeline settings. + """ + + model_name:str = "rhymes-ai/Allegro" + device: str = "cuda" + seed: int = 42 + guidance_scale: float = 7.5 + max_sequence_length: int = 512 + num_inference_steps: int = 100 + fps: int = 15 + + class Config: + """ + Pydantic configuration class for environment variable support. + """ + env_prefix = "ALLEGRO_" # Prefix for environment variables + validate_assignment = True + + def __repr__(self): + """ + Return a string representation of the settings for debugging purposes. + + :return: A string summarizing the settings. + """ + return (f"AllegroSettings(model_name={self.model_name}, device={self.device}, seed={self.seed}, " + f"guidance_scale={self.guidance_scale}, max_sequence_length={self.max_sequence_length}, " + f"num_inference_steps={self.num_inference_steps}, fps={self.fps})") diff --git a/scripts/allegro_diffusers.py b/scripts/allegro_diffusers.py new file mode 100644 index 0000000..c1d5f29 --- /dev/null +++ b/scripts/allegro_diffusers.py @@ -0,0 +1,117 @@ +import torch +from diffusers import AutoencoderKLAllegro, AllegroPipeline +from diffusers.utils import export_to_video +from loguru import logger +from configs.allegro_settings import AllegroSettings + +class AllegroInference: + """ + A class for managing the Allegro inference pipeline for generating videos based on textual prompts. + + This class encapsulates the initialization, configuration, and video generation processes + for the Allegro model pipeline. It provides a streamlined way to handle prompts, model setup, + and output file management in a production-grade environment. + """ + + def __init__(self, settings: AllegroSettings): + """ + Initialize the AllegroInference class with the given settings. + + :param settings: An instance of AllegroSettings containing model, device, and generation parameters. + """ + self.settings = settings + self.pipe = None + + logger.info(f"Initializing {self.settings.model_name} inference pipeline") + self._setup_pipeline() + + def _setup_pipeline(self): + """ + Set up the Allegro model pipeline by loading the VAE and the pipeline with specified configurations. + + This method loads the models, moves them to the specified device, and enables tiling for + efficient memory usage during inference. + + :raises Exception: If there is an error during the model loading or configuration process. + """ + try: + # Load VAE + logger.info("Loading VAE model...") + vae = AutoencoderKLAllegro.from_pretrained( + self.settings.model_name, + subfolder="vae", + torch_dtype=torch.float32 + ) + + # Load Allegro pipeline + logger.info("Loading Allegro pipeline...") + self.pipe = AllegroPipeline.from_pretrained( + self.settings.model_name, + vae=vae, + torch_dtype=torch.bfloat16 + ) + + # Move pipeline to the specified device + self.pipe.to(self.settings.device) + + # Enable tiling for efficient memory usage + self.pipe.vae.enable_tiling() + + logger.info("Pipeline successfully initialized") + except Exception as e: + logger.error(f"Error initializing pipeline: {e}") + raise + + def generate_video(self, prompt: str, positive_prompt: str, negative_prompt: str, output_path: str): + """ + Generate a video based on the provided prompts and save it to the specified path. + + :param prompt: The main textual description of the video scene. + :param positive_prompt: Additional positive prompts to enhance quality and style. + :param negative_prompt: Prompts to avoid undesirable features in the generated video. + :param output_path: File path to save the generated video. + :raises Exception: If there is an error during video generation or export. + """ + try: + logger.info("Preparing prompts...") + prompt = positive_prompt.format(prompt.lower().strip()) + + logger.info("Generating video...") + generator = torch.Generator(device=self.settings.device).manual_seed(self.settings.seed) + video_frames = self.pipe( + prompt, + negative_prompt=negative_prompt, + guidance_scale=self.settings.guidance_scale, + max_sequence_length=self.settings.max_sequence_length, + num_inference_steps=self.settings.num_inference_steps, + generator=generator + ).frames[0] + + logger.info(f"Exporting video to {output_path}...") + export_to_video(video_frames, output_path, fps=self.settings.fps) + + logger.info("Video generation completed successfully") + except Exception as e: + logger.error(f"Error during video generation: {e}") + raise + +# Example usage (to be executed in a main script or testing environment) +if __name__ == "__main__": + settings = AllegroSettings() + inference = AllegroInference(settings) + + prompt = "A seaside harbor with bright sunlight and sparkling seawater, with many boats in the water. From an aerial view, the boats vary in size and color, some moving and some stationary. Fishing boats in the water suggest that this location might be a popular spot for docking fishing boats." + positive_prompt = """ + (masterpiece), (best quality), (ultra-detailed), (unwatermarked), + {} + emotional, harmonious, vignette, 4k epic detailed, shot on kodak, 35mm photo, + sharp focus, high budget, cinemascope, moody, epic, gorgeous + """ + + negative_prompt = """ + nsfw, lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, + low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry. + """ + + output_path = "output.mp4" + inference.generate_video(prompt, positive_prompt, negative_prompt, output_path)