From 3d82b51839e78a1f6039d0e0cc6c0b85c5f65f66 Mon Sep 17 00:00:00 2001 From: Pedro Dias Date: Tue, 22 Oct 2024 21:44:35 +0000 Subject: [PATCH] flug 2x24gb vram parallel strategy --- app/llms/workers/children/flux1.py | 93 ++++++++++++++++++++++++++++++ app/llms/workers/flux.py | 54 +++++++++++++++++ app/projects/vision.py | 2 + 3 files changed, 149 insertions(+) create mode 100644 app/llms/workers/children/flux1.py create mode 100644 app/llms/workers/flux.py diff --git a/app/llms/workers/children/flux1.py b/app/llms/workers/children/flux1.py new file mode 100644 index 0000000..6de4623 --- /dev/null +++ b/app/llms/workers/children/flux1.py @@ -0,0 +1,93 @@ +import base64 +import io +import torch +from diffusers import FluxPipeline +import gc +from diffusers import FluxTransformer2DModel +from diffusers import AutoencoderKL +from diffusers.image_processor import VaeImageProcessor + +from app.config import RESTAI_DEFAULT_DEVICE + +def flush(): + gc.collect() + torch.cuda.empty_cache() + torch.cuda.reset_max_memory_allocated() + torch.cuda.reset_peak_memory_stats() + +def worker(prompt, sharedmem): + + pipeline = FluxPipeline.from_pretrained( + "black-forest-labs/FLUX.1-dev", + transformer=None, + vae=None, + device_map="balanced", + max_memory={0: "24GB", 1: "24GB"}, + torch_dtype=torch.bfloat16 + ) + with torch.no_grad(): + print("Encoding prompts.") + prompt_embeds, pooled_prompt_embeds, text_ids = pipeline.encode_prompt( + prompt=prompt, prompt_2=None, max_sequence_length=512 + ) + + del pipeline.text_encoder + del pipeline.text_encoder_2 + del pipeline.tokenizer + del pipeline.tokenizer_2 + del pipeline + + flush() + + transformer = FluxTransformer2DModel.from_pretrained( + "black-forest-labs/FLUX.1-dev", + subfolder="transformer", + device_map="auto", + torch_dtype=torch.bfloat16 + ) + + pipeline = FluxPipeline.from_pretrained( + "black-forest-labs/FLUX.1-dev", + text_encoder=None, + text_encoder_2=None, + tokenizer=None, + tokenizer_2=None, + vae=None, + transformer=transformer, + torch_dtype=torch.bfloat16 + ) + + print("Running denoising.") + height, width = 768, 1360 + latents = pipeline( + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + num_inference_steps=50, + guidance_scale=3.5, + height=height, + width=width, + output_type="latent", + ).images + + del pipeline.transformer + del pipeline + + flush() + + vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=torch.bfloat16).to("cuda") + vae_scale_factor = 2 ** (len(vae.config.block_out_channels)) + image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor) + + with torch.no_grad(): + print("Running decoding.") + latents = FluxPipeline._unpack_latents(latents, height, width, vae_scale_factor) + latents = (latents / vae.config.scaling_factor) + vae.config.shift_factor + + image = vae.decode(latents, return_dict=False)[0] + image = image_processor.postprocess(image, output_type="pil") + + image_data = io.BytesIO() + image[0].save(image_data, format="JPEG") + image_base64 = base64.b64encode(image_data.getvalue()).decode('utf-8') + + sharedmem["image"] = image_base64 diff --git a/app/llms/workers/flux.py b/app/llms/workers/flux.py new file mode 100644 index 0000000..289342b --- /dev/null +++ b/app/llms/workers/flux.py @@ -0,0 +1,54 @@ +from torch.multiprocessing import Process, set_start_method, Manager + +from app.llms.workers.children.flux1 import worker + +try: + set_start_method('spawn') +except RuntimeError: + pass +from langchain.tools import BaseTool +from langchain.chains import LLMChain +from langchain_community.chat_models import ChatOpenAI +from langchain.prompts import PromptTemplate + +from typing import Optional +from langchain.callbacks.manager import ( + CallbackManagerForToolRun, +) +from ilock import ILock, ILockException + + +class FluxImage(BaseTool): + name = "Flux Image Generator" + description = "use this tool when you need to generate an image using Flux." + return_direct = True + + def _run(self, query: str, run_manager: Optional[CallbackManagerForToolRun] = None) -> str: + if run_manager.tags[0].boost == True: + llm = ChatOpenAI(temperature=0.9, model_name="gpt-3.5-turbo") + prompt = PromptTemplate( + input_variables=["image_desc"], + template="Generate a detailed prompt to generate an image based on the following description: {image_desc}", + ) + chain = LLMChain(llm=llm, prompt=prompt) + + fprompt = chain.run(query) + else: + fprompt = run_manager.tags[0].question + + manager = Manager() + sharedmem = manager.dict() + + with ILock('flux', timeout=180): + p = Process(target=worker, args=(fprompt, sharedmem)) + p.start() + p.join() + p.kill() + + if "image" not in sharedmem or not sharedmem["image"]: + raise Exception("An error occurred while processing the image. Please try again.") + + return {"type": "flux", "image": sharedmem["image"], "prompt": fprompt} + + async def _arun(self, query: str) -> str: + raise NotImplementedError("N/A") \ No newline at end of file diff --git a/app/projects/vision.py b/app/projects/vision.py index d171c30..87dd5a2 100644 --- a/app/projects/vision.py +++ b/app/projects/vision.py @@ -52,7 +52,9 @@ def question(self, project: Project, questionModel: QuestionModel, user: User, d from app.llms.workers.stablediffusion import StableDiffusionImage from app.llms.workers.describeimage import DescribeImage from app.llms.workers.instantid import InstantID + from app.llms.workers.flux import FluxImage tools.append(StableDiffusionImage()) + tools.append(FluxImage()) tools.append(DescribeImage()) tools.append(InstantID())