From e11293ac40798b91bde61b7055e619674df20d17 Mon Sep 17 00:00:00 2001 From: Jonas Loos <33965649+JonasLoos@users.noreply.github.com> Date: Wed, 20 Mar 2024 02:39:36 +0100 Subject: [PATCH] show disabled representaion positions --- generate-reprs.ipynb | 21 ++++++++++----------- index.html | 13 ++++++++++--- sdwrapper.py | 23 +++++++++++++++-------- 3 files changed, 35 insertions(+), 22 deletions(-) diff --git a/generate-reprs.ipynb b/generate-reprs.ipynb index 6b14cb3..4806f11 100644 --- a/generate-reprs.ipynb +++ b/generate-reprs.ipynb @@ -37,7 +37,7 @@ "models: list[dict[str,Any]] = [\n", " dict(\n", " short='SDXL-Turbo',\n", - " extract_positions = SD.available_extract_positions,\n", + " extract_positions = None, # use all\n", " ),\n", " dict(\n", " short='SDXL-Lightning',\n", @@ -45,26 +45,23 @@ " ),\n", " dict(\n", " short='SD-Turbo',\n", - " extract_positions = SD.available_extract_positions,\n", + " extract_positions = None, # use all\n", " ),\n", " dict(\n", " short='SD-1.5',\n", - " name='runwayml/stable-diffusion-v1-5',\n", - " extract_positions = ['down_blocks[2]', 'mid_block', 'up_blocks[0]', 'conv_out'],\n", + " extract_positions = ['down_blocks[2]', 'down_blocks[3]', 'mid_block', 'up_blocks[0]', 'conv_out'],\n", " ),\n", "]\n", "\n", "prompts = {\n", " \"Cat\": \"A photo of a cat.\",\n", " \"Dog\": \"A photograph of a husky, dog, looking friendly and cute.\",\n", - " # \"Polarbear\": \"A photo of a polar bear.\",\n", " \"Woman\": \"A photo of a beautiful, slightly smiling woman in the city.\",\n", " \"OldMan\": \"A portrait of an old man with a long beard and a hat.\",\n", " \"ConstructionWorker\": \"A photo of a hard working construction worker.\",\n", " \"FuturisticCityscape\": \"A futuristic cityscape at sunset, with flying cars and towering skyscrapers, in the style of cyberpunk.\",\n", " \"MountainLandscape\": \"A serene mountain landscape with a crystal-clear lake in the foreground, reflecting the snow-capped peaks under a bright blue sky.\",\n", " \"SpaceAstronaut\": \"A high-res photo of an astronaut floating in the vastness of space, with a colorful nebula and distant galaxies in the background.\",\n", - " # \"MajesticLion\": \"A close-up portrait of a majestic lion, with detailed fur and piercing eyes, set against the backdrop of the African savannah at dusk.\",\n", " \"MagicalForest\": \"A magical forest filled with glowing plants, mythical creatures, and a pathway leading to an enchanted castle.\",\n", " \"JapaneseGarden\": \"A traditional Japanese garden in spring, complete with cherry blossoms, a koi pond, and a wooden bridge.\",\n", "}\n", @@ -94,11 +91,12 @@ "# run the models\n", "for model_dict in models:\n", " name_short = model_dict['short']\n", - " sd = SD(name_short)\n", - " sd.pipeline.set_progress_bar_config(disable=True)\n", + " sd = SD(name_short, disable_progress_bar=True)\n", + " extract_positions = model_dict['extract_positions'] or sd.available_extract_positions\n", + " del model_dict['extract_positions'] # not needed anymore\n", "\n", " def get_reprs(prompt):\n", - " result = sd(prompt, seed=seed, extract_positions=model_dict['extract_positions'])\n", + " result = sd(prompt, seed=seed, extract_positions=sd.available_extract_positions)\n", " representations = {}\n", " for pos, reprs in result.representations.items():\n", " representations[pos] = []\n", @@ -119,6 +117,7 @@ " model_dict['representations'][pos] = {\n", " 'channels': reprs[0].shape[-1],\n", " 'spatial': reprs[0].shape[-2],\n", + " 'available': pos in extract_positions,\n", " }\n", "\n", " # go through prompts\n", @@ -132,8 +131,8 @@ " representations, images = get_reprs(prompt)\n", "\n", " # save representations\n", - " for pos, reprs in representations.items():\n", - " for j, repr in enumerate(reprs, 0):\n", + " for pos in extract_positions:\n", + " for j, repr in enumerate(representations[pos], 0):\n", " with open(save_path / f'repr-{pos}-{j}.bin', 'wb') as f:\n", " f.write(np.array(repr, dtype=np.float16).tobytes())\n", "\n", diff --git a/index.html b/index.html index 8260f1e..5f42db3 100644 --- a/index.html +++ b/index.html @@ -154,7 +154,7 @@

H-Space similarity explorer

// send a task to the webworker and return a promise function callWasm(task, data) { const id = worker.job_counter++; - const job = {id, task} + const job = {id, task}; const promise = new Promise((resolve, reject) => { job.resolve = resolve; job.reject = reject; @@ -477,11 +477,18 @@

H-Space similarity explorer

function updatePositionSelector() { - if (current_model.representations[current_position] === undefined) current_position = 'mid_block'; // fallback to mid_block, if current position is not available + if (current_model.representations[current_position] === undefined || !current_model.representations[current_position].available) + current_position = 'mid_block'; // fallback to mid_block, if current position is not available document.getElementById('position-to-use').innerHTML = ''; // clear old options Object.keys(current_model.representations).forEach(position => { // setup available positions (where the representations are extracted from) - createElem('option', { value: position, textContent: position == 'mid_block' ? 'mid_block (h-space)' : position, selected: position == current_position }, document.getElementById('position-to-use')); + createElem('option', { + value: position, + textContent: position == 'mid_block' ? 'mid_block (h-space)' : position, + disabled: !current_model.representations[position].available, + title: !current_model.representations[position].available ? 'Not available due to large representation size' : undefined, + selected: position == current_position + }, document.getElementById('position-to-use')); }); } diff --git a/sdwrapper.py b/sdwrapper.py index 02a4bc0..72e3e81 100644 --- a/sdwrapper.py +++ b/sdwrapper.py @@ -4,6 +4,7 @@ import torch from typing import Optional, Callable, Any, Literal from PIL.Image import Image as PILImage +from dataclasses import dataclass @@ -38,18 +39,16 @@ def load_sdxl_lightning(steps: int, device: str): return pipe - +@dataclass class SDResult: prompt : str seed : int - representations : dict[str, list[torch.Tensor]] + representations : dict[str, list[torch.Tensor|tuple]] images : list[PILImage] result_latent : torch.Tensor result_tensor : torch.Tensor result_image : PILImage - def __init__(self, **kwargs): self.__dict__.update(kwargs) - def __repr__(self): return f'' - + def __repr__(self): return f'SDResult(prompt="{self.prompt}",seed={self.seed},...)' class SD: ''' @@ -100,9 +99,10 @@ def __init__( self, model_name: Literal['SD1.5', 'SD2.1', 'SD-Turbo', 'SDXL-Turbo'] | str = 'SD1.5', device: str = 'auto', + disable_progress_bar: bool = False, ): self.model_name = model_name - self.config = self.known_models.get(model_name, {'name':model_name}) + self.config = self.known_models.get(model_name, {'name': model_name}) self.device = device if device != 'auto' else 'cuda' if torch.cuda.is_available() else 'cpu' # setup pipeline @@ -112,12 +112,19 @@ def __init__( self.pipeline = AutoPipelineForText2Image.from_pretrained(self.config['name'], torch_dtype=torch.float16).to(self.device) self.vae = self.pipeline.vae + if disable_progress_bar and hasattr(self.pipeline, 'set_progress_bar_config'): + self.pipeline.set_progress_bar_config(disable=True) + # upcast vae if necessary (SDXL models require float32) if hasattr(self.pipeline, 'upcast_vae') and self.vae.dtype == torch.float16 and self.vae.config.force_upcast: self.pipeline.upcast_vae() - # check h-space dim - # TODO + # setup available extract positions + self.available_extract_positions = [] + for name, obj in self.pipeline.unet.named_children(): + if any(x in name for x in ['time_', '_norm', '_act', '_embedding']): continue + self.available_extract_positions += [f'{name}[{i}]' for i in range(len(obj))] if isinstance(obj, torch.nn.ModuleList) else [name] + self.available_extract_positions.sort(key=lambda x: ['conv_in', 'down_blocks', 'mid_block', 'up_blocks', 'conv_out'].index(x.split('[')[0])) @torch.no_grad() def vae_decode(self, latents):