From e11293ac40798b91bde61b7055e619674df20d17 Mon Sep 17 00:00:00 2001 From: Jonas Loos <33965649+JonasLoos@users.noreply.github.com> Date: Wed, 20 Mar 2024 02:39:36 +0100 Subject: [PATCH] show disabled representaion positions --- generate-reprs.ipynb | 21 ++++++++++----------- index.html | 13 ++++++++++--- sdwrapper.py | 23 +++++++++++++++-------- 3 files changed, 35 insertions(+), 22 deletions(-) diff --git a/generate-reprs.ipynb b/generate-reprs.ipynb index 6b14cb3..4806f11 100644 --- a/generate-reprs.ipynb +++ b/generate-reprs.ipynb @@ -37,7 +37,7 @@ "models: list[dict[str,Any]] = [\n", " dict(\n", " short='SDXL-Turbo',\n", - " extract_positions = SD.available_extract_positions,\n", + " extract_positions = None, # use all\n", " ),\n", " dict(\n", " short='SDXL-Lightning',\n", @@ -45,26 +45,23 @@ " ),\n", " dict(\n", " short='SD-Turbo',\n", - " extract_positions = SD.available_extract_positions,\n", + " extract_positions = None, # use all\n", " ),\n", " dict(\n", " short='SD-1.5',\n", - " name='runwayml/stable-diffusion-v1-5',\n", - " extract_positions = ['down_blocks[2]', 'mid_block', 'up_blocks[0]', 'conv_out'],\n", + " extract_positions = ['down_blocks[2]', 'down_blocks[3]', 'mid_block', 'up_blocks[0]', 'conv_out'],\n", " ),\n", "]\n", "\n", "prompts = {\n", " \"Cat\": \"A photo of a cat.\",\n", " \"Dog\": \"A photograph of a husky, dog, looking friendly and cute.\",\n", - " # \"Polarbear\": \"A photo of a polar bear.\",\n", " \"Woman\": \"A photo of a beautiful, slightly smiling woman in the city.\",\n", " \"OldMan\": \"A portrait of an old man with a long beard and a hat.\",\n", " \"ConstructionWorker\": \"A photo of a hard working construction worker.\",\n", " \"FuturisticCityscape\": \"A futuristic cityscape at sunset, with flying cars and towering skyscrapers, in the style of cyberpunk.\",\n", " \"MountainLandscape\": \"A serene mountain landscape with a crystal-clear lake in the foreground, reflecting the snow-capped peaks under a bright blue sky.\",\n", " \"SpaceAstronaut\": \"A high-res photo of an astronaut floating in the vastness of space, with a colorful nebula and distant galaxies in the background.\",\n", - " # \"MajesticLion\": \"A close-up portrait of a majestic lion, with detailed fur and piercing eyes, set against the backdrop of the African savannah at dusk.\",\n", " \"MagicalForest\": \"A magical forest filled with glowing plants, mythical creatures, and a pathway leading to an enchanted castle.\",\n", " \"JapaneseGarden\": \"A traditional Japanese garden in spring, complete with cherry blossoms, a koi pond, and a wooden bridge.\",\n", "}\n", @@ -94,11 +91,12 @@ "# run the models\n", "for model_dict in models:\n", " name_short = model_dict['short']\n", - " sd = SD(name_short)\n", - " sd.pipeline.set_progress_bar_config(disable=True)\n", + " sd = SD(name_short, disable_progress_bar=True)\n", + " extract_positions = model_dict['extract_positions'] or sd.available_extract_positions\n", + " del model_dict['extract_positions'] # not needed anymore\n", "\n", " def get_reprs(prompt):\n", - " result = sd(prompt, seed=seed, extract_positions=model_dict['extract_positions'])\n", + " result = sd(prompt, seed=seed, extract_positions=sd.available_extract_positions)\n", " representations = {}\n", " for pos, reprs in result.representations.items():\n", " representations[pos] = []\n", @@ -119,6 +117,7 @@ " model_dict['representations'][pos] = {\n", " 'channels': reprs[0].shape[-1],\n", " 'spatial': reprs[0].shape[-2],\n", + " 'available': pos in extract_positions,\n", " }\n", "\n", " # go through prompts\n", @@ -132,8 +131,8 @@ " representations, images = get_reprs(prompt)\n", "\n", " # save representations\n", - " for pos, reprs in representations.items():\n", - " for j, repr in enumerate(reprs, 0):\n", + " for pos in extract_positions:\n", + " for j, repr in enumerate(representations[pos], 0):\n", " with open(save_path / f'repr-{pos}-{j}.bin', 'wb') as f:\n", " f.write(np.array(repr, dtype=np.float16).tobytes())\n", "\n", diff --git a/index.html b/index.html index 8260f1e..5f42db3 100644 --- a/index.html +++ b/index.html @@ -154,7 +154,7 @@