diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..729f1cc --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +.vscode +**__pycache__ \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..559c4bf --- /dev/null +++ b/README.md @@ -0,0 +1,176 @@ +# Multimodal Garment Designer +### Human-Centric Latent Diffusion Models for Fashion Image Editing +[**Alberto Baldrati**](https://scholar.google.com/citations?hl=en&user=I1jaZecAAAAJ)**\***, +[**Davide Morelli**](https://scholar.google.com/citations?user=UJ4D3rYAAAAJ&hl=en)**\***, +[**Giuseppe Cartella**](https://scholar.google.com/citations?hl=en&user=0sJ4VCcAAAAJ), +[**Marcella Cornia**](https://scholar.google.com/citations?hl=en&user=DzgmSJEAAAAJ), +[**Marco Bertini**](https://scholar.google.com/citations?user=SBm9ZpYAAAAJ&hl=en), +[**Rita Cucchiara**](https://scholar.google.com/citations?hl=en&user=OM3sZEoAAAAJ) + +**\*** Equal contribution. + +[![arXiv](https://img.shields.io/badge/arXiv-Paper-.svg)](https://arxiv.org/abs/2304.02051) +[![GitHub Stars](https://img.shields.io/github/stars/aimagelab/multimodal-garment-designer?style=social)](https://github.com/aimagelab/multimodal-garment-designer) + +This is the **official repository** for the [**paper**](https://arxiv.org/abs/2304.02051) "*Multimodal Garment Designer: Human-Centric Latent Diffusion Models for Fashion Image Editing*". + +## Overview + +

+ +

+ +>**Abstract**:
+> Fashion illustration is used by designers to communicate their vision and to bring the design idea from conceptualization to realization, showing how clothes interact with the human body. In this context, computer vision can thus be used to improve the fashion design process. Differently from previous works that mainly focused on the virtual try-on of garments, we propose the task of multimodal-conditioned fashion image editing, guiding the generation of human-centric fashion images by following multimodal prompts, such as text, human body poses, and garment sketches. We tackle this problem by proposing a new architecture based on latent diffusion models, an approach that has not been used before in the fashion domain. Given the lack of existing datasets suitable for the task, we also extend two existing fashion datasets, namely Dress Code and VITON-HD, with multimodal annotations collected in a semi-automatic manner. Experimental results on these new datasets demonstrate the effectiveness of our proposal, both in terms of realism and coherence with the given multimodal inputs. + +## Citation +If you make use of our work, please cite our paper: + +```bibtex +@article{baldrati2023multimodal, + title={Multimodal Garment Designer: Human-Centric Latent Diffusion Models for Fashion Image Editing}, + author={Baldrati, Alberto and Morelli, Davide and Cartella, Giuseppe and Cornia, Marcella and Bertini, Marco and Cucchiara, Rita}, + journal={arXiv preprint arXiv:2304.02051}, + year={2023} +} +``` + +## Inference + +To run the inference please use the following: + +``` +python eval.py --dataset_path --batch_size --mixed_precision fp16 --output_dir --save_name --num_workers_test --sketch_cond_rate 0.2 --dataset --start_cond_rate 0.0 +``` + +- ```dataset_path``` is the path to the dataset (change accordingly to the dataset parameter) +- ```dataset``` dataset name to be used +- ```output_dir``` path to the output directory +- ```save_name``` name of the output dir subfolder where the generated images are saved +- ```start_cond_rate``` rate {0.0,1.0} of denoising steps in which sketch cond is applied +- ```sketch_cond_rate``` rate {0.0,1.0} of denoising steps that will be used as offset to start sketch conditioning +- ```test_order``` test setting (paired | unpaired) + +Note that we provide few sample images to test MGD simply cloning this repo (*i.e.*, assets/data). To execute the code set +- Dress Code Multimodal dataset + - ```dataset_path``` to ```assets/data/dresscode``` + - ```dataset``` to ```dresscode``` +- Viton-HD Multimodal dataset + - ```dataset_path``` to ```assets/data/vitonhd``` + - ```dataset``` to ```vitonhd``` + +It is possible to run the inference on the whole Dress Code Multimodal or Viton-HD Multimodal dataset simply changing the ```dataset_path``` and ```dataset``` according with the downloaded and prepared datasets (see sections below). + + +## Pre-trained models +The model and checkpoints are available via torch.hub. + +Load the MGD denoising UNet model using the following code: + +``` +unet = torch.hub.load( + dataset=, + repo_or_dir='aimagelab/multimodal-garment-designer', + source='github', + model='mgd', + pretrained=True + ) +``` + +- ```dataset``` dataset name (dresscode | vitonhd) + +Use the denoising network with our custom diffusers pipeline as follow: + +``` +from pipes.sketch_posemap_inpaint_pipe import StableDiffusionSketchPosemapInpaintPipeline +from diffusers import AutoencoderKL, DDIMScheduler +from transformers import CLIPTextModel, CLIPTokenizer + +pretrained_model_name_or_path = "runwayml/stable-diffusion-inpainting" + +text_encoder = CLIPTextModel.from_pretrained( + pretrained_model_name_or_path, + subfolder="text_encoder" + ) + +vae = AutoencoderKL.from_pretrained( + pretrained_model_name_or_path, + subfolder="vae" + ) + +tokenizer = CLIPTokenizer.from_pretrained( + pretrained_model_name_or_path, + subfolder="tokenizer", + ) + +val_scheduler = DDIMScheduler.from_pretrained( + pretrained_model_name_or_path, + subfolder="scheduler" + ) +val_scheduler.set_timesteps(50) + +val_pipe = ValPipe( + text_encoder=text_encoder, + vae=vae, + unet=unet, + tokenizer=tokenizer, + scheduler=val_scheduler, + ) +``` + +For an extensive usage case see the file ```eval.py``` in the main repo. + +## Datasets +You can download the Dress Code Multimodal and Viton-HD Multimodal additional data annotations from here. + +- Dress Code Multimodal **[[link](https://drive.google.com/file/d/1GABxne7cEHyFgmVoffgssfYKvy91KLos/view?usp=share_link)]** +- Viton-HD Multimodal **[[link](https://drive.google.com/file/d/1Z2b9YkyBPA_9ZDC54Y5muW9Q8yfAqWSH/view?usp=share_link)]** + +### Dress Code Multimodal Data Preparation +Once data is downloaded prepare the dataset folder as follow: + +
+Dress Code
+| fine_captions.json
+| coarse_captions.json
+| test_pairs_paired.txt
+| test_pairs_unpaired.txt
+| train_pairs.txt
+| test_stitch_map
+|---- [category]
+|-------- images
+|-------- keypoints
+|-------- skeletons
+|-------- dense
+|-------- im_sketch
+|-------- im_sketch_unpaired
+...
+
+ +### Viton-HD Multimodal Data Preparation +Once data is downloaded prepare the dataset folder as follow: + +
+Viton-HD
+| captions.json
+|---- Train
+|-------- image
+|-------- cloth
+|-------- image-parse-v3
+|-------- openpose_json
+|-------- im_sketch
+|-------- im_sketch_unpaired
+...
+|---- Test
+...
+|-------- im_sketch
+|-------- im_sketch_unpaired
+...
+
+ + +## TODO +- [ ] training code + +## Acknowledgements +This work has partially been supported by the PNRR project “Future Artificial Intelligence Research (FAIR)”, by the PRIN project “CREATIVE: CRoss-modal understanding and gEnerATIon of Visual and tExtual content” (CUP B87G22000460001), both co-funded by the Italian Ministry of University and Research, and by the European Commission under European Horizon 2020 Programme, grant number 101004545 - ReInHerit. \ No newline at end of file diff --git a/assets/data/dresscode/dresses/im_sketch/052012_1.png b/assets/data/dresscode/dresses/im_sketch/052012_1.png new file mode 100644 index 0000000..b4bbae5 Binary files /dev/null and b/assets/data/dresscode/dresses/im_sketch/052012_1.png differ diff --git a/assets/data/dresscode/dresses/im_sketch_unpaired/052012_0_052033_1.png b/assets/data/dresscode/dresses/im_sketch_unpaired/052012_0_052033_1.png new file mode 100644 index 0000000..1cd1864 Binary files /dev/null and b/assets/data/dresscode/dresses/im_sketch_unpaired/052012_0_052033_1.png differ diff --git a/assets/data/dresscode/dresses/images/052012_0.jpg b/assets/data/dresscode/dresses/images/052012_0.jpg new file mode 100755 index 0000000..928d926 Binary files /dev/null and b/assets/data/dresscode/dresses/images/052012_0.jpg differ diff --git a/assets/data/dresscode/dresses/keypoints/051994_2.json b/assets/data/dresscode/dresses/keypoints/051994_2.json new file mode 100755 index 0000000..97f9496 --- /dev/null +++ b/assets/data/dresscode/dresses/keypoints/051994_2.json @@ -0,0 +1 @@ +{"keypoints": [[205.0, 0.0, 0.7454180121421814, 0.0], [193.0, 54.0, 0.9695595502853394, 1.0], [153.0, 46.0, 0.9210028052330017, 2.0], [163.0, 133.0, 0.8196305632591248, 3.0], [182.0, 193.0, 0.8251658082008362, 4.0], [234.0, 59.0, 0.9184455871582031, 5.0], [252.0, 133.0, 0.9567239880561829, 6.0], [186.0, 139.0, 0.8934434652328491, 7.0], [180.0, 199.0, 0.803178071975708, 8.0], [164.0, 323.0, 0.9721584916114807, 9.0], [140.0, 438.0, 0.8884657025337219, 10.0], [230.0, 199.0, 0.7910050749778748, 11.0], [223.0, 322.0, 0.8632205128669739, 12.0], [211.0, 436.0, 0.8701340556144714, 13.0], [197.0, 0.0, 0.5216715931892395, 14.0], [214.0, 0.0, 0.610032856464386, 15.0], [186.0, 0.0, 0.5596458911895752, 16.0], [224.0, 0.0, 0.6544457077980042, 17.0]]} \ No newline at end of file diff --git a/assets/data/dresscode/dresses/keypoints/052012_2.json b/assets/data/dresscode/dresses/keypoints/052012_2.json new file mode 100755 index 0000000..54ce2e4 --- /dev/null +++ b/assets/data/dresscode/dresses/keypoints/052012_2.json @@ -0,0 +1 @@ +{"keypoints": [[209.0, 0.0, 0.5854654312133789, 0.0], [194.0, 46.0, 0.9134835004806519, 1.0], [152.0, 37.0, 0.9346101880073547, 2.0], [137.0, 118.0, 0.8949743509292603, 3.0], [142.0, 188.0, 0.9262046813964844, 4.0], [235.0, 53.0, 0.9443597793579102, 5.0], [242.0, 134.0, 0.8160827159881592, 6.0], [249.0, 208.0, 0.7443605065345764, 7.0], [167.0, 194.0, 0.8199771046638489, 8.0], [160.0, 319.0, 0.907497763633728, 9.0], [143.0, 447.0, 0.8496435284614563, 10.0], [224.0, 195.0, 0.7766034007072449, 11.0], [214.0, 307.0, 0.8796323537826538, 12.0], [195.0, 430.0, 0.9041213393211365, 13.0], [201.0, 0.0, 0.36648842692375183, 14.0], [217.0, 0.0, 0.44042539596557617, 15.0], [187.0, 0.0, 0.40841901302337646, 16.0], [227.0, 0.0, 0.2826925814151764, 17.0]]} \ No newline at end of file diff --git a/assets/data/dresscode/dresses/label_maps/051994_4.png b/assets/data/dresscode/dresses/label_maps/051994_4.png new file mode 100755 index 0000000..f19e54f Binary files /dev/null and b/assets/data/dresscode/dresses/label_maps/051994_4.png differ diff --git a/assets/data/dresscode/dresses/label_maps/052012_4.png b/assets/data/dresscode/dresses/label_maps/052012_4.png new file mode 100755 index 0000000..1fb6231 Binary files /dev/null and b/assets/data/dresscode/dresses/label_maps/052012_4.png differ diff --git a/assets/data/dresscode/dresses/test_pairs_paired.txt b/assets/data/dresscode/dresses/test_pairs_paired.txt new file mode 100644 index 0000000..3e4702a --- /dev/null +++ b/assets/data/dresscode/dresses/test_pairs_paired.txt @@ -0,0 +1 @@ +052012_0.jpg 052012_1.jpg \ No newline at end of file diff --git a/assets/data/dresscode/dresses/test_pairs_unpaired.txt b/assets/data/dresscode/dresses/test_pairs_unpaired.txt new file mode 100644 index 0000000..a4c7ad5 --- /dev/null +++ b/assets/data/dresscode/dresses/test_pairs_unpaired.txt @@ -0,0 +1 @@ +052012_0.jpg 052033_1.jpg \ No newline at end of file diff --git a/assets/data/dresscode/fine_captions.json b/assets/data/dresscode/fine_captions.json new file mode 100644 index 0000000..b4ad0de --- /dev/null +++ b/assets/data/dresscode/fine_captions.json @@ -0,0 +1,2 @@ +{"052033": ["blue belted dress", "light teal blue", "long loose blue dress"], "052896": ["cream dress", "natural sleeveless v-neck dress", "sleevless beige dress"], "049951": ["flowy top", "long sleeved and colored light blue", "sheer blue blouse"], "049534": ["black sweatshirt", "grey crew neck", "long sleeved and dark grey"], "050908": ["skinny denim", "skinny jeans", "skinny mid-rise jeans"], "051078": ["beige shorts", "mid rise shorts", "relaxed fit shorts"], "052012": ["black bell-sleeve mini dress", "black bow detail jersey dress", "long sleeve short black dress"], "051994": ["cream dress", "short cream colored dress", "white short sleeved dress"], "048462": ["black mockneck top", "long sleeved and black color", "long sleeved black top"], "048466": ["short sleeved and picture graphic", "yellow abstract print t-shirt", "yellow graphic tee"], "050855": ["black lexi midi jersey tube skirt", "black maddie jersey midi length tube skirt", "black ttya midi tube skirt"], "050915": ["blue original jeans", "blue straight leg", "blue tapered-leg denim trousers"] +} diff --git a/assets/data/dresscode/lower_body/im_sketch/050855_1.png b/assets/data/dresscode/lower_body/im_sketch/050855_1.png new file mode 100644 index 0000000..18d5b8a Binary files /dev/null and b/assets/data/dresscode/lower_body/im_sketch/050855_1.png differ diff --git a/assets/data/dresscode/lower_body/im_sketch/050915_1.png b/assets/data/dresscode/lower_body/im_sketch/050915_1.png new file mode 100644 index 0000000..7305dda Binary files /dev/null and b/assets/data/dresscode/lower_body/im_sketch/050915_1.png differ diff --git a/assets/data/dresscode/lower_body/im_sketch_unpaired/050855_0_050908_1.png b/assets/data/dresscode/lower_body/im_sketch_unpaired/050855_0_050908_1.png new file mode 100644 index 0000000..5b4cd3d Binary files /dev/null and b/assets/data/dresscode/lower_body/im_sketch_unpaired/050855_0_050908_1.png differ diff --git a/assets/data/dresscode/lower_body/im_sketch_unpaired/050915_0_051078_1.png b/assets/data/dresscode/lower_body/im_sketch_unpaired/050915_0_051078_1.png new file mode 100644 index 0000000..39098cd Binary files /dev/null and b/assets/data/dresscode/lower_body/im_sketch_unpaired/050915_0_051078_1.png differ diff --git a/assets/data/dresscode/lower_body/images/050855_0.jpg b/assets/data/dresscode/lower_body/images/050855_0.jpg new file mode 100755 index 0000000..39c5aef Binary files /dev/null and b/assets/data/dresscode/lower_body/images/050855_0.jpg differ diff --git a/assets/data/dresscode/lower_body/images/050915_0.jpg b/assets/data/dresscode/lower_body/images/050915_0.jpg new file mode 100755 index 0000000..fbab5ae Binary files /dev/null and b/assets/data/dresscode/lower_body/images/050915_0.jpg differ diff --git a/assets/data/dresscode/lower_body/keypoints/050855_2.json b/assets/data/dresscode/lower_body/keypoints/050855_2.json new file mode 100755 index 0000000..0b35f4c --- /dev/null +++ b/assets/data/dresscode/lower_body/keypoints/050855_2.json @@ -0,0 +1 @@ +{"keypoints": [[183.0, 0.0, 0.5823147296905518, 0.0], [208.0, 53.0, 0.9465193152427673, 1.0], [166.0, 48.0, 0.9355934262275696, 2.0], [149.0, 138.0, 0.8791153430938721, 3.0], [142.0, 213.0, 0.9196943640708923, 4.0], [249.0, 54.0, 0.9031239151954651, 5.0], [257.0, 143.0, 0.8434740900993347, 6.0], [269.0, 214.0, 0.9049280881881714, 7.0], [181.0, 196.0, 0.7393415570259094, 8.0], [162.0, 321.0, 0.8804817795753479, 9.0], [125.0, 435.0, 0.8968222141265869, 10.0], [235.0, 199.0, 0.7837328314781189, 11.0], [233.0, 322.0, 0.927562415599823, 12.0], [217.0, 439.0, 0.8323052525520325, 13.0], [176.0, 0.0, 0.2889324426651001, 14.0], [190.0, 0.0, 0.395557701587677, 15.0], [-1.0, -1.0, 0.0, -1.0], [213.0, 0.0, 0.4816431403160095, 16.0]]} \ No newline at end of file diff --git a/assets/data/dresscode/lower_body/keypoints/050915_2.json b/assets/data/dresscode/lower_body/keypoints/050915_2.json new file mode 100755 index 0000000..48a49f3 --- /dev/null +++ b/assets/data/dresscode/lower_body/keypoints/050915_2.json @@ -0,0 +1 @@ +{"keypoints": [[204.0, 0.0, 0.7293938398361206, 0.0], [199.0, 62.0, 0.9325352907180786, 1.0], [152.0, 67.0, 0.9000401496887207, 2.0], [144.0, 146.0, 0.9149616956710815, 3.0], [132.0, 213.0, 0.8797067999839783, 4.0], [246.0, 59.0, 0.8685039281845093, 5.0], [279.0, 135.0, 0.9230703115463257, 6.0], [228.0, 163.0, 0.908068835735321, 7.0], [160.0, 218.0, 0.8140541315078735, 8.0], [160.0, 327.0, 0.9113076329231262, 9.0], [162.0, 432.0, 0.8555663228034973, 10.0], [220.0, 218.0, 0.7834039926528931, 11.0], [227.0, 327.0, 0.8965063691139221, 12.0], [225.0, 434.0, 0.876929759979248, 13.0], [190.0, 0.0, 0.510018527507782, 14.0], [211.0, 0.0, 0.48887312412261963, 15.0], [172.0, 0.0, 0.6903271079063416, 16.0], [219.0, 0.0, 0.19232189655303955, 17.0]]} \ No newline at end of file diff --git a/assets/data/dresscode/lower_body/label_maps/050855_4.png b/assets/data/dresscode/lower_body/label_maps/050855_4.png new file mode 100755 index 0000000..e6cb0eb Binary files /dev/null and b/assets/data/dresscode/lower_body/label_maps/050855_4.png differ diff --git a/assets/data/dresscode/lower_body/label_maps/050915_4.png b/assets/data/dresscode/lower_body/label_maps/050915_4.png new file mode 100755 index 0000000..319055d Binary files /dev/null and b/assets/data/dresscode/lower_body/label_maps/050915_4.png differ diff --git a/assets/data/dresscode/lower_body/test_pairs_paired.txt b/assets/data/dresscode/lower_body/test_pairs_paired.txt new file mode 100644 index 0000000..f482284 --- /dev/null +++ b/assets/data/dresscode/lower_body/test_pairs_paired.txt @@ -0,0 +1,2 @@ +050855_0.jpg 050855_1.jpg +050915_0.jpg 050915_1.jpg diff --git a/assets/data/dresscode/lower_body/test_pairs_unpaired.txt b/assets/data/dresscode/lower_body/test_pairs_unpaired.txt new file mode 100644 index 0000000..64c2132 --- /dev/null +++ b/assets/data/dresscode/lower_body/test_pairs_unpaired.txt @@ -0,0 +1,2 @@ +050855_0.jpg 050908_1.jpg +050915_0.jpg 051078_1.jpg diff --git a/assets/data/dresscode/test_stitchmap/048462_0.png b/assets/data/dresscode/test_stitchmap/048462_0.png new file mode 100644 index 0000000..73b8f3e Binary files /dev/null and b/assets/data/dresscode/test_stitchmap/048462_0.png differ diff --git a/assets/data/dresscode/test_stitchmap/048466_0.png b/assets/data/dresscode/test_stitchmap/048466_0.png new file mode 100644 index 0000000..e170529 Binary files /dev/null and b/assets/data/dresscode/test_stitchmap/048466_0.png differ diff --git a/assets/data/dresscode/test_stitchmap/050855_0.png b/assets/data/dresscode/test_stitchmap/050855_0.png new file mode 100644 index 0000000..8bb37ba Binary files /dev/null and b/assets/data/dresscode/test_stitchmap/050855_0.png differ diff --git a/assets/data/dresscode/test_stitchmap/050915_0.png b/assets/data/dresscode/test_stitchmap/050915_0.png new file mode 100644 index 0000000..06111a2 Binary files /dev/null and b/assets/data/dresscode/test_stitchmap/050915_0.png differ diff --git a/assets/data/dresscode/test_stitchmap/051994_0.png b/assets/data/dresscode/test_stitchmap/051994_0.png new file mode 100644 index 0000000..90f362c Binary files /dev/null and b/assets/data/dresscode/test_stitchmap/051994_0.png differ diff --git a/assets/data/dresscode/test_stitchmap/052012_0.png b/assets/data/dresscode/test_stitchmap/052012_0.png new file mode 100644 index 0000000..fb15a72 Binary files /dev/null and b/assets/data/dresscode/test_stitchmap/052012_0.png differ diff --git a/assets/data/dresscode/upper_body/im_sketch/048462_1.png b/assets/data/dresscode/upper_body/im_sketch/048462_1.png new file mode 100644 index 0000000..4a15223 Binary files /dev/null and b/assets/data/dresscode/upper_body/im_sketch/048462_1.png differ diff --git a/assets/data/dresscode/upper_body/im_sketch_unpaired/048462_0_049951_1.png b/assets/data/dresscode/upper_body/im_sketch_unpaired/048462_0_049951_1.png new file mode 100644 index 0000000..6df5621 Binary files /dev/null and b/assets/data/dresscode/upper_body/im_sketch_unpaired/048462_0_049951_1.png differ diff --git a/assets/data/dresscode/upper_body/im_sketch_unpaired/048466_0_049534_1.png b/assets/data/dresscode/upper_body/im_sketch_unpaired/048466_0_049534_1.png new file mode 100644 index 0000000..8a19a59 Binary files /dev/null and b/assets/data/dresscode/upper_body/im_sketch_unpaired/048466_0_049534_1.png differ diff --git a/assets/data/dresscode/upper_body/images/048462_0.jpg b/assets/data/dresscode/upper_body/images/048462_0.jpg new file mode 100755 index 0000000..efd257c Binary files /dev/null and b/assets/data/dresscode/upper_body/images/048462_0.jpg differ diff --git a/assets/data/dresscode/upper_body/keypoints/048462_2.json b/assets/data/dresscode/upper_body/keypoints/048462_2.json new file mode 100755 index 0000000..1b86c84 --- /dev/null +++ b/assets/data/dresscode/upper_body/keypoints/048462_2.json @@ -0,0 +1 @@ +{"keypoints": [[141.0, 0.0, 0.5669732689857483, 0.0], [138.0, 52.0, 0.9372020363807678, 1.0], [97.0, 51.0, 0.9566869735717773, 2.0], [91.0, 127.0, 0.9176412224769592, 3.0], [92.0, 196.0, 0.9124826192855835, 4.0], [175.0, 52.0, 0.9164174795150757, 5.0], [190.0, 126.0, 0.9348884224891663, 6.0], [200.0, 192.0, 0.9007474780082703, 7.0], [121.0, 186.0, 0.8528367280960083, 8.0], [131.0, 298.0, 0.9579567909240723, 9.0], [147.0, 408.0, 0.9156394600868225, 10.0], [171.0, 184.0, 0.8514549732208252, 11.0], [188.0, 299.0, 0.98090660572052, 12.0], [207.0, 421.0, 0.8863993287086487, 13.0], [133.0, 0.0, 0.38882675766944885, 14.0], [148.0, 0.0, 0.43583664298057556, 15.0], [120.0, 0.0, 0.44555413722991943, 16.0], [159.0, 0.0, 0.3003842532634735, 17.0]]} \ No newline at end of file diff --git a/assets/data/dresscode/upper_body/keypoints/048466_2.json b/assets/data/dresscode/upper_body/keypoints/048466_2.json new file mode 100755 index 0000000..06ffb6e --- /dev/null +++ b/assets/data/dresscode/upper_body/keypoints/048466_2.json @@ -0,0 +1 @@ +{"keypoints": [[200.0, 0.0, 0.688823401927948, 0.0], [201.0, 49.0, 0.9461612105369568, 1.0], [162.0, 46.0, 0.9666376709938049, 2.0], [151.0, 113.0, 0.938896894454956, 3.0], [146.0, 177.0, 0.9722222685813904, 4.0], [238.0, 52.0, 0.9425340890884399, 5.0], [247.0, 117.0, 0.9355710744857788, 6.0], [246.0, 180.0, 0.9147509932518005, 7.0], [170.0, 173.0, 0.8584168553352356, 8.0], [159.0, 281.0, 0.938974916934967, 9.0], [156.0, 372.0, 0.9198458790779114, 10.0], [221.0, 173.0, 0.8431934714317322, 11.0], [219.0, 280.0, 0.9141676425933838, 12.0], [214.0, 367.0, 0.944513201713562, 13.0], [192.0, 0.0, 0.4936670660972595, 14.0], [208.0, 0.0, 0.5329393148422241, 15.0], [182.0, 0.0, 0.4606328308582306, 16.0], [219.0, 0.0, 0.460784375667572, 17.0]]} \ No newline at end of file diff --git a/assets/data/dresscode/upper_body/label_maps/048462_4.png b/assets/data/dresscode/upper_body/label_maps/048462_4.png new file mode 100755 index 0000000..00a876f Binary files /dev/null and b/assets/data/dresscode/upper_body/label_maps/048462_4.png differ diff --git a/assets/data/dresscode/upper_body/label_maps/048466_4.png b/assets/data/dresscode/upper_body/label_maps/048466_4.png new file mode 100755 index 0000000..d5e44e3 Binary files /dev/null and b/assets/data/dresscode/upper_body/label_maps/048466_4.png differ diff --git a/assets/data/dresscode/upper_body/test_pairs_paired.txt b/assets/data/dresscode/upper_body/test_pairs_paired.txt new file mode 100644 index 0000000..21aa686 --- /dev/null +++ b/assets/data/dresscode/upper_body/test_pairs_paired.txt @@ -0,0 +1 @@ +048462_0.jpg 048462_1.jpg \ No newline at end of file diff --git a/assets/data/dresscode/upper_body/test_pairs_unpaired.txt b/assets/data/dresscode/upper_body/test_pairs_unpaired.txt new file mode 100644 index 0000000..fbbf340 --- /dev/null +++ b/assets/data/dresscode/upper_body/test_pairs_unpaired.txt @@ -0,0 +1 @@ +048462_0.jpg 049951_1.jpg \ No newline at end of file diff --git a/assets/data/vitonhd/captions.json b/assets/data/vitonhd/captions.json new file mode 100644 index 0000000..6246f0c --- /dev/null +++ b/assets/data/vitonhd/captions.json @@ -0,0 +1 @@ +{"12419": ["black curve cami rib lace", "black dentelle lace-detail camisole", "black bella cami"], "01944": ["yellow scoop tee", "yellow jersey tee", "gold t-shirt"], "03191": ["white petite t-shirt only macy", "white perforated leather front tee", "white detail tee"], "00349": ["vero moda black high neck blouse", "high-neck top", "black vanessa bruno ath\u00e9 high-neck ruffled blouse"]} diff --git a/assets/data/vitonhd/test/im_sketch/03191_00.png b/assets/data/vitonhd/test/im_sketch/03191_00.png new file mode 100644 index 0000000..bf735f2 Binary files /dev/null and b/assets/data/vitonhd/test/im_sketch/03191_00.png differ diff --git a/assets/data/vitonhd/test/im_sketch/12419_00.png b/assets/data/vitonhd/test/im_sketch/12419_00.png new file mode 100644 index 0000000..a9d84b7 Binary files /dev/null and b/assets/data/vitonhd/test/im_sketch/12419_00.png differ diff --git a/assets/data/vitonhd/test/im_sketch_unpaired/03191_00_00349_00.png b/assets/data/vitonhd/test/im_sketch_unpaired/03191_00_00349_00.png new file mode 100644 index 0000000..cc5c223 Binary files /dev/null and b/assets/data/vitonhd/test/im_sketch_unpaired/03191_00_00349_00.png differ diff --git a/assets/data/vitonhd/test/im_sketch_unpaired/12419_00_01944_00.png b/assets/data/vitonhd/test/im_sketch_unpaired/12419_00_01944_00.png new file mode 100644 index 0000000..b88a307 Binary files /dev/null and b/assets/data/vitonhd/test/im_sketch_unpaired/12419_00_01944_00.png differ diff --git a/assets/data/vitonhd/test/image-parse-v3/03191_00.png b/assets/data/vitonhd/test/image-parse-v3/03191_00.png new file mode 100644 index 0000000..aec1ee7 Binary files /dev/null and b/assets/data/vitonhd/test/image-parse-v3/03191_00.png differ diff --git a/assets/data/vitonhd/test/image-parse-v3/12419_00.png b/assets/data/vitonhd/test/image-parse-v3/12419_00.png new file mode 100644 index 0000000..ef6d023 Binary files /dev/null and b/assets/data/vitonhd/test/image-parse-v3/12419_00.png differ diff --git a/assets/data/vitonhd/test/image/03191_00.jpg b/assets/data/vitonhd/test/image/03191_00.jpg new file mode 100644 index 0000000..142785d Binary files /dev/null and b/assets/data/vitonhd/test/image/03191_00.jpg differ diff --git a/assets/data/vitonhd/test/image/12419_00.jpg b/assets/data/vitonhd/test/image/12419_00.jpg new file mode 100644 index 0000000..b8fdada Binary files /dev/null and b/assets/data/vitonhd/test/image/12419_00.jpg differ diff --git a/assets/data/vitonhd/test/openpose_json/03191_00_keypoints.json b/assets/data/vitonhd/test/openpose_json/03191_00_keypoints.json new file mode 100644 index 0000000..6a1810f --- /dev/null +++ b/assets/data/vitonhd/test/openpose_json/03191_00_keypoints.json @@ -0,0 +1 @@ +{"version":1.3,"people":[{"person_id":[-1],"pose_keypoints_2d":[416.932,147.983,0.897289,419.886,323.564,0.823605,295.199,323.575,0.710266,247.078,553.188,0.845261,167.788,760.006,0.839517,547.427,317.9,0.678732,558.827,555.858,0.815472,572.928,779.771,0.816659,397.27,700.433,0.424303,317.926,697.611,0.385856,300.856,969.661,0.23729,0,0,0,485.062,703.303,0.382307,465.203,972.48,0.184967,0,0,0,383.025,122.292,0.896053,442.498,122.248,0.942278,351.819,145.093,0.790069,482.159,136.519,0.671479,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"face_keypoints_2d":[354.446,127.599,0.808671,356.595,142.643,0.809553,360.893,157.687,0.892167,363.042,174.163,0.84444,368.057,189.207,0.79902,378.086,202.818,0.82672,388.831,214.28,0.824691,401.01,224.309,0.842394,416.77,227.175,0.842865,431.097,223.593,0.78027,446.141,213.564,0.798111,457.603,201.386,0.764526,464.767,186.342,0.860417,469.781,169.865,0.836623,470.497,154.105,0.757279,474.796,138.345,0.686917,476.228,122.585,0.674373,361.609,111.123,0.890275,367.34,105.392,0.944614,378.086,103.959,0.885853,387.399,104.676,0.865673,395.995,106.825,0.928337,423.933,106.108,0.876463,433.963,101.094,0.847098,443.992,99.6609,0.891535,455.454,101.094,0.840891,464.05,108.974,0.840359,411.039,122.585,0.865094,411.755,134.047,0.900373,411.755,145.509,0.902167,412.472,156.971,0.886932,401.726,165.567,0.929022,407.457,167.716,0.912199,412.472,168.432,0.922348,418.202,167,0.88593,423.933,164.134,0.89659,372.355,123.301,0.937799,378.086,119.003,0.91301,388.115,119.719,0.852703,395.279,126.883,0.899933,387.399,128.316,0.889112,378.086,128.316,0.975661,428.948,124.018,0.915437,436.112,117.57,0.984511,446.141,116.854,0.92132,452.588,121.868,0.969951,446.857,126.167,0.888569,436.112,126.883,0.902612,393.13,186.342,0.886258,401.01,182.044,0.871408,408.173,179.894,0.916132,413.188,180.611,0.923757,418.919,179.894,0.908463,428.232,180.611,0.889455,436.112,184.909,0.896513,429.664,192.789,0.899923,421.784,197.087,0.877151,413.904,197.804,0.916279,407.457,197.804,0.952279,401.01,194.938,0.878204,395.995,186.342,0.910329,408.173,185.625,0.890515,413.904,185.625,0.906499,419.635,184.909,0.880725,433.246,185.625,0.876117,419.635,189.207,0.837431,413.904,190.64,0.878774,408.173,189.924,0.874297,383.817,122.585,0.863312,441.126,121.152,0.859384],"hand_left_keypoints_2d":[565.693,776.845,0.542976,552.89,801.536,0.678419,545.574,836.287,0.768888,550.146,871.038,0.833503,549.232,893.9,0.865609,580.325,852.748,0.609337,567.522,893.9,0.828608,551.975,914.019,0.795984,537.343,924.078,0.766036,583.983,855.491,0.627158,572.094,892.986,0.789977,552.89,910.361,0.538921,537.343,920.42,0.410428,585.812,850.919,0.656112,574.838,881.097,0.65846,556.548,901.216,0.531315,541.001,910.361,0.395474,581.239,846.346,0.689323,573.923,872.867,0.668061,560.206,880.183,0.523149,545.574,897.558,0.328769],"hand_right_keypoints_2d":[166.902,765.96,0.632211,166.902,790.336,0.735203,165.096,811.102,0.672094,153.36,845.409,0.939393,149.748,871.592,0.619071,132.594,830.061,0.808289,128.983,867.98,0.782519,138.011,891.454,0.756885,145.234,910.414,0.853907,134.4,835.478,0.617637,129.886,871.592,0.758741,143.428,895.968,0.768459,155.165,913.122,0.702863,139.817,837.284,0.584807,136.206,867.078,0.70512,150.651,888.746,0.801908,162.388,904.094,0.809094,149.748,833.673,0.443493,147.04,860.758,0.500639,154.262,875.203,0.701815,165.999,882.426,0.819218],"pose_keypoints_3d":[],"face_keypoints_3d":[],"hand_left_keypoints_3d":[],"hand_right_keypoints_3d":[]}]} \ No newline at end of file diff --git a/assets/data/vitonhd/test/openpose_json/12419_00_keypoints.json b/assets/data/vitonhd/test/openpose_json/12419_00_keypoints.json new file mode 100644 index 0000000..03bbcf8 --- /dev/null +++ b/assets/data/vitonhd/test/openpose_json/12419_00_keypoints.json @@ -0,0 +1 @@ +{"version":1.3,"people":[{"person_id":[-1],"pose_keypoints_2d":[374.575,238.469,0.949289,337.721,374.564,0.8301,215.97,368.91,0.732457,173.452,604.043,0.810423,173.422,861.967,0.768282,456.755,380.306,0.76543,465.396,601.339,0.750355,527.56,844.897,0.734083,371.599,788.305,0.592811,283.936,793.987,0.534644,0,0,0,0,0,0,453.946,782.666,0.536147,0,0,0,0,0,0,351.821,201.756,0.89095,400.173,215.886,0.919544,303.749,212.944,0.836505,422.804,238.612,0.465832,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"face_keypoints_2d":[311.814,198.615,0.65869,308.136,214.552,0.718225,306.297,229.876,0.67781,305.684,247.039,0.660392,308.136,263.588,0.742804,314.266,278.299,0.792996,324.073,291.784,0.757003,337.558,302.817,0.862079,352.881,307.721,0.890189,366.979,307.721,0.87921,378.625,300.978,0.777896,391.497,293.01,0.815604,400.691,281.977,0.784073,406.821,269.105,0.862144,413.563,258.072,0.793081,420.306,244.587,0.794466,423.984,231.715,0.732577,328.976,190.647,0.872497,339.397,184.518,0.839534,352.268,183.292,0.875146,364.527,186.969,0.839053,375.561,194.938,0.876202,396.401,200.454,0.917688,406.208,199.841,0.901285,415.402,201.067,0.88975,422.145,206.584,0.819484,425.209,215.778,0.807623,381.077,211.487,0.853123,378.012,221.295,0.841805,376.786,231.102,0.913066,376.173,240.909,0.843938,360.237,248.877,0.879791,365.753,250.716,0.835396,371.27,254.394,0.914368,376.173,254.394,0.914862,381.69,254.394,0.916969,340.622,204.745,0.890919,347.978,201.067,0.972366,357.172,204.745,0.913174,362.689,211.487,0.881865,353.494,210.874,0.870807,346.139,209.649,0.909914,391.497,220.069,0.952839,399.466,216.391,0.860999,407.434,219.456,0.92136,411.725,224.972,0.943857,406.208,226.811,0.919377,397.014,224.359,0.94138,343.687,267.879,0.901622,352.881,262.975,0.908549,362.076,259.91,0.933249,367.592,263.588,0.93963,373.109,263.588,0.905823,378.625,269.105,0.85625,382.303,277.686,0.925362,376.173,282.59,0.882075,367.592,283.203,0.943062,362.076,282.59,0.937433,356.559,279.525,0.896022,348.591,274.621,0.883552,347.978,268.492,0.958477,359.624,269.105,0.894953,365.753,270.944,0.843438,371.27,272.782,0.89823,378.625,276.46,0.824877,371.27,273.395,0.897881,365.753,271.556,0.836231,359.011,269.105,0.885902,351.655,205.358,0.82943,401.917,220.069,0.941136],"hand_left_keypoints_2d":[521.435,846.377,0.651244,526.558,868.918,0.830186,535.779,896.582,0.739442,536.804,923.221,1.0176,534.755,945.762,0.950134,551.148,912.975,0.871023,549.099,946.787,0.928153,533.73,960.107,0.856091,519.386,968.303,0.856713,542.951,912.975,0.756564,539.878,946.787,0.637287,520.411,960.107,0.887973,502.993,962.156,0.871243,532.706,911.951,0.667567,528.607,938.59,0.600222,512.214,952.934,0.946885,496.845,958.057,0.856556,518.361,907.853,0.601498,512.214,933.467,0.540937,502.993,944.738,0.621017,494.796,946.787,0.741466],"hand_right_keypoints_2d":[176.566,860.874,0.71883,198.643,888.208,0.738061,210.208,918.696,0.810671,212.311,950.236,0.865716,211.259,969.159,0.71021,188.13,943.928,0.785603,196.541,977.57,0.788865,209.157,984.929,0.819588,216.516,984.929,0.907034,172.36,943.928,0.750136,178.668,980.724,0.838958,194.438,984.929,1.07872,202.849,978.621,0.706354,162.899,939.722,0.838518,169.206,971.262,0.890653,183.925,977.57,0.821722,193.387,974.416,0.549536,157.642,928.158,0.883011,160.796,956.544,0.852623,173.412,964.954,0.602694,180.771,962.852,0.455725],"pose_keypoints_3d":[],"face_keypoints_3d":[],"hand_left_keypoints_3d":[],"hand_right_keypoints_3d":[]}]} \ No newline at end of file diff --git a/assets/data/vitonhd/test_pairs.txt b/assets/data/vitonhd/test_pairs.txt new file mode 100644 index 0000000..d52b9da --- /dev/null +++ b/assets/data/vitonhd/test_pairs.txt @@ -0,0 +1,2 @@ +12419_00.jpg 01944_00.jpg +03191_00.jpg 00349_00.jpg diff --git a/datasets/__pycache__/dresscode.cpython-310.pyc b/datasets/__pycache__/dresscode.cpython-310.pyc new file mode 100644 index 0000000..7e8ae80 Binary files /dev/null and b/datasets/__pycache__/dresscode.cpython-310.pyc differ diff --git a/datasets/__pycache__/preview/PIL/asd.png b/datasets/__pycache__/preview/PIL/asd.png new file mode 100644 index 0000000..f41becf Binary files /dev/null and b/datasets/__pycache__/preview/PIL/asd.png differ diff --git a/datasets/__pycache__/preview/PIL/fake_img.png b/datasets/__pycache__/preview/PIL/fake_img.png new file mode 100644 index 0000000..2fc1023 Binary files /dev/null and b/datasets/__pycache__/preview/PIL/fake_img.png differ diff --git a/datasets/__pycache__/preview/Pytorch/asd.png b/datasets/__pycache__/preview/Pytorch/asd.png new file mode 100644 index 0000000..f41becf Binary files /dev/null and b/datasets/__pycache__/preview/Pytorch/asd.png differ diff --git a/datasets/__pycache__/preview/Pytorch/face.png b/datasets/__pycache__/preview/Pytorch/face.png new file mode 100644 index 0000000..99b0369 Binary files /dev/null and b/datasets/__pycache__/preview/Pytorch/face.png differ diff --git a/datasets/__pycache__/preview/Pytorch/final_img.png b/datasets/__pycache__/preview/Pytorch/final_img.png new file mode 100644 index 0000000..b236274 Binary files /dev/null and b/datasets/__pycache__/preview/Pytorch/final_img.png differ diff --git a/datasets/__pycache__/preview/Pytorch/generated_body.png b/datasets/__pycache__/preview/Pytorch/generated_body.png new file mode 100644 index 0000000..d43efd5 Binary files /dev/null and b/datasets/__pycache__/preview/Pytorch/generated_body.png differ diff --git a/datasets/__pycache__/preview/Pytorch/gt_img.png b/datasets/__pycache__/preview/Pytorch/gt_img.png new file mode 100644 index 0000000..b872bed Binary files /dev/null and b/datasets/__pycache__/preview/Pytorch/gt_img.png differ diff --git a/datasets/__pycache__/preview/Pytorch/im_parse.png b/datasets/__pycache__/preview/Pytorch/im_parse.png new file mode 100644 index 0000000..b4cef90 Binary files /dev/null and b/datasets/__pycache__/preview/Pytorch/im_parse.png differ diff --git a/datasets/__pycache__/preview/Pytorch/label_map.png b/datasets/__pycache__/preview/Pytorch/label_map.png new file mode 100644 index 0000000..5d987e5 Binary files /dev/null and b/datasets/__pycache__/preview/Pytorch/label_map.png differ diff --git a/datasets/__pycache__/preview/Pytorch/seg_head.png b/datasets/__pycache__/preview/Pytorch/seg_head.png new file mode 100644 index 0000000..8375fff Binary files /dev/null and b/datasets/__pycache__/preview/Pytorch/seg_head.png differ diff --git a/datasets/__pycache__/preview/Pytorch/true_head.png b/datasets/__pycache__/preview/Pytorch/true_head.png new file mode 100644 index 0000000..8446ccd Binary files /dev/null and b/datasets/__pycache__/preview/Pytorch/true_head.png differ diff --git a/datasets/__pycache__/preview/Pytorch/true_parts.png b/datasets/__pycache__/preview/Pytorch/true_parts.png new file mode 100644 index 0000000..5eb03fe Binary files /dev/null and b/datasets/__pycache__/preview/Pytorch/true_parts.png differ diff --git a/datasets/__pycache__/vitonhd.cpython-310.pyc b/datasets/__pycache__/vitonhd.cpython-310.pyc new file mode 100644 index 0000000..763e0e7 Binary files /dev/null and b/datasets/__pycache__/vitonhd.cpython-310.pyc differ diff --git a/datasets/dresscode.py b/datasets/dresscode.py new file mode 100644 index 0000000..caacf3d --- /dev/null +++ b/datasets/dresscode.py @@ -0,0 +1,434 @@ +import os +import random +from typing import List, Tuple + +import json +import numpy as np +import cv2 +import torch +import torch.utils.data as data +import torchvision.transforms as transforms +from torchvision.ops import masks_to_boxes +from PIL import Image, ImageDraw, ImageOps + +from utils.labelmap import label_map +from utils.posemap import kpoint_to_heatmap + + +class Dataset(data.Dataset): + def __init__(self, + dataroot_path: str, + phase: str, + tokenizer, + radius=5, + caption_folder='fine_captions.json', + coarse_caption_folder='coarse_captions.json', + sketch_threshold_range: Tuple[int, int] = (20, 127), + order: str = 'paired', + outputlist: Tuple[str] = ('c_name', 'im_name', 'image', 'im_cloth', 'shape', 'pose_map', + 'parse_array', 'im_mask', 'inpaint_mask', 'parse_mask_total', + 'im_sketch', 'captions', 'captions_uncond', + 'original_captions', 'category', 'stitch_label'), + category: Tuple[str] = ('dresses', 'upper_body', 'lower_body'), + size: Tuple[int, int] = (256, 192), + uncond_fraction: float = 0.0, + use_coarse_captions: bool = False, + generated_images_path: str = None, + balance_category: bool = False, + num_elements: int = None + ): + """ + Initialize the PyTroch Dataset Class + :param dataroot_path: dataset root folder + :type dataroot_path: string + :param phase: phase (train | test) + :type phase: string + :param order: setting (paired | unpaired) + :type order: string + :param category: clothing category (upper_body | lower_body | dresses) + :type category: list(str) + :param size: image size (height, width) + :type size: tuple(int) + """ + super(Dataset, self).__init__() + self.dataroot = dataroot_path + self.phase = phase + self.caption_folder = caption_folder + self.sketch_threshold_range = sketch_threshold_range + self.category = category + self.outputlist = outputlist + self.height = size[0] + self.width = size[1] + self.radius = radius + self.tokenizer = tokenizer + self.transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) + ]) + self.transform2D = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.5,), (0.5,)) + ]) + self.order = order + self.uncond_fraction = uncond_fraction + self.use_coarse_captions = use_coarse_captions + self.generated_images_path = generated_images_path + self.balance_category = balance_category + self.num_elements = num_elements + + im_names = [] + c_names = [] + dataroot_names = [] + + possible_outputs = ['c_name', 'im_name', 'cloth', 'image', 'im_cloth', 'shape', 'im_head', 'im_pose', + 'pose_map', 'parse_array', 'dense_labels', 'dense_uv', 'skeleton', + 'im_mask', 'inpaint_mask', 'parse_mask_total', 'cloth_sketch', 'im_sketch', 'captions', + 'captions_uncond', 'original_captions', 'category', 'hands', 'parse_head_2', 'stitch_label'] + + assert all(x in possible_outputs for x in outputlist) + + # Load Captions + with open(os.path.join(self.dataroot, self.caption_folder)) as f: + self.captions_dict = json.load(f) + self.captions_dict = {k: v for k, v in self.captions_dict.items() if len(v) >= 3} + if use_coarse_captions: + with open(os.path.join(self.dataroot, coarse_caption_folder)) as f: + self.captions_dict.update(json.load(f)) + + annotated_elements = [k for k, _ in self.captions_dict.items()] + + for c in category: + assert c in ['dresses', 'upper_body', 'lower_body'] + + dataroot = os.path.join(self.dataroot, c) + if phase == 'train': + filename = os.path.join(dataroot, f"{phase}_pairs.txt") + else: + filename = os.path.join(dataroot, f"{phase}_pairs_{order}.txt") + + with open(filename, 'r') as f: + for line in f.readlines(): + im_name, c_name = line.strip().split() + if c_name.split('_')[0] not in self.captions_dict: + continue + + im_names.append(im_name) + c_names.append(c_name) + dataroot_names.append(dataroot) + + self.im_names = im_names + self.c_names = c_names + self.dataroot_names = dataroot_names + + def __getitem__(self, index): + """ + For each index return the corresponding sample in the dataset + :param index: data index + :type index: int + :return: dict containing dataset samples + :rtype: dict + """ + + if self.balance_category: + assert self.phase == 'train' and self.order == 'paired' + current_category = random.choice(self.category) + chosen_droot = random.choice( + [droot for droot in self.dataroot_names if droot.split('/')[-1] == current_category]) + index = self.dataroot_names.index(chosen_droot) + + c_name = self.c_names[index] + im_name = self.im_names[index] + dataroot = self.dataroot_names[index] + + sketch_threshold = random.randint(self.sketch_threshold_range[0], self.sketch_threshold_range[1]) + + if "captions" in self.outputlist or "original_captions" in self.outputlist: + captions = self.captions_dict[c_name.split('_')[0]] + # if train randomly shuffle captions if there are multiple, else concatenate with comma + if self.phase == 'train': + random.shuffle(captions) + captions = ", ".join(captions) + + # randomly drop captions according to uncond_fraction + if self.uncond_fraction > 0: + captions = "" if random.random() < self.uncond_fraction else captions + + original_captions = captions + + if "captions" in self.outputlist: + cond_input = self.tokenizer([captions], max_length=self.tokenizer.model_max_length, padding="max_length", + truncation=True, return_tensors="pt").input_ids + cond_input = cond_input.squeeze(0) + max_length = cond_input.shape[-1] + uncond_input = self.tokenizer( + [""], padding="max_length", max_length=max_length, return_tensors="pt" + ).input_ids.squeeze(0) + captions = cond_input + captions_uncond = uncond_input + + if "image" in self.outputlist or "im_head" in self.outputlist or "im_cloth" in self.outputlist: + image = Image.open(os.path.join(dataroot, 'images', im_name)) + + image = image.resize((self.width, self.height)) + image = self.transform(image) # [-1,1] + + if "im_sketch" in self.outputlist: + + if "unpaired" == self.order and self.phase == 'test': # Upper of multigarment is the same of unpaired + im_sketch = Image.open(os.path.join(dataroot, 'im_sketch_unpaired', + f'{im_name.replace(".jpg", "")}_{c_name.replace(".jpg", ".png")}')) + else: + im_sketch = Image.open(os.path.join(dataroot, 'im_sketch', c_name.replace(".jpg", ".png"))) + + im_sketch = im_sketch.resize((self.width, self.height)) + im_sketch = ImageOps.invert(im_sketch) + # threshold grayscale pil image + im_sketch = im_sketch.point(lambda p: 255 if p > sketch_threshold else 0) + # im_sketch = im_sketch.convert("RGB") + im_sketch = transforms.functional.to_tensor(im_sketch) # [-1,1] + im_sketch = 1 - im_sketch + + if "im_pose" in self.outputlist or "parser_mask" in self.outputlist or "im_mask" in self.outputlist or "parse_mask_total" in self.outputlist or "parse_array" in self.outputlist or "pose_map" in self.outputlist or "parse_array" in self.outputlist or "shape" in self.outputlist or "im_head" in self.outputlist: + # Label Map + parse_name = im_name.replace('_0.jpg', '_4.png') + im_parse = Image.open(os.path.join(dataroot, 'label_maps', parse_name)) + im_parse = im_parse.resize((self.width, self.height), Image.NEAREST) + parse_array = np.array(im_parse) + + parse_shape = (parse_array > 0).astype(np.float32) + + parse_head = (parse_array == 1).astype(np.float32) + \ + (parse_array == 2).astype(np.float32) + \ + (parse_array == 3).astype(np.float32) + \ + (parse_array == 11).astype(np.float32) + + parser_mask_fixed = (parse_array == label_map["hair"]).astype(np.float32) + \ + (parse_array == label_map["left_shoe"]).astype(np.float32) + \ + (parse_array == label_map["right_shoe"]).astype(np.float32) + \ + (parse_array == label_map["hat"]).astype(np.float32) + \ + (parse_array == label_map["sunglasses"]).astype(np.float32) + \ + (parse_array == label_map["scarf"]).astype(np.float32) + \ + (parse_array == label_map["bag"]).astype(np.float32) + + parser_mask_changeable = (parse_array == label_map["background"]).astype(np.float32) + + arms = (parse_array == 14).astype(np.float32) + (parse_array == 15).astype(np.float32) + + category = dataroot.split('/')[-1] + if dataroot.split('/')[-1] == 'dresses': + label_cat = 7 + parse_cloth = (parse_array == 7).astype(np.float32) + parse_mask = (parse_array == 7).astype(np.float32) + \ + (parse_array == 12).astype(np.float32) + \ + (parse_array == 13).astype(np.float32) + parser_mask_changeable += np.logical_and(parse_array, np.logical_not(parser_mask_fixed)) + + elif dataroot.split('/')[-1] == 'upper_body': + label_cat = 4 + parse_cloth = (parse_array == 4).astype(np.float32) + parse_mask = (parse_array == 4).astype(np.float32) + + parser_mask_fixed += (parse_array == label_map["skirt"]).astype(np.float32) + \ + (parse_array == label_map["pants"]).astype(np.float32) + + parser_mask_changeable += np.logical_and(parse_array, np.logical_not(parser_mask_fixed)) + elif dataroot.split('/')[-1] == 'lower_body': + label_cat = 6 + parse_cloth = (parse_array == 6).astype(np.float32) + parse_mask = (parse_array == 6).astype(np.float32) + \ + (parse_array == 12).astype(np.float32) + \ + (parse_array == 13).astype(np.float32) + + parser_mask_fixed += (parse_array == label_map["upper_clothes"]).astype(np.float32) + \ + (parse_array == 14).astype(np.float32) + \ + (parse_array == 15).astype(np.float32) + parser_mask_changeable += np.logical_and(parse_array, np.logical_not(parser_mask_fixed)) + else: + raise NotImplementedError + + parse_head = torch.from_numpy(parse_head) # [0,1] + parse_cloth = torch.from_numpy(parse_cloth) # [0,1] + parse_mask = torch.from_numpy(parse_mask) # [0,1] + parser_mask_fixed = torch.from_numpy(parser_mask_fixed) + parser_mask_changeable = torch.from_numpy(parser_mask_changeable) + + # dilation + parse_without_cloth = np.logical_and(parse_shape, np.logical_not(parse_mask)) + parse_mask = parse_mask.cpu().numpy() + + if "im_head" in self.outputlist: + # Masked cloth + im_head = image * parse_head - (1 - parse_head) + if "im_cloth" in self.outputlist: + im_cloth = image * parse_cloth + (1 - parse_cloth) + + # Shape + parse_shape = Image.fromarray((parse_shape * 255).astype(np.uint8)) + parse_shape = parse_shape.resize((self.width // 16, self.height // 16), Image.BILINEAR) + parse_shape = parse_shape.resize((self.width, self.height), Image.BILINEAR) + shape = self.transform2D(parse_shape) # [-1,1] + + # Load pose points + pose_name = im_name.replace('_0.jpg', '_2.json') + with open(os.path.join(dataroot, 'keypoints', pose_name), 'r') as f: + pose_label = json.load(f) + pose_data = pose_label['keypoints'] + pose_data = np.array(pose_data) + pose_data = pose_data.reshape((-1, 4)) + + point_num = pose_data.shape[0] + pose_map = torch.zeros(point_num, self.height, self.width) + r = self.radius * (self.height / 512.0) + im_pose = Image.new('L', (self.width, self.height)) + pose_draw = ImageDraw.Draw(im_pose) + neck = Image.new('L', (self.width, self.height)) + neck_draw = ImageDraw.Draw(neck) + for i in range(point_num): + one_map = Image.new('L', (self.width, self.height)) + draw = ImageDraw.Draw(one_map) + point_x = np.multiply(pose_data[i, 0], self.width / 384.0) + point_y = np.multiply(pose_data[i, 1], self.height / 512.0) + if point_x > 1 and point_y > 1: + draw.rectangle((point_x - r, point_y - r, point_x + r, point_y + r), 'white', 'white') + pose_draw.rectangle((point_x - r, point_y - r, point_x + r, point_y + r), 'white', 'white') + if i == 2 or i == 5: + neck_draw.ellipse((point_x - r * 4, point_y - r * 4, point_x + r * 4, point_y + r * 4), 'white', + 'white') + one_map = self.transform2D(one_map) + pose_map[i] = one_map[0] + + d = [] + for pose_d in pose_data: + ux = pose_d[0] / 384.0 + uy = pose_d[1] / 512.0 + + # scale posemap points + px = ux * self.width + py = uy * self.height + + d.append(kpoint_to_heatmap(np.array([px, py]), (self.height, self.width), 9)) + + pose_map = torch.stack(d) + + # just for visualization + im_pose = self.transform2D(im_pose) + + im_arms = Image.new('L', (self.width, self.height)) + arms_draw = ImageDraw.Draw(im_arms) + if dataroot.split('/')[-1] == 'dresses' or dataroot.split('/')[-1] == 'upper_body' or dataroot.split('/')[ + -1] == 'lower_body': + with open(os.path.join(dataroot, 'keypoints', pose_name), 'r') as f: + data = json.load(f) + shoulder_right = np.multiply(tuple(data['keypoints'][2][:2]), self.height / 512.0) + shoulder_left = np.multiply(tuple(data['keypoints'][5][:2]), self.height / 512.0) + elbow_right = np.multiply(tuple(data['keypoints'][3][:2]), self.height / 512.0) + elbow_left = np.multiply(tuple(data['keypoints'][6][:2]), self.height / 512.0) + wrist_right = np.multiply(tuple(data['keypoints'][4][:2]), self.height / 512.0) + wrist_left = np.multiply(tuple(data['keypoints'][7][:2]), self.height / 512.0) + if wrist_right[0] <= 1. and wrist_right[1] <= 1.: + if elbow_right[0] <= 1. and elbow_right[1] <= 1.: + arms_draw.line( + np.concatenate((wrist_left, elbow_left, shoulder_left, shoulder_right)).astype( + np.uint16).tolist(), 'white', 45, 'curve') + else: + arms_draw.line(np.concatenate( + (wrist_left, elbow_left, shoulder_left, shoulder_right, elbow_right)).astype( + np.uint16).tolist(), 'white', 45, 'curve') + elif wrist_left[0] <= 1. and wrist_left[1] <= 1.: + if elbow_left[0] <= 1. and elbow_left[1] <= 1.: + arms_draw.line( + np.concatenate((shoulder_left, shoulder_right, elbow_right, wrist_right)).astype( + np.uint16).tolist(), 'white', 45, 'curve') + else: + arms_draw.line(np.concatenate( + (elbow_left, shoulder_left, shoulder_right, elbow_right, wrist_right)).astype( + np.uint16).tolist(), 'white', 45, 'curve') + else: + arms_draw.line(np.concatenate( + (wrist_left, elbow_left, shoulder_left, shoulder_right, elbow_right, wrist_right)).astype( + np.uint16).tolist(), 'white', 45, 'curve') + + hands = np.logical_and(np.logical_not(im_arms), arms) + + if dataroot.split('/')[-1] == 'dresses' or dataroot.split('/')[-1] == 'upper_body': + parse_mask += im_arms + parser_mask_fixed += hands + + # delete neck + parse_head_2 = torch.clone(parse_head) + if dataroot.split('/')[-1] == 'dresses' or dataroot.split('/')[-1] == 'upper_body': + with open(os.path.join(dataroot, 'keypoints', pose_name), 'r') as f: + data = json.load(f) + points = [] + points.append(np.multiply(tuple(data['keypoints'][2][:2]), self.height / 512.0)) + points.append(np.multiply(tuple(data['keypoints'][5][:2]), self.height / 512.0)) + x_coords, y_coords = zip(*points) + A = np.vstack([x_coords, np.ones(len(x_coords))]).T + m, c = np.linalg.lstsq(A, y_coords, rcond=None)[0] + for i in range(parse_array.shape[1]): + y = i * m + c + parse_head_2[int(y - 20 * (self.height / 512.0)):, i] = 0 + + parser_mask_fixed = np.logical_or(parser_mask_fixed, np.array(parse_head_2, dtype=np.uint16)) + parse_mask += np.logical_or(parse_mask, np.logical_and(np.array(parse_head, dtype=np.uint16), + np.logical_not( + np.array(parse_head_2, dtype=np.uint16)))) + + # tune the amount of dilation here + parse_mask = cv2.dilate(parse_mask, np.ones((5, 5), np.uint16), iterations=5) + parse_mask = np.logical_and(parser_mask_changeable, np.logical_not(parse_mask)) + parse_mask_total = np.logical_or(parse_mask, parser_mask_fixed) + im_mask = image * parse_mask_total + inpaint_mask = 1 - parse_mask_total + + # here we have to modify the mask and get the bounding box + bboxes = masks_to_boxes(inpaint_mask.unsqueeze(0)) + bboxes = bboxes.type(torch.int32) # xmin, ymin, xmax, ymax format + xmin = bboxes[0, 0] + xmax = bboxes[0, 2] + ymin = bboxes[0, 1] + ymax = bboxes[0, 3] + + inpaint_mask[ymin:ymax + 1, xmin:xmax + 1] = torch.logical_and( + torch.ones_like(inpaint_mask[ymin:ymax + 1, xmin:xmax + 1]), + torch.logical_not(parser_mask_fixed[ymin:ymax + 1, xmin:xmax + 1])) + + inpaint_mask = inpaint_mask.unsqueeze(0) + im_mask = image * np.logical_not(inpaint_mask.repeat(3, 1, 1)) + parse_mask_total = parse_mask_total.numpy() + parse_mask_total = parse_array * parse_mask_total + parse_mask_total = torch.from_numpy(parse_mask_total) + + # randomlly drop inputs according to uncond_fraction + if "pose_map" in self.outputlist and random.random() < self.uncond_fraction: + pose_map = torch.zeros_like(pose_map) + + if "im_sketch" in self.outputlist and random.random() < self.uncond_fraction: + im_sketch = torch.zeros_like(im_sketch) + + if "stitch_label" in self.outputlist: + stitch_labelmap = Image.open(os.path.join(self.dataroot, 'test_stitchmap', im_name.replace(".jpg", ".png"))) + stitch_labelmap = transforms.ToTensor()(stitch_labelmap) * 255 + stitch_label = stitch_labelmap == 13 + + result = {} + for k in self.outputlist: + result[k] = vars()[k] + + # Output interpretation + # "c_name" -> filename of inshop cloth + # "im_name" -> filename of model with cloth + # "cloth" -> img of inshop cloth + # "image" -> img of the model with that cloth + # "im_cloth" -> cut cloth from the model + # "im_mask" -> black mask of the cloth in the model img + # "cloth_sketch" -> sketch of the inshop cloth + # "im_sketch" -> sketch of "im_cloth" + + return result + + def __len__(self): + if self.num_elements == None: + return len(self.c_names) + else: + return len(self.c_names[:self.num_elements]) diff --git a/datasets/vitonhd.py b/datasets/vitonhd.py new file mode 100644 index 0000000..bf4bda2 --- /dev/null +++ b/datasets/vitonhd.py @@ -0,0 +1,415 @@ +import random +import cv2 +import torch +import torch.utils.data as data +import torchvision.transforms as transforms +from PIL import Image, ImageDraw, ImageOps +import os +import numpy as np +import json +from typing import List, Tuple +# from utils.labelmap import label_map +from numpy.linalg import lstsq +from torchvision.ops import masks_to_boxes +from utils.posemap import kpoint_to_heatmap +from utils.posemap import get_coco_body25_mapping +from utils.labelmap import label_map_vitonhd as labels + + +class Dataset(torch.nn.Module): + def __init__( + self, + dataroot_path: str, + phase: str, + tokenizer, + radius=5, + caption_folder='captions.json', + sketch_threshold_range: Tuple[int, int] = (20, 127), + order: str = 'paired', + outputlist: Tuple[str] = ('c_name', 'im_name', 'image', 'im_cloth', 'shape', 'pose_map', + 'parse_array', 'im_mask', 'inpaint_mask', 'parse_mask_total', + 'im_sketch', 'captions', 'captions_uncond', 'original_captions'), + category: Tuple[str] = ('dresses', 'upper_body', 'lower_body'), + size: Tuple[int, int] = (256, 192), + uncond_fraction: float = 0.0, + ): + """ + Initialize the PyTroch Dataset Class + :param dataroot_path: dataset root folder + :type dataroot_path: string + :param phase: phase (train | test) + :type phase: string + :param order: setting (paired | unpaired) + :type order: string + :param category: clothing category (upper_body | lower_body | dresses) + :type category: list(str) + :param size: image size (height, width) + :type size: tuple(int) + """ + super(Dataset, self).__init__() + self.dataroot = dataroot_path + self.phase = phase + self.caption_folder = caption_folder + self.sketch_threshold_range = sketch_threshold_range + self.category = ('upper_body') + self.outputlist = outputlist + self.height = size[0] + self.width = size[1] + self.radius = radius + self.tokenizer = tokenizer + self.transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) + ]) + self.transform2D = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.5,), (0.5,)) + ]) + self.order = order + self.uncond_fraction = uncond_fraction + im_names = [] + c_names = [] + dataroot_names = [] + + possible_outputs = ['c_name', 'im_name', 'image', 'im_cloth', 'shape', 'im_head', 'im_pose', + 'pose_map', 'parse_array', + 'im_mask', 'inpaint_mask', 'parse_mask_total', 'im_sketch', 'captions', + 'captions_uncond', 'original_captions', 'category'] + + assert all(x in possible_outputs for x in outputlist) + + # Load Captions + with open(os.path.join(self.dataroot, self.caption_folder)) as f: + # self.captions_dict = json.load(f)['items'] + self.captions_dict = json.load(f) + self.captions_dict = {k: v for k, v in self.captions_dict.items() if len(v) >= 3} + + annotated_elements = [k for k, _ in self.captions_dict.items()] + + dataroot = self.dataroot + if phase == 'train': + filename = os.path.join(dataroot, f"{phase}_pairs.txt") + else: + filename = os.path.join(dataroot, f"{phase}_pairs.txt") + + with open(filename, 'r') as f: + data_len = len(f.readlines()) + + with open(filename, 'r') as f: + for line in f.readlines(): + if phase == 'train': + im_name, _ = line.strip().split() + c_name = im_name + else: + if order == 'paired': + im_name, _ = line.strip().split() + c_name = im_name + else: + im_name, c_name = line.strip().split() + + im_names.append(im_name) + c_names.append(c_name) + dataroot_names.append(dataroot) + + self.im_names = im_names + self.c_names = c_names + self.dataroot_names = dataroot_names + + def __getitem__(self, index): + """ + For each index return the corresponding sample in the dataset + :param index: data index + :type index: int + :return: dict containing dataset samples + :rtype: dict + """ + c_name = self.c_names[index] + im_name = self.im_names[index] + dataroot = self.dataroot_names[index] + + sketch_threshold = random.randint(self.sketch_threshold_range[0], self.sketch_threshold_range[1]) + + if "captions" in self.outputlist or "original_captions" in self.outputlist: + + try: + captions = self.captions_dict[c_name.split('_')[0]] + except: + captions = [''] + try: + # take a random caption if there are multiple + if self.phase == 'train': + random.shuffle(captions) + captions = ", ".join(captions) + except: + raise ValueError( + f"Captions should contain list of strings" + ) + + if self.uncond_fraction > 0: + captions = "" if random.random() < self.uncond_fraction else captions + + original_captions = captions + + if "captions" in self.outputlist: + cond_input = self.tokenizer([captions], max_length=self.tokenizer.model_max_length, padding="max_length", + truncation=True, return_tensors="pt").input_ids + cond_input = cond_input.squeeze(0) + max_length = cond_input.shape[-1] + uncond_input = self.tokenizer( + [""], padding="max_length", max_length=max_length, return_tensors="pt" + ).input_ids.squeeze(0) + captions = cond_input + captions_uncond = uncond_input + + if "image" in self.outputlist or "im_head" in self.outputlist or "im_cloth" in self.outputlist: + # Person image + # image = Image.open(os.path.join(dataroot, 'images', im_name)) + image = Image.open(os.path.join(dataroot, self.phase, 'image', im_name)) + image = image.resize((self.width, self.height)) + image = self.transform(image) # [-1,1] + + if "im_sketch" in self.outputlist: + # Person image + # im_sketch = Image.open(os.path.join(dataroot, 'im_sketch', c_name.replace(".jpg", ".png"))) + if self.order == 'unpaired': + im_sketch = Image.open( + os.path.join(dataroot, self.phase, 'im_sketch_unpaired', + os.path.splitext(im_name)[0] + '_' + c_name.replace(".jpg", ".png"))) + elif self.order == 'paired': + im_sketch = Image.open(os.path.join(dataroot, self.phase, 'im_sketch', im_name.replace(".jpg", ".png"))) + else: + raise ValueError( + f"Order should be either paired or unpaired" + ) + + im_sketch = im_sketch.resize((self.width, self.height)) + im_sketch = ImageOps.invert(im_sketch) + # threshold grayscale pil image + im_sketch = im_sketch.point(lambda p: 255 if p > sketch_threshold else 0) + # im_sketch = im_sketch.convert("RGB") + im_sketch = transforms.functional.to_tensor(im_sketch) # [-1,1] + im_sketch = 1 - im_sketch + + + if "im_pose" in self.outputlist or "parser_mask" in self.outputlist or "im_mask" in self.outputlist or "parse_mask_total" in self.outputlist or "parse_array" in self.outputlist or "pose_map" in self.outputlist or "parse_array" in self.outputlist or "shape" in self.outputlist or "im_head" in self.outputlist: + # Label Map + # parse_name = im_name.replace('_0.jpg', '_4.png') + parse_name = im_name.replace('.jpg', '.png') + im_parse = Image.open(os.path.join(dataroot, self.phase, 'image-parse-v3', parse_name)) + im_parse = im_parse.resize((self.width, self.height), Image.NEAREST) + im_parse_final = transforms.ToTensor()(im_parse) * 255 + parse_array = np.array(im_parse) + + parse_shape = (parse_array > 0).astype(np.float32) + + parse_head = (parse_array == 1).astype(np.float32) + \ + (parse_array == 2).astype(np.float32) + \ + (parse_array == 4).astype(np.float32) + \ + (parse_array == 13).astype(np.float32) + + parser_mask_fixed = (parse_array == 1).astype(np.float32) + \ + (parse_array == 2).astype(np.float32) + \ + (parse_array == 18).astype(np.float32) + \ + (parse_array == 19).astype(np.float32) + + # parser_mask_changeable = (parse_array == label_map["background"]).astype(np.float32) + parser_mask_changeable = (parse_array == 0).astype(np.float32) + + arms = (parse_array == 14).astype(np.float32) + (parse_array == 15).astype(np.float32) + + parse_cloth = (parse_array == 5).astype(np.float32) + \ + (parse_array == 6).astype(np.float32) + \ + (parse_array == 7).astype(np.float32) + parse_mask = (parse_array == 5).astype(np.float32) + \ + (parse_array == 6).astype(np.float32) + \ + (parse_array == 7).astype(np.float32) + + parser_mask_fixed = parser_mask_fixed + (parse_array == 9).astype(np.float32) + \ + (parse_array == 12).astype(np.float32) # the lower body is fixed + + parser_mask_changeable += np.logical_and(parse_array, np.logical_not(parser_mask_fixed)) + + parse_head = torch.from_numpy(parse_head) # [0,1] + parse_cloth = torch.from_numpy(parse_cloth) # [0,1] + parse_mask = torch.from_numpy(parse_mask) # [0,1] + parser_mask_fixed = torch.from_numpy(parser_mask_fixed) + parser_mask_changeable = torch.from_numpy(parser_mask_changeable) + + # dilation + parse_without_cloth = np.logical_and(parse_shape, np.logical_not(parse_mask)) + parse_mask = parse_mask.cpu().numpy() + + if "im_head" in self.outputlist: + # Masked cloth + im_head = image * parse_head - (1 - parse_head) + if "im_cloth" in self.outputlist: + im_cloth = image * parse_cloth + (1 - parse_cloth) + + # Shape + parse_shape = Image.fromarray((parse_shape * 255).astype(np.uint8)) + parse_shape = parse_shape.resize((self.width // 16, self.height // 16), Image.BILINEAR) + parse_shape = parse_shape.resize((self.width, self.height), Image.BILINEAR) + shape = self.transform2D(parse_shape) # [-1,1] + + # Load pose points + pose_name = im_name.replace('.jpg', '_keypoints.json') + with open(os.path.join(dataroot, self.phase, 'openpose_json', pose_name), 'r') as f: + pose_label = json.load(f) + pose_data = pose_label['people'][0]['pose_keypoints_2d'] + pose_data = np.array(pose_data) + pose_data = pose_data.reshape((-1, 3))[:, :2] + + # rescale keypoints on the base of height and width + pose_data[:, 0] = pose_data[:, 0] * (self.width / 768) + pose_data[:, 1] = pose_data[:, 1] * (self.height / 1024) + + pose_mapping = get_coco_body25_mapping() + + point_num = len(pose_mapping) + + pose_map = torch.zeros(point_num, self.height, self.width) + r = self.radius * (self.height / 512.0) + im_pose = Image.new('L', (self.width, self.height)) + pose_draw = ImageDraw.Draw(im_pose) + neck = Image.new('L', (self.width, self.height)) + neck_draw = ImageDraw.Draw(neck) + for i in range(point_num): + one_map = Image.new('L', (self.width, self.height)) + draw = ImageDraw.Draw(one_map) + point_x = np.multiply(pose_data[pose_mapping[i], 0], 1) + point_y = np.multiply(pose_data[pose_mapping[i], 1], 1) + + if point_x > 1 and point_y > 1: + draw.rectangle((point_x - r, point_y - r, point_x + r, point_y + r), 'white', 'white') + pose_draw.rectangle((point_x - r, point_y - r, point_x + r, point_y + r), 'white', 'white') + if i == 2 or i == 5: + neck_draw.ellipse((point_x - r * 4, point_y - r * 4, point_x + r * 4, point_y + r * 4), 'white', + 'white') + one_map = self.transform2D(one_map) + pose_map[i] = one_map[0] + + d = [] + + for idx in range(point_num): + ux = pose_data[pose_mapping[idx], 0] # / (192) + uy = (pose_data[pose_mapping[idx], 1]) # / (256) + + # scale posemap points + px = ux # * self.width + py = uy # * self.height + + d.append(kpoint_to_heatmap(np.array([px, py]), (self.height, self.width), 9)) + + pose_map = torch.stack(d) + + # just for visualization + im_pose = self.transform2D(im_pose) + + im_arms = Image.new('L', (self.width, self.height)) + arms_draw = ImageDraw.Draw(im_arms) + + # do in any case because i have only upperbody + with open(os.path.join(dataroot, self.phase, 'openpose_json', pose_name), 'r') as f: + data = json.load(f) + data = data['people'][0]['pose_keypoints_2d'] + data = np.array(data) + data = data.reshape((-1, 3))[:, :2] + + # rescale keypoints on the base of height and width + data[:, 0] = data[:, 0] * (self.width / 768) + data[:, 1] = data[:, 1] * (self.height / 1024) + + shoulder_right = np.multiply(tuple(data[pose_mapping[2]]), 1) + shoulder_left = np.multiply(tuple(data[pose_mapping[5]]), 1) + elbow_right = np.multiply(tuple(data[pose_mapping[3]]), 1) + elbow_left = np.multiply(tuple(data[pose_mapping[6]]), 1) + wrist_right = np.multiply(tuple(data[pose_mapping[4]]), 1) + wrist_left = np.multiply(tuple(data[pose_mapping[7]]), 1) + + ARM_LINE_WIDTH = int(90 / 512 * self.height) + if wrist_right[0] <= 1. and wrist_right[1] <= 1.: + if elbow_right[0] <= 1. and elbow_right[1] <= 1.: + arms_draw.line( + np.concatenate((wrist_left, elbow_left, shoulder_left, shoulder_right)).astype( + np.uint16).tolist(), 'white', ARM_LINE_WIDTH, 'curve') + else: + arms_draw.line(np.concatenate( + (wrist_left, elbow_left, shoulder_left, shoulder_right, elbow_right)).astype( + np.uint16).tolist(), 'white', ARM_LINE_WIDTH, 'curve') + elif wrist_left[0] <= 1. and wrist_left[1] <= 1.: + if elbow_left[0] <= 1. and elbow_left[1] <= 1.: + arms_draw.line( + np.concatenate((shoulder_left, shoulder_right, elbow_right, wrist_right)).astype( + np.uint16).tolist(), 'white', ARM_LINE_WIDTH, 'curve') + else: + arms_draw.line(np.concatenate( + (elbow_left, shoulder_left, shoulder_right, elbow_right, wrist_right)).astype( + np.uint16).tolist(), 'white', ARM_LINE_WIDTH, 'curve') + else: + arms_draw.line(np.concatenate( + (wrist_left, elbow_left, shoulder_left, shoulder_right, elbow_right, wrist_right)).astype( + np.uint16).tolist(), 'white', ARM_LINE_WIDTH, 'curve') + + hands = np.logical_and(np.logical_not(im_arms), arms) + parse_mask += im_arms + parser_mask_fixed += hands + + # delete neck + parse_head_2 = torch.clone(parse_head) + + parser_mask_fixed = np.logical_or(parser_mask_fixed, np.array(parse_head_2, dtype=np.uint16)) + parse_mask += np.logical_or(parse_mask, np.logical_and(np.array(parse_head, dtype=np.uint16), + np.logical_not( + np.array(parse_head_2, dtype=np.uint16)))) + + parse_mask = np.logical_and(parser_mask_changeable, np.logical_not(parse_mask)) + parse_mask_total = np.logical_or(parse_mask, parser_mask_fixed) + # im_mask = image * parse_mask_total + inpaint_mask = 1 - parse_mask_total + + # here we have to modify the mask and get the bounding box + bboxes = masks_to_boxes(inpaint_mask.unsqueeze(0)) + bboxes = bboxes.type(torch.int32) # xmin, ymin, xmax, ymax format + xmin = bboxes[0, 0] + xmax = bboxes[0, 2] + ymin = bboxes[0, 1] + ymax = bboxes[0, 3] + + inpaint_mask[ymin:ymax + 1, xmin:xmax + 1] = torch.logical_and( + torch.ones_like(inpaint_mask[ymin:ymax + 1, xmin:xmax + 1]), + torch.logical_not(parser_mask_fixed[ymin:ymax + 1, xmin:xmax + 1])) + + inpaint_mask = inpaint_mask.unsqueeze(0) + im_mask = image * np.logical_not(inpaint_mask.repeat(3, 1, 1)) + parse_mask_total = parse_mask_total.numpy() + parse_mask_total = parse_array * parse_mask_total + parse_mask_total = torch.from_numpy(parse_mask_total) + + if "pose_map" in self.outputlist and torch.rand(1) < self.uncond_fraction: + pose_map = torch.zeros_like(pose_map) + + if "im_sketch" in self.outputlist and torch.rand(1) < self.uncond_fraction: + im_sketch = torch.zeros_like(im_sketch) + + result = {} + for k in self.outputlist: + result[k] = vars()[k] + + result['im_parse'] = im_parse_final + result['hands'] = torch.from_numpy(hands) + + # Output interpretation + # "c_name" -> filename of inshop cloth + # "im_name" -> filename of model with cloth + # "cloth" -> img of inshop cloth + # "image" -> img of the model with that cloth + # "im_cloth" -> cut cloth from the model + # "im_mask" -> black mask of the cloth in the model img + # "cloth_sketch" -> sketch of the inshop cloth + # "im_sketch" -> sketch of "im_cloth" + # inpaint_mask -> bb of the model img where the cloth is + + return result + + def __len__(self): + return len(self.c_names) diff --git a/eval.py b/eval.py new file mode 100644 index 0000000..de26f3e --- /dev/null +++ b/eval.py @@ -0,0 +1,157 @@ +import json +import os + +# external libraries +import accelerate +import torch +import torch.utils.checkpoint +import torch.utils.checkpoint +from accelerate import Accelerator +from accelerate.logging import get_logger +from diffusers import AutoencoderKL, DDIMScheduler +from diffusers.utils import check_min_version +from diffusers.utils.import_utils import is_xformers_available +from transformers import CLIPTextModel, CLIPTokenizer + +# custom imports +from datasets.dresscode import Dataset as Dataset +from datasets.vitonhd import Dataset as Dataset_viton +from pipes.sketch_posemap_inpaint_pipe import StableDiffusionSketchPosemapInpaintPipeline as ValPipe +from pipes.sketch_posemap_inpaint_pipe_disentangled import StableDiffusionSketchPosemapInpaintPipeline as ValPipeDisentangled +from utils.image_from_pipe import generate_images_from_inpaint_sketch_posemap_pipe +from utils.set_seeds import set_seed +from utils.arg_parser import parse_args + +# Will error if the minimal version of diffusers is not installed. Remove at your own risks. +check_min_version("0.10.0.dev0") + +logger = get_logger(__name__, log_level="INFO") +os.environ["TOKENIZERS_PARALLELISM"] = "true" +os.environ["WANDB_START_METHOD"] = "thread" + + +def main() -> None: + args = parse_args() + accelerator = Accelerator( + mixed_precision=args.mixed_precision, + ) + device = accelerator.device + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + + # Load scheduler, tokenizer and models. + val_scheduler = DDIMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") + val_scheduler.set_timesteps(50, device=device) + + tokenizer = CLIPTokenizer.from_pretrained( + args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision + ) + text_encoder = CLIPTextModel.from_pretrained( + args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision + ) + vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision) + + unet = torch.hub.load(dataset=args.dataset, repo_or_dir='aimagelab/multimodal-garment-designer', source='github', model='mgd', pretrained=True) + + # Freeze vae and text_encoder + vae.requires_grad_(False) + text_encoder.requires_grad_(False) + + # Enable memory efficient attention if requested + if args.enable_xformers_memory_efficient_attention: + if is_xformers_available(): + unet.enable_xformers_memory_efficient_attention() + else: + raise ValueError("xformers is not available. Make sure it is installed correctly") + + if args.category: + category = [args.category] + else: + category = ['dresses', 'upper_body', 'lower_body'] + + if args.dataset == "dresscode": + test_dataset = Dataset( + dataroot_path=args.dataset_path, + phase='test', + order=args.test_order, + radius=5, + sketch_threshold_range=(20, 20), + tokenizer=tokenizer, + category=category, + size=(512, 384) + ) + elif args.dataset == "vitonhd": + test_dataset = Dataset_viton( + dataroot_path=args.dataset_path, + phase='test', + order=args.test_order, + sketch_threshold_range=(20, 20), + radius=5, + tokenizer=tokenizer, + category=['upper_body'], + size=(512, 384), + ) + else: + raise NotImplementedError + + test_dataloader = torch.utils.data.DataLoader( + test_dataset, + shuffle=False, + batch_size=args.batch_size, + num_workers=args.num_workers_test, + ) + + # For mixed precision training we cast the text_encoder and vae weights to half-precision + # as these models are only used for inference, keeping weights in full precision is not required. + weight_dtype = torch.float32 + if args.mixed_precision == 'fp16': + weight_dtype = torch.float16 + + # Move text_encode and vae to gpu and cast to weight_dtype + text_encoder.to(device, dtype=weight_dtype) + vae.to(device, dtype=weight_dtype) + + unet.eval() + # Select fast classifier free guidance or disentagle classifier free guidance according to the disentagle parameter in args + with torch.inference_mode(): + if args.disentagle: + val_pipe = ValPipeDisentangled( + text_encoder=text_encoder, + vae=vae, + unet=unet.to(vae.dtype), + tokenizer=tokenizer, + scheduler=val_scheduler, + ).to(device) + else: + val_pipe = ValPipe( + text_encoder=text_encoder, + vae=vae, + unet=unet.to(vae.dtype), + tokenizer=tokenizer, + scheduler=val_scheduler, + ).to(device) + + val_pipe.enable_attention_slicing() + test_dataloader = accelerator.prepare(test_dataloader) + generate_images_from_inpaint_sketch_posemap_pipe( + test_order = args.test_order, + pipe = val_pipe, + test_dataloader = test_dataloader, + save_name = args.save_name, + dataset = args.dataset, + output_dir = args.output_dir, + guidance_scale = args.guidance_scale, + guidance_scale_pose = args.guidance_scale_pose, + guidance_scale_sketch = args.guidance_scale_sketch, + sketch_cond_rate = args.sketch_cond_rate, + start_cond_rate = args.start_cond_rate, + no_pose = False, + disentagle = False, + seed = args.seed, + ) + + +if __name__ == "__main__": + main() diff --git a/hubconf.py b/hubconf.py new file mode 100644 index 0000000..0c12ccb --- /dev/null +++ b/hubconf.py @@ -0,0 +1,23 @@ +dependencies = ['torch', 'diffusers'] +import os +import torch +from diffusers import UNet2DConditionModel + + +# mgd is the name of entrypoint +def mgd(dataset:str, pretrained: bool =True, **kwargs) -> UNet2DConditionModel: + + """ # This docstring shows up in hub.help() + MGD model + pretrained (bool): kwargs, load pretrained weights into the model + """ + + config = UNet2DConditionModel.load_config("runwayml/stable-diffusion-inpainting", subfolder="unet") + config['in_channels'] = 28 + unet = UNet2DConditionModel.from_config(config) + + if pretrained: + checkpoint = f"https://github.com/aimagelab/multimodal-garment-designer/releases/download/weights/{dataset}.pth" + unet.load_state_dict(torch.hub.load_state_dict_from_url(checkpoint, progress=True)) + + return unet diff --git a/images/1.gif b/images/1.gif new file mode 100644 index 0000000..62d63ca Binary files /dev/null and b/images/1.gif differ diff --git a/output/test_paired/images/048462_0.jpg b/output/test_paired/images/048462_0.jpg new file mode 100644 index 0000000..ab1454e Binary files /dev/null and b/output/test_paired/images/048462_0.jpg differ diff --git a/output/test_paired/images/050855_0.jpg b/output/test_paired/images/050855_0.jpg new file mode 100644 index 0000000..23f9901 Binary files /dev/null and b/output/test_paired/images/050855_0.jpg differ diff --git a/output/test_paired/images/050915_0.jpg b/output/test_paired/images/050915_0.jpg new file mode 100644 index 0000000..e4c671f Binary files /dev/null and b/output/test_paired/images/050915_0.jpg differ diff --git a/output/test_paired/images/052012_0.jpg b/output/test_paired/images/052012_0.jpg new file mode 100644 index 0000000..658b012 Binary files /dev/null and b/output/test_paired/images/052012_0.jpg differ diff --git a/output/test_unpaired/images/048462_0.jpg b/output/test_unpaired/images/048462_0.jpg new file mode 100644 index 0000000..cfdb96f Binary files /dev/null and b/output/test_unpaired/images/048462_0.jpg differ diff --git a/output/test_unpaired/images/050855_0.jpg b/output/test_unpaired/images/050855_0.jpg new file mode 100644 index 0000000..2d39bc1 Binary files /dev/null and b/output/test_unpaired/images/050855_0.jpg differ diff --git a/output/test_unpaired/images/050915_0.jpg b/output/test_unpaired/images/050915_0.jpg new file mode 100644 index 0000000..ef8d49b Binary files /dev/null and b/output/test_unpaired/images/050915_0.jpg differ diff --git a/output/test_unpaired/images/052012_0.jpg b/output/test_unpaired/images/052012_0.jpg new file mode 100644 index 0000000..a217196 Binary files /dev/null and b/output/test_unpaired/images/052012_0.jpg differ diff --git a/output_vitonhd/test_unpaired/images/03191_00.jpg b/output_vitonhd/test_unpaired/images/03191_00.jpg new file mode 100644 index 0000000..d15f64e Binary files /dev/null and b/output_vitonhd/test_unpaired/images/03191_00.jpg differ diff --git a/output_vitonhd/test_unpaired/images/12419_00.jpg b/output_vitonhd/test_unpaired/images/12419_00.jpg new file mode 100644 index 0000000..e541f93 Binary files /dev/null and b/output_vitonhd/test_unpaired/images/12419_00.jpg differ diff --git a/pipes/__init__.py b/pipes/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/pipes/sketch_posemap_inpaint_pipe.py b/pipes/sketch_posemap_inpaint_pipe.py new file mode 100644 index 0000000..e5ad048 --- /dev/null +++ b/pipes/sketch_posemap_inpaint_pipe.py @@ -0,0 +1,638 @@ +import inspect +from typing import Callable, List, Optional, Union + +from packaging import version +import PIL +import numpy as np +import torch +import torchvision +from transformers import CLIPTextModel, CLIPTokenizer +from diffusers.utils import is_accelerate_available +from diffusers.configuration_utils import FrozenDict +from diffusers.models import AutoencoderKL, UNet2DConditionModel +from diffusers.pipeline_utils import DiffusionPipeline +from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler +from diffusers.utils import deprecate +from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput +from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint import prepare_mask_and_masked_image + + +class StableDiffusionSketchPosemapInpaintPipeline(DiffusionPipeline): + r""" + Pipeline for text and posemap -guided image inpainting using Stable Diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. + feature_extractor ([`CLIPFeatureExtractor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + _optional_components = ["safety_checker"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], + safety_checker=None, + feature_extractor=None, + requires_safety_checker: bool = False, + ): + super().__init__() + + if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + if hasattr(scheduler.config, "skip_prk_steps") and scheduler.config.skip_prk_steps is False: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} has not set the configuration" + " `skip_prk_steps`. `skip_prk_steps` should be set to True in the configuration file. Please make" + " sure to update the config accordingly as not setting `skip_prk_steps` in the config might lead to" + " incorrect results in future versions. If you have downloaded this checkpoint from the Hugging Face" + " Hub, it would be very nice if you could open a Pull request for the" + " `scheduler/scheduler_config.json` file" + ) + deprecate("skip_prk_steps not set", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["skip_prk_steps"] = True + scheduler._internal_dict = FrozenDict(new_config) + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( + version.parse(unet.config._diffusers_version).base_version + ) < version.parse("0.9.0.dev0") + is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: + deprecation_message = ( + "The configuration file of the unet has set the default `sample_size` to smaller than" + " 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the" + " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" + " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" + " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" + " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" + " in the config might lead to incorrect results in future versions. If you have downloaded this" + " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" + " the `unet/config.json` file" + ) + deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(unet.config) + new_config["sample_size"] = 64 + unet._internal_dict = FrozenDict(new_config) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload + def enable_sequential_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, + text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a + `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. + """ + if is_accelerate_available(): + from accelerate import cpu_offload + else: + raise ImportError("Please install accelerate via `pip install accelerate`") + + device = torch.device(f"cuda:{gpu_id}") + + for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: + if cpu_offloaded_model is not None: + cpu_offload(cpu_offloaded_model, device) + + if self.safety_checker is not None: + # TODO(Patrick) - there is currently a bug with cpu offload of nn.Parameter in accelerate + # fix by only offloading self.safety_checker for now + cpu_offload(self.safety_checker.vision_model, device) + + @property + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device + def _execution_device(self): + r""" + Returns the device on which the pipeline's models will be executed. After calling + `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module + hooks. + """ + if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"): + return self.device + for module in self.unet.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt + def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list(int)`): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + """ + batch_size = len(prompt) if isinstance(prompt, list) else 1 + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1: -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + text_embeddings = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + text_embeddings = text_embeddings[0] + + # duplicate text embeddings for each generation per prompt, using mps friendly method + bs_embed, seq_len, _ = text_embeddings.shape + text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1) + text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + max_length = text_input_ids.shape[-1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + uncond_embeddings = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + uncond_embeddings = uncond_embeddings[0] + + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = uncond_embeddings.shape[1] + uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1) + uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + + return text_embeddings + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + latents = 1 / 0.18215 * latents + image = self.vae.decode(latents).sample + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs + def check_inputs(self, prompt, height, width, callback_steps): + if not isinstance(prompt, str) and not isinstance(prompt, list): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + rand_device = "cpu" if device.type == "mps" else device + + if isinstance(generator, list): + shape = (1,) + shape[1:] + latents = [ + torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype) + for i in range(batch_size) + ] + latents = torch.cat(latents, dim=0).to(device) + else: + latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def prepare_mask_latents( + self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance + ): + # resize the mask to latents shape as we concatenate the mask to the latents + # we do that before converting to dtype to avoid breaking in case we're using cpu_offload + # and half precision + mask = torch.nn.functional.interpolate( + mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor) + ) + mask = mask.to(device=device, dtype=dtype) + + masked_image = masked_image.to(device=device, dtype=dtype) + + # encode the mask image into latents space so we can concatenate it to the latents + if isinstance(generator, list): + masked_image_latents = [ + self.vae.encode(masked_image[i: i + 1]).latent_dist.sample(generator=generator[i]) + for i in range(batch_size) + ] + masked_image_latents = torch.cat(masked_image_latents, dim=0) + else: + masked_image_latents = self.vae.encode(masked_image).latent_dist.sample(generator=generator) + masked_image_latents = 0.18215 * masked_image_latents + + # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method + if mask.shape[0] < batch_size: + if not batch_size % mask.shape[0] == 0: + raise ValueError( + "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to" + f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number" + " of masks that you pass is divisible by the total requested batch size." + ) + mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1) + if masked_image_latents.shape[0] < batch_size: + if not batch_size % masked_image_latents.shape[0] == 0: + raise ValueError( + "The passed images and the required batch size don't match. Images are supposed to be duplicated" + f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed." + " Make sure the number of images that you pass is divisible by the total requested batch size." + ) + masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1) + + mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask + masked_image_latents = ( + torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents + ) + + # aligning device to prevent device errors when concating it with the latent model input + masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) + return mask, masked_image_latents + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]], + image: Union[torch.FloatTensor, PIL.Image.Image], + mask_image: Union[torch.FloatTensor, PIL.Image.Image], + pose_map: torch.FloatTensor, + sketch: torch.FloatTensor, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: Optional[int] = 1, + sketch_cond_rate: float = 1.0, + start_cond_rate: float = 0, + no_pose: bool = False, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + image (`PIL.Image.Image`): + `Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will + be masked out with `mask_image` and repainted according to `prompt`. + mask_image (`PIL.Image.Image`): + `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be + repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted + to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L) + instead of 3, so the expected shape would be `(B, H, W, 1)`. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + # 1. Check inputs + self.check_inputs(prompt, height, width, callback_steps) + + # 2. Define call parameters + batch_size = 1 if isinstance(prompt, str) else len(prompt) + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + text_embeddings = self._encode_prompt( + prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt + ) + + # 4. Preprocess mask, image and posemap + mask, masked_image = prepare_mask_and_masked_image(image, mask_image) + pose_map = torch.nn.functional.interpolate( + pose_map, size=(pose_map.shape[2] // 8, pose_map.shape[3] // 8), mode="bilinear" + ) + if no_pose: + pose_map = torch.zeros_like(pose_map) + + sketch = torchvision.transforms.functional.resize( + sketch, size=(sketch.shape[2] // 8, sketch.shape[3] // 8), + interpolation=torchvision.transforms.InterpolationMode.BILINEAR, + antialias=True) + sketch = sketch + + # 5. set timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5a. Compute the number of steps to run sketch conditioning + # sketch_conditioning_steps = (1 - sketch_cond_rate) * num_inference_steps + start_cond_step = int(num_inference_steps * start_cond_rate) + + sketch_start = start_cond_step + sketch_end = sketch_cond_rate * num_inference_steps + start_cond_step + + # 6. Prepare latent variables + num_channels_latents = self.vae.config.latent_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + text_embeddings.dtype, + device, + generator, + latents, + ) + + # 7. Prepare mask latent variables + mask, masked_image_latents = self.prepare_mask_latents( + mask, + masked_image, + batch_size * num_images_per_prompt, + height, + width, + text_embeddings.dtype, + device, + generator, + do_classifier_free_guidance, + ) + + # 7a. Prepare pose map latent variables + pose_map = torch.cat([torch.zeros_like(pose_map), pose_map]) if do_classifier_free_guidance else pose_map + sketch = torch.cat([torch.zeros_like(sketch), sketch]) if do_classifier_free_guidance else sketch + + # 8. Check that sizes of mask, masked image and latents match + num_channels_mask = mask.shape[1] + num_channels_masked_image = masked_image_latents.shape[1] + num_channels_pose_map = pose_map.shape[1] + num_channels_sketch = sketch.shape[1] + + if num_channels_latents + num_channels_mask + num_channels_masked_image + num_channels_pose_map + num_channels_sketch != self.unet.config.in_channels: + raise ValueError( + f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" + f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" + f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" + f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of" + " `pipeline.unet` or your `mask_image` or `image` input." + ) + + # 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 10. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + + # 10a. Sketch conditioning + if i < sketch_start or i > sketch_end: + local_sketch = torch.zeros_like(sketch) + else: + local_sketch = sketch + + # concat latents, mask, masked_image_latents in the channel dimension + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + latent_model_input = torch.cat( + [latent_model_input, mask, masked_image_latents, pose_map.to(mask.dtype), local_sketch.to(mask.dtype)], + dim=1) + + # predict the noise residual + noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample.to( + self.vae.dtype) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + # 11. Post-processing + image = self.decode_latents(latents) + + # 13. Convert to PIL + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image, None) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=None) diff --git a/pipes/sketch_posemap_inpaint_pipe_disentangled.py b/pipes/sketch_posemap_inpaint_pipe_disentangled.py new file mode 100644 index 0000000..956362f --- /dev/null +++ b/pipes/sketch_posemap_inpaint_pipe_disentangled.py @@ -0,0 +1,638 @@ +import inspect +from typing import Callable, List, Optional, Union + +import PIL +import torch +import torchvision + +from diffusers.utils import is_accelerate_available +from packaging import version +from transformers import CLIPTextModel, CLIPTokenizer +from diffusers.configuration_utils import FrozenDict +from diffusers.models import AutoencoderKL, UNet2DConditionModel +from diffusers.pipeline_utils import DiffusionPipeline +from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler +from diffusers.utils import deprecate +from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput +from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint import prepare_mask_and_masked_image + + +class StableDiffusionSketchPosemapInpaintPipeline(DiffusionPipeline): + r""" + Pipeline for text and posemap -guided image inpainting using Stable Diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. + feature_extractor ([`CLIPFeatureExtractor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + _optional_components = ["safety_checker"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], + safety_checker = None, + feature_extractor = None, + requires_safety_checker: bool = False, + ): + super().__init__() + + if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + if hasattr(scheduler.config, "skip_prk_steps") and scheduler.config.skip_prk_steps is False: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} has not set the configuration" + " `skip_prk_steps`. `skip_prk_steps` should be set to True in the configuration file. Please make" + " sure to update the config accordingly as not setting `skip_prk_steps` in the config might lead to" + " incorrect results in future versions. If you have downloaded this checkpoint from the Hugging Face" + " Hub, it would be very nice if you could open a Pull request for the" + " `scheduler/scheduler_config.json` file" + ) + deprecate("skip_prk_steps not set", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["skip_prk_steps"] = True + scheduler._internal_dict = FrozenDict(new_config) + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( + version.parse(unet.config._diffusers_version).base_version + ) < version.parse("0.9.0.dev0") + is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: + deprecation_message = ( + "The configuration file of the unet has set the default `sample_size` to smaller than" + " 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the" + " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" + " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" + " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" + " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" + " in the config might lead to incorrect results in future versions. If you have downloaded this" + " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" + " the `unet/config.json` file" + ) + deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(unet.config) + new_config["sample_size"] = 64 + unet._internal_dict = FrozenDict(new_config) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload + def enable_sequential_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, + text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a + `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. + """ + if is_accelerate_available(): + from accelerate import cpu_offload + else: + raise ImportError("Please install accelerate via `pip install accelerate`") + + device = torch.device(f"cuda:{gpu_id}") + + for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: + if cpu_offloaded_model is not None: + cpu_offload(cpu_offloaded_model, device) + + if self.safety_checker is not None: + # TODO(Patrick) - there is currently a bug with cpu offload of nn.Parameter in accelerate + # fix by only offloading self.safety_checker for now + cpu_offload(self.safety_checker.vision_model, device) + + @property + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device + def _execution_device(self): + r""" + Returns the device on which the pipeline's models will be executed. After calling + `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module + hooks. + """ + if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"): + return self.device + for module in self.unet.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt + def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list(int)`): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + """ + batch_size = len(prompt) if isinstance(prompt, list) else 1 + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + text_embeddings = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + text_embeddings = text_embeddings[0] + + # duplicate text embeddings for each generation per prompt, using mps friendly method + bs_embed, seq_len, _ = text_embeddings.shape + text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1) + text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + max_length = text_input_ids.shape[-1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + uncond_embeddings = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + uncond_embeddings = uncond_embeddings[0] + + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = uncond_embeddings.shape[1] + uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1) + uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + text_embeddings = torch.cat([uncond_embeddings, text_embeddings, uncond_embeddings, uncond_embeddings]) + + return text_embeddings + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + latents = 1 / 0.18215 * latents + image = self.vae.decode(latents).sample + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs + def check_inputs(self, prompt, height, width, callback_steps): + if not isinstance(prompt, str) and not isinstance(prompt, list): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + rand_device = "cpu" if device.type == "mps" else device + + if isinstance(generator, list): + shape = (1,) + shape[1:] + latents = [ + torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype) + for i in range(batch_size) + ] + latents = torch.cat(latents, dim=0).to(device) + else: + latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def prepare_mask_latents( + self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance + ): + # resize the mask to latents shape as we concatenate the mask to the latents + # we do that before converting to dtype to avoid breaking in case we're using cpu_offload + # and half precision + mask = torch.nn.functional.interpolate( + mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor) + ) + mask = mask.to(device=device, dtype=dtype) + + masked_image = masked_image.to(device=device, dtype=dtype) + + # encode the mask image into latents space so we can concatenate it to the latents + if isinstance(generator, list): + masked_image_latents = [ + self.vae.encode(masked_image[i : i + 1]).latent_dist.sample(generator=generator[i]) + for i in range(batch_size) + ] + masked_image_latents = torch.cat(masked_image_latents, dim=0) + else: + masked_image_latents = self.vae.encode(masked_image).latent_dist.sample(generator=generator) + masked_image_latents = 0.18215 * masked_image_latents + + # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method + if mask.shape[0] < batch_size: + if not batch_size % mask.shape[0] == 0: + raise ValueError( + "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to" + f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number" + " of masks that you pass is divisible by the total requested batch size." + ) + mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1) + if masked_image_latents.shape[0] < batch_size: + if not batch_size % masked_image_latents.shape[0] == 0: + raise ValueError( + "The passed images and the required batch size don't match. Images are supposed to be duplicated" + f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed." + " Make sure the number of images that you pass is divisible by the total requested batch size." + ) + masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1) + + mask = torch.cat([mask] * 4) if do_classifier_free_guidance else mask + masked_image_latents = ( + torch.cat([masked_image_latents] * 4) if do_classifier_free_guidance else masked_image_latents + ) + + # aligning device to prevent device errors when concating it with the latent model input + masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) + return mask, masked_image_latents + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]], + image: Union[torch.FloatTensor, PIL.Image.Image], + mask_image: Union[torch.FloatTensor, PIL.Image.Image], + pose_map: torch.FloatTensor, + sketch: torch.FloatTensor, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + guidance_scale_pose: float = 7.5, + guidance_scale_sketch: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: Optional[int] = 1, + sketch_cond_rate: float = 1.0, + start_cond_rate: float = 0, + no_pose: bool = False, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + image (`PIL.Image.Image`): + `Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will + be masked out with `mask_image` and repainted according to `prompt`. + mask_image (`PIL.Image.Image`): + `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be + repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted + to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L) + instead of 3, so the expected shape would be `(B, H, W, 1)`. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + # 1. Check inputs + self.check_inputs(prompt, height, width, callback_steps) + + # 2. Define call parameters + batch_size = 1 if isinstance(prompt, str) else len(prompt) + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = True + + # 3. Encode input prompt + text_embeddings = self._encode_prompt( + prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt + ) + # pose_map_vis = pose_map.sum(dim=1).unsqueeze(1).clamp(0,1) + # plt.imsave("Pose.jpg", pose_map_vis[0].repeat(3,1,1).permute(1,2,0).cpu().numpy()) + + # 4. Preprocess mask, image and posemap + mask, masked_image = prepare_mask_and_masked_image(image, mask_image) + pose_map = torch.nn.functional.interpolate( + pose_map, size=(pose_map.shape[2] // 8, pose_map.shape[3] // 8), mode="bilinear" + ) + if no_pose: + pose_map = torch.zeros_like(pose_map) + + # plt.imsave("Sketch.jpg", sketch[0].repeat(3,1,1).permute(1,2,0).cpu().numpy()) + sketch = torchvision.transforms.functional.resize( + sketch, size=(sketch.shape[2] // 8, sketch.shape[3] // 8), + interpolation=torchvision.transforms.InterpolationMode.BILINEAR, + antialias=True) + sketch = sketch + # plt.imsave("Image.png", ((image[0] + 1) / 2).permute(1,2,0).cpu().numpy()) + + # 5. set timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5a. Compute the number of steps to run sketch conditioning + # sketch_conditioning_steps = (1 - sketch_cond_rate) * num_inference_steps + start_cond_step = int(num_inference_steps * start_cond_rate) + + sketch_start = start_cond_step + sketch_end = sketch_cond_rate * num_inference_steps + start_cond_step + + # 6. Prepare latent variables + num_channels_latents = self.vae.config.latent_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + text_embeddings.dtype, + device, + generator, + latents, + ) + + # 7. Prepare mask latent variables + mask, masked_image_latents = self.prepare_mask_latents( + mask, + masked_image, + batch_size * num_images_per_prompt, + height, + width, + text_embeddings.dtype, + device, + generator, + do_classifier_free_guidance, + ) + + # 7a. Prepare pose map latent variables + pose_map = torch.cat([torch.zeros_like(pose_map), torch.zeros_like(pose_map), pose_map, torch.zeros_like(pose_map)]) if do_classifier_free_guidance else pose_map + sketch = torch.cat([torch.zeros_like(sketch), torch.zeros_like(sketch), torch.zeros_like(sketch), sketch]) if do_classifier_free_guidance else sketch + + # 8. Check that sizes of mask, masked image and latents match + num_channels_mask = mask.shape[1] + num_channels_masked_image = masked_image_latents.shape[1] + if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels: + raise ValueError( + f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" + f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" + f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" + f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" + " `pipeline.unet` or your `mask_image` or `image` input." + ) + + # 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 10. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 4) if do_classifier_free_guidance else latents + + # 10a. Sketch conditioning + if i < sketch_start or i > sketch_end: + local_sketch = torch.zeros_like(sketch) + else: + local_sketch = sketch + + # concat latents, mask, masked_image_latents in the channel dimension + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents, pose_map.to(mask.dtype), local_sketch.to(mask.dtype)], dim=1) + + # predict the noise residual + noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text, noise_pred_pose, noise_pred_sketch = noise_pred.chunk(4) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + guidance_scale_pose * (noise_pred_pose - noise_pred_uncond) + guidance_scale_sketch * (noise_pred_sketch - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + # 11. Post-processing + image = self.decode_latents(latents) + + # 13. Convert to PIL + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image, None) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=None) diff --git a/utils/arg_parser.py b/utils/arg_parser.py new file mode 100644 index 0000000..4f56a45 --- /dev/null +++ b/utils/arg_parser.py @@ -0,0 +1,112 @@ +import os +import argparse + + +def parse_args() -> argparse.Namespace: + """ This function parses the arguments passed to the script. + + Returns: + argparse.Namespace: Namespace containing the arguments. + """ + + parser = argparse.ArgumentParser(description="Multimodal Garment Designer argparse.") + + # Diffusion parameters + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default="runwayml/stable-diffusion-inpainting", + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + required=False, + help="Revision of pretrained model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--non_ema_revision", + type=str, + default=None, + required=False, + help=( + "Revision of pretrained non-ema model identifier. Must be a branch, tag or git identifier of the local or" + " remote repository specified with --pretrained_model_name_or_path." + ), + ) + + # destination folder + parser.add_argument( + "--output_dir", + type=str, + required=True, + help="The output directory where the model predictions will be written.", + ) + + # Accelerator parameters + parser.add_argument( + "--mixed_precision", + type=str, + default=None, + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." + ), + ) + parser.add_argument( + "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." + ) + + # dataset parameters + parser.add_argument("--dataset", type=str, required=True, choices=["dresscode", "vitonhd"], help="dataset to use") + parser.add_argument( + "--dataset_path", + type=str, + default="", + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + parser.add_argument("--category", type=str, default="") + parser.add_argument("--test_order", type=str, default="unpaired", choices=["unpaired", "paired"]) + + # dataloader parameters + parser.add_argument("--batch_size", type=int, default=1, help="Batch size (per device) for the test dataloader.") + parser.add_argument("--num_workers_test", type=int, default=8, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + + # input parameters + parser.add_argument("--mask_type", type=str, default="bounding_box", choices=["keypoints", "bounding_box"]) + parser.add_argument("--no_pose", action="store_true", help="exclude posemap from input") + + + # disentagle classifier free guidance parameters + parser.add_argument("--disentagle", action="store_true") + parser.add_argument("--guidance_scale", type=float, default=7.5, help="text guidance scale, use with disentagle") + parser.add_argument("--guidance_scale_pose", type=float, default=7.5, + help="pose guidance scale, use with disentagle") + parser.add_argument("--guidance_scale_sketch", type=float, default=7.5, + help="sketch guidance scale, use with disentagle") + + # sketch conditioninig paramters + parser.add_argument("--sketch_cond_rate", type=float, default=0.2, help="Sketch conditioning rate") + parser.add_argument("--start_cond_rate", type=float, default=0.0, help="offset sketch cond rate") + + # miscelaneous parameters + parser.add_argument("--seed", type=int, default=1234, help="A seed for reproducible training.") + parser.add_argument("--save_name", type=str, default="") + + args = parser.parse_args() + + # if not, set default local rank + env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) + if env_local_rank != -1 and env_local_rank != args.local_rank: + args.local_rank = env_local_rank + + # default to using the same revision for the non-ema model if not specified + if args.non_ema_revision is None: + args.non_ema_revision = args.revision + + return args diff --git a/utils/image_composition.py b/utils/image_composition.py new file mode 100644 index 0000000..3ef807f --- /dev/null +++ b/utils/image_composition.py @@ -0,0 +1,23 @@ +import torch +import torchvision.transforms.functional as F + +def compose_img(gt_img, fake_img, im_parse): + + seg_head = torch.logical_or(im_parse == 1, im_parse == 2) + seg_head = torch.logical_or(seg_head, im_parse == 4) + seg_head = torch.logical_or(seg_head, im_parse == 13) + + true_head = gt_img * seg_head + true_parts = true_head + + generated_body = (F.pil_to_tensor(fake_img).cuda() / 255) * (~(seg_head)) + + return true_parts + generated_body + +def compose_img_dresscode(gt_img, fake_img, im_head): + + seg_head = im_head + true_head = gt_img * seg_head + generated_body = fake_img * ~(seg_head) + + return true_head + generated_body \ No newline at end of file diff --git a/utils/image_from_pipe.py b/utils/image_from_pipe.py new file mode 100644 index 0000000..c1cb0e6 --- /dev/null +++ b/utils/image_from_pipe.py @@ -0,0 +1,128 @@ +import os +from tqdm import tqdm +import torch + +import torchvision.transforms as T +from diffusers.pipeline_utils import DiffusionPipeline +from argparse import ArgumentParser +from torch.utils.data import DataLoader +from utils.image_composition import compose_img, compose_img_dresscode +from PIL import Image + +@torch.inference_mode() +def generate_images_from_inpaint_sketch_posemap_pipe( + test_order: bool, + pipe: DiffusionPipeline, + test_dataloader: DataLoader, + save_name: str, + dataset: str, + output_dir: str, + guidance_scale: float = 7.5, + guidance_scale_pose: float = 7.5, + guidance_scale_sketch: float = 7.5, + sketch_cond_rate: float = 1.0, + start_cond_rate: float = 0.0, + no_pose: bool = False, + disentagle: bool = False, + seed: int = 1234, + ) -> None: + #This function generates images from the given test dataloader and saves them to the output directory. + """ + Args: + test_order: The order of the test dataset. + pipe: The diffusion pipeline. + test_dataloader: The test dataloader. + save_name: The name of the saved images. + dataset: The name of the dataset. + output_dir: The output directory. + guidance_scale: The guidance scale. + guidance_scale_pose: The guidance scale for the pose. + guidance_scale_sketch: The guidance scale for the sketch. + sketch_cond_rate: The sketch condition rate. + start_cond_rate: The start condition rate. + no_pose: Whether to use the pose. + disentagle: Whether to use disentagle. + seed: The seed. + + Returns: + None + """ + assert(save_name != ""), "save_name must be specified" + assert(output_dir != ""), "output_dir must be specified" + + path = os.path.join(output_dir, f"{save_name}_{test_order}", "images") + + os.makedirs(path, exist_ok=True) + generator = torch.Generator("cuda").manual_seed(seed) + + for batch in tqdm(test_dataloader): + model_img = batch["image"] + mask_img = batch["inpaint_mask"] + mask_img = mask_img.type(torch.float32) + prompts = batch["original_captions"] # prompts is a list of length N, where N=batch size. + pose_map = batch["pose_map"] + sketch = batch["im_sketch"] + ext = ".jpg" + + if disentagle: + guidance_scale = guidance_scale + num_samples = 1 + guidance_scale_pose = guidance_scale_pose + guidance_scale_sketch = guidance_scale_sketch + generated_images = pipe( + prompt=prompts, + image=model_img, + mask_image=mask_img, + pose_map=pose_map, + sketch=sketch, + height=512, + width=384, + guidance_scale=guidance_scale, + num_images_per_prompt=num_samples, + generator=generator, + sketch_cond_rate=sketch_cond_rate, + guidance_scale_pose=guidance_scale_pose, + guidance_scale_sketch=guidance_scale_sketch, + start_cond_rate=start_cond_rate, + no_pose=no_pose, + ).images + else: + guidance_scale = 7.5 + num_samples = 1 + generated_images = pipe( + prompt=prompts, + image=model_img, + mask_image=mask_img, + pose_map=pose_map, + sketch=sketch, + height=512, + width=384, + guidance_scale=guidance_scale, + num_images_per_prompt=num_samples, + generator=generator, + sketch_cond_rate=sketch_cond_rate, + start_cond_rate=start_cond_rate, + no_pose=no_pose, + ).images + + for i in range(len(generated_images)): + model_i = model_img[i] * 0.5 + 0.5 + if dataset == "vitonhd": + final_img = compose_img(model_i, generated_images[i], batch['im_parse'][i]) + else: + face = batch["stitch_label"][i].to(model_img.device) + face = T.functional.resize(face, + size=(512,384), + interpolation=T.InterpolationMode.BILINEAR, + antialias = True + ) + + final_img = compose_img_dresscode( + gt_img = model_i, + fake_img = T.functional.to_tensor(generated_images[i]).to(model_img.device), + im_head = face + ) + + final_img = T.functional.to_pil_image(final_img) + final_img.save( + os.path.join(path, batch["im_name"][i].replace(".jpg", ext))) diff --git a/utils/labelmap.py b/utils/labelmap.py new file mode 100644 index 0000000..100c0dd --- /dev/null +++ b/utils/labelmap.py @@ -0,0 +1,36 @@ +label_map={ + "background": 0, + "hat": 1, + "hair": 2, + "sunglasses": 3, + "upper_clothes": 4, + "skirt": 5, + "pants": 6, + "dress": 7, + "belt": 8, + "left_shoe": 9, + "right_shoe": 10, + "head": 11, + "left_leg": 12, + "right_leg": 13, + "left_arm": 14, + "right_arm": 15, + "bag": 16, + "scarf": 17, +} + +label_map_vitonhd = { + 0: ['background', [0, 10]], # 0 is background, 10 is neck + 1: ['hair', [1, 2]], # 1 and 2 are hair + 2: ['face', [4, 13]], + 3: ['upper', [5, 6, 7]], + 4: ['bottom', [9, 12]], + 5: ['left_arm', [14]], + 6: ['right_arm', [15]], + 7: ['left_leg', [16]], + 8: ['right_leg', [17]], + 9: ['left_shoe', [18]], + 10: ['right_shoe', [19]], + 11: ['socks', [8]], + 12: ['noise', [3, 11]] + } \ No newline at end of file diff --git a/utils/posemap.py b/utils/posemap.py new file mode 100644 index 0000000..aa1e78d --- /dev/null +++ b/utils/posemap.py @@ -0,0 +1,58 @@ +import torch +import numpy as np + + + +def kpoint_to_heatmap(kpoint, shape, sigma): + """Converts a 2D keypoint to a gaussian heatmap + + Parameters + ---------- + kpoint: np.array + 2D coordinates of keypoint [x, y]. + shape: tuple + Heatmap dimension (HxW). + sigma: float + Variance value of the gaussian. + + Returns + ------- + heatmap: np.array + A gaussian heatmap HxW. + """ + map_h = shape[0] + map_w = shape[1] + if np.any(kpoint > 0): + x, y = kpoint + # x = x * map_w / 384.0 + # y = y * map_h / 512.0 + xy_grid = np.mgrid[:map_w, :map_h].transpose(2, 1, 0) + heatmap = np.exp(-np.sum((xy_grid - (x, y)) ** 2, axis=-1) / sigma ** 2) + heatmap /= (heatmap.max() + np.finfo('float32').eps) + else: + heatmap = np.zeros((map_h, map_w)) + return torch.Tensor(heatmap) + + +def get_coco_body25_mapping(): + #left numbers are coco format while right numbers are body25 format + return { + 0:0, + 1:1, + 2:2, + 3:3, + 4:4, + 5:5, + 6:6, + 7:7, + 8:9, + 9:10, + 10:11, + 11:12, + 12:13, + 13:14, + 14:15, + 15:16, + 16:17, + 17:18 + } \ No newline at end of file diff --git a/utils/set_seeds.py b/utils/set_seeds.py new file mode 100644 index 0000000..80488a3 --- /dev/null +++ b/utils/set_seeds.py @@ -0,0 +1,15 @@ +import random +import os +import numpy as np +import torch +import accelerate + + +def set_seed(seed): + random.seed(seed) + os.environ['PYTHONHASHSEED'] = str(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.backends.cudnn.deterministic = True + accelerate.utils.set_seed(seed)