diff --git a/run.py b/run.py index d5bf52d..26208fb 100644 --- a/run.py +++ b/run.py @@ -79,15 +79,14 @@ def generate_images(prompt, height, width, negative_prompt, guidance_scale, num_ # Sanitize user input prompt before using it, with a timeout of 5 seconds cleaned_prompt = clean_prompt_with_timeout(prompt, timeout=5) print("Processed prompt:", cleaned_prompt) - - # Load, use, and discard the prior model - prior = load_model("prior") with torch.cuda.amp.autocast(dtype=dtype): seed = torch.seed() if seed == -1 else seed # Get the initial seed torch.manual_seed(seed) # Apply the seed for generation generator = torch.Generator(device).manual_seed(seed) # Preserve for reproducibility + # Load, use, and discard the prior model + prior = load_model("prior") prior.enable_model_cpu_offload() prior_output = prior( prompt=cleaned_prompt, @@ -99,6 +98,8 @@ def generate_images(prompt, height, width, negative_prompt, guidance_scale, num_ num_images_per_prompt=int(num_images_per_prompt), generator=generator, ) + del prior + torch.cuda.empty_cache() # Release GPU memory # Load, use, and discard the decoder model decoder = load_model("decoder") @@ -112,6 +113,8 @@ def generate_images(prompt, height, width, negative_prompt, guidance_scale, num_ output_type="pil", generator=generator, ).images + del decoder + torch.cuda.empty_cache() # Release GPU memory metadata_embedded = { "parameters": "Stable Cascade",