Skip to content

Commit

Permalink
flug 2x24gb vram parallel strategy
Browse files Browse the repository at this point in the history
  • Loading branch information
apocas committed Oct 22, 2024
1 parent 92b282f commit 3d82b51
Show file tree
Hide file tree
Showing 3 changed files with 149 additions and 0 deletions.
93 changes: 93 additions & 0 deletions app/llms/workers/children/flux1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import base64
import io
import torch
from diffusers import FluxPipeline
import gc
from diffusers import FluxTransformer2DModel
from diffusers import AutoencoderKL
from diffusers.image_processor import VaeImageProcessor

from app.config import RESTAI_DEFAULT_DEVICE

def flush():
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()
torch.cuda.reset_peak_memory_stats()

def worker(prompt, sharedmem):

pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
transformer=None,
vae=None,
device_map="balanced",
max_memory={0: "24GB", 1: "24GB"},
torch_dtype=torch.bfloat16
)
with torch.no_grad():
print("Encoding prompts.")
prompt_embeds, pooled_prompt_embeds, text_ids = pipeline.encode_prompt(
prompt=prompt, prompt_2=None, max_sequence_length=512
)

del pipeline.text_encoder
del pipeline.text_encoder_2
del pipeline.tokenizer
del pipeline.tokenizer_2
del pipeline

flush()

transformer = FluxTransformer2DModel.from_pretrained(
"black-forest-labs/FLUX.1-dev",
subfolder="transformer",
device_map="auto",
torch_dtype=torch.bfloat16
)

pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
text_encoder=None,
text_encoder_2=None,
tokenizer=None,
tokenizer_2=None,
vae=None,
transformer=transformer,
torch_dtype=torch.bfloat16
)

print("Running denoising.")
height, width = 768, 1360
latents = pipeline(
prompt_embeds=prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
num_inference_steps=50,
guidance_scale=3.5,
height=height,
width=width,
output_type="latent",
).images

del pipeline.transformer
del pipeline

flush()

vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=torch.bfloat16).to("cuda")
vae_scale_factor = 2 ** (len(vae.config.block_out_channels))
image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor)

with torch.no_grad():
print("Running decoding.")
latents = FluxPipeline._unpack_latents(latents, height, width, vae_scale_factor)
latents = (latents / vae.config.scaling_factor) + vae.config.shift_factor

image = vae.decode(latents, return_dict=False)[0]
image = image_processor.postprocess(image, output_type="pil")

image_data = io.BytesIO()
image[0].save(image_data, format="JPEG")
image_base64 = base64.b64encode(image_data.getvalue()).decode('utf-8')

sharedmem["image"] = image_base64
54 changes: 54 additions & 0 deletions app/llms/workers/flux.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from torch.multiprocessing import Process, set_start_method, Manager

from app.llms.workers.children.flux1 import worker

try:
set_start_method('spawn')
except RuntimeError:
pass
from langchain.tools import BaseTool
from langchain.chains import LLMChain
from langchain_community.chat_models import ChatOpenAI
from langchain.prompts import PromptTemplate

from typing import Optional
from langchain.callbacks.manager import (
CallbackManagerForToolRun,
)
from ilock import ILock, ILockException


class FluxImage(BaseTool):
name = "Flux Image Generator"
description = "use this tool when you need to generate an image using Flux."
return_direct = True

def _run(self, query: str, run_manager: Optional[CallbackManagerForToolRun] = None) -> str:
if run_manager.tags[0].boost == True:
llm = ChatOpenAI(temperature=0.9, model_name="gpt-3.5-turbo")
prompt = PromptTemplate(
input_variables=["image_desc"],
template="Generate a detailed prompt to generate an image based on the following description: {image_desc}",
)
chain = LLMChain(llm=llm, prompt=prompt)

fprompt = chain.run(query)
else:
fprompt = run_manager.tags[0].question

manager = Manager()
sharedmem = manager.dict()

with ILock('flux', timeout=180):
p = Process(target=worker, args=(fprompt, sharedmem))
p.start()
p.join()
p.kill()

if "image" not in sharedmem or not sharedmem["image"]:
raise Exception("An error occurred while processing the image. Please try again.")

return {"type": "flux", "image": sharedmem["image"], "prompt": fprompt}

async def _arun(self, query: str) -> str:
raise NotImplementedError("N/A")
2 changes: 2 additions & 0 deletions app/projects/vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,9 @@ def question(self, project: Project, questionModel: QuestionModel, user: User, d
from app.llms.workers.stablediffusion import StableDiffusionImage
from app.llms.workers.describeimage import DescribeImage
from app.llms.workers.instantid import InstantID
from app.llms.workers.flux import FluxImage
tools.append(StableDiffusionImage())
tools.append(FluxImage())
tools.append(DescribeImage())
tools.append(InstantID())

Expand Down

0 comments on commit 3d82b51

Please sign in to comment.