Skip to content

Commit

Permalink
show disabled representaion positions
Browse files Browse the repository at this point in the history
  • Loading branch information
JonasLoos committed Mar 20, 2024
1 parent 2a174f3 commit e11293a
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 22 deletions.
21 changes: 10 additions & 11 deletions generate-reprs.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -37,34 +37,31 @@
"models: list[dict[str,Any]] = [\n",
" dict(\n",
" short='SDXL-Turbo',\n",
" extract_positions = SD.available_extract_positions,\n",
" extract_positions = None, # use all\n",
" ),\n",
" dict(\n",
" short='SDXL-Lightning',\n",
" extract_positions = ['down_blocks[0]', 'down_blocks[1]', 'down_blocks[2]', 'mid_block', 'conv_out'],\n",
" ),\n",
" dict(\n",
" short='SD-Turbo',\n",
" extract_positions = SD.available_extract_positions,\n",
" extract_positions = None, # use all\n",
" ),\n",
" dict(\n",
" short='SD-1.5',\n",
" name='runwayml/stable-diffusion-v1-5',\n",
" extract_positions = ['down_blocks[2]', 'mid_block', 'up_blocks[0]', 'conv_out'],\n",
" extract_positions = ['down_blocks[2]', 'down_blocks[3]', 'mid_block', 'up_blocks[0]', 'conv_out'],\n",
" ),\n",
"]\n",
"\n",
"prompts = {\n",
" \"Cat\": \"A photo of a cat.\",\n",
" \"Dog\": \"A photograph of a husky, dog, looking friendly and cute.\",\n",
" # \"Polarbear\": \"A photo of a polar bear.\",\n",
" \"Woman\": \"A photo of a beautiful, slightly smiling woman in the city.\",\n",
" \"OldMan\": \"A portrait of an old man with a long beard and a hat.\",\n",
" \"ConstructionWorker\": \"A photo of a hard working construction worker.\",\n",
" \"FuturisticCityscape\": \"A futuristic cityscape at sunset, with flying cars and towering skyscrapers, in the style of cyberpunk.\",\n",
" \"MountainLandscape\": \"A serene mountain landscape with a crystal-clear lake in the foreground, reflecting the snow-capped peaks under a bright blue sky.\",\n",
" \"SpaceAstronaut\": \"A high-res photo of an astronaut floating in the vastness of space, with a colorful nebula and distant galaxies in the background.\",\n",
" # \"MajesticLion\": \"A close-up portrait of a majestic lion, with detailed fur and piercing eyes, set against the backdrop of the African savannah at dusk.\",\n",
" \"MagicalForest\": \"A magical forest filled with glowing plants, mythical creatures, and a pathway leading to an enchanted castle.\",\n",
" \"JapaneseGarden\": \"A traditional Japanese garden in spring, complete with cherry blossoms, a koi pond, and a wooden bridge.\",\n",
"}\n",
Expand Down Expand Up @@ -94,11 +91,12 @@
"# run the models\n",
"for model_dict in models:\n",
" name_short = model_dict['short']\n",
" sd = SD(name_short)\n",
" sd.pipeline.set_progress_bar_config(disable=True)\n",
" sd = SD(name_short, disable_progress_bar=True)\n",
" extract_positions = model_dict['extract_positions'] or sd.available_extract_positions\n",
" del model_dict['extract_positions'] # not needed anymore\n",
"\n",
" def get_reprs(prompt):\n",
" result = sd(prompt, seed=seed, extract_positions=model_dict['extract_positions'])\n",
" result = sd(prompt, seed=seed, extract_positions=sd.available_extract_positions)\n",
" representations = {}\n",
" for pos, reprs in result.representations.items():\n",
" representations[pos] = []\n",
Expand All @@ -119,6 +117,7 @@
" model_dict['representations'][pos] = {\n",
" 'channels': reprs[0].shape[-1],\n",
" 'spatial': reprs[0].shape[-2],\n",
" 'available': pos in extract_positions,\n",
" }\n",
"\n",
" # go through prompts\n",
Expand All @@ -132,8 +131,8 @@
" representations, images = get_reprs(prompt)\n",
"\n",
" # save representations\n",
" for pos, reprs in representations.items():\n",
" for j, repr in enumerate(reprs, 0):\n",
" for pos in extract_positions:\n",
" for j, repr in enumerate(representations[pos], 0):\n",
" with open(save_path / f'repr-{pos}-{j}.bin', 'wb') as f:\n",
" f.write(np.array(repr, dtype=np.float16).tobytes())\n",
"\n",
Expand Down
13 changes: 10 additions & 3 deletions index.html
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ <h1 style="text-align: center;">H-Space similarity explorer</h1>
// send a task to the webworker and return a promise
function callWasm(task, data) {
const id = worker.job_counter++;
const job = {id, task}
const job = {id, task};
const promise = new Promise((resolve, reject) => {
job.resolve = resolve;
job.reject = reject;
Expand Down Expand Up @@ -477,11 +477,18 @@ <h1 style="text-align: center;">H-Space similarity explorer</h1>


function updatePositionSelector() {
if (current_model.representations[current_position] === undefined) current_position = 'mid_block'; // fallback to mid_block, if current position is not available
if (current_model.representations[current_position] === undefined || !current_model.representations[current_position].available)
current_position = 'mid_block'; // fallback to mid_block, if current position is not available
document.getElementById('position-to-use').innerHTML = ''; // clear old options
Object.keys(current_model.representations).forEach(position => {
// setup available positions (where the representations are extracted from)
createElem('option', { value: position, textContent: position == 'mid_block' ? 'mid_block (h-space)' : position, selected: position == current_position }, document.getElementById('position-to-use'));
createElem('option', {
value: position,
textContent: position == 'mid_block' ? 'mid_block (h-space)' : position,
disabled: !current_model.representations[position].available,
title: !current_model.representations[position].available ? 'Not available due to large representation size' : undefined,
selected: position == current_position
}, document.getElementById('position-to-use'));
});
}

Expand Down
23 changes: 15 additions & 8 deletions sdwrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import torch
from typing import Optional, Callable, Any, Literal
from PIL.Image import Image as PILImage
from dataclasses import dataclass



Expand Down Expand Up @@ -38,18 +39,16 @@ def load_sdxl_lightning(steps: int, device: str):
return pipe



@dataclass
class SDResult:
prompt : str
seed : int
representations : dict[str, list[torch.Tensor]]
representations : dict[str, list[torch.Tensor|tuple]]
images : list[PILImage]
result_latent : torch.Tensor
result_tensor : torch.Tensor
result_image : PILImage
def __init__(self, **kwargs): self.__dict__.update(kwargs)
def __repr__(self): return f'<SDResult prompt="{self.prompt}" seed={self.seed} ...>'

def __repr__(self): return f'SDResult(prompt="{self.prompt}",seed={self.seed},...)'

class SD:
'''
Expand Down Expand Up @@ -100,9 +99,10 @@ def __init__(
self,
model_name: Literal['SD1.5', 'SD2.1', 'SD-Turbo', 'SDXL-Turbo'] | str = 'SD1.5',
device: str = 'auto',
disable_progress_bar: bool = False,
):
self.model_name = model_name
self.config = self.known_models.get(model_name, {'name':model_name})
self.config = self.known_models.get(model_name, {'name': model_name})
self.device = device if device != 'auto' else 'cuda' if torch.cuda.is_available() else 'cpu'

# setup pipeline
Expand All @@ -112,12 +112,19 @@ def __init__(
self.pipeline = AutoPipelineForText2Image.from_pretrained(self.config['name'], torch_dtype=torch.float16).to(self.device)
self.vae = self.pipeline.vae

if disable_progress_bar and hasattr(self.pipeline, 'set_progress_bar_config'):
self.pipeline.set_progress_bar_config(disable=True)

# upcast vae if necessary (SDXL models require float32)
if hasattr(self.pipeline, 'upcast_vae') and self.vae.dtype == torch.float16 and self.vae.config.force_upcast:
self.pipeline.upcast_vae()

# check h-space dim
# TODO
# setup available extract positions
self.available_extract_positions = []
for name, obj in self.pipeline.unet.named_children():
if any(x in name for x in ['time_', '_norm', '_act', '_embedding']): continue
self.available_extract_positions += [f'{name}[{i}]' for i in range(len(obj))] if isinstance(obj, torch.nn.ModuleList) else [name]
self.available_extract_positions.sort(key=lambda x: ['conv_in', 'down_blocks', 'mid_block', 'up_blocks', 'conv_out'].index(x.split('[')[0]))

@torch.no_grad()
def vae_decode(self, latents):
Expand Down

0 comments on commit e11293a

Please sign in to comment.