Zero-shot personalized image generation models aim to produce images that align with both a given text prompt and subject image, requiring the model to effectively incorporate both sources of guidance. However, existing methods often struggle to capture fine-grained subject details and frequently prioritize one form of guidance over the other, resulting in suboptimal subject encoding and an imbalance in the generated images. In this study, we uncover key insights into achieving high-quality balances on subject identity preservation and text-following, notably that 1) the design of the subject image encoder critically influences subject identity preservation, and 2) the text and subject guidance should take effect at different denoising stages. Building on these insights, we introduce a new approach, EZIGen, that employs two main components: a carefully crafted subject image encoder based on the pretrained UNet of the Stable Diffusion model, following a process that balances the two guidances by separating their dominance stage and revisiting certain time steps to bootstrap subject transfer quality. Through these two components, EZIGen achieves state-of-the-art results on multiple personalized generation benchmarks with a unified model and 100 times less training data.
Clone this repo
git clone [email protected]:ZichengDuan/EZIGen.git
cd EZIGen
Prepare Conda environment
conda create -n ezigen python=3.10 -y && conda activate ezigen
Install PyTorch
pip install torch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 --index-url https://download.pytorch.org/whl/cu118
Build Diffusers from source
wget https://github.com/huggingface/diffusers/archive/refs/tags/v0.30.1.zip
cd diffusers-0.30.1
pip install . && cd .. && rm v0.30.1.zip
Install remaining dependencies
pip install -r requirements.txt
We provide inference code for both subject-driven generation tasks and subject-driven image editing. Exemplary results can be found in the outputs
folder.
Download the checkpoint(checkpoint-200000.zip
) from Google Drive, unzip it to your local folder.
Please first turn to config/infer_config.yaml
to assign a correct checkpoint folder path (e.g. checkpoint-200000/
).
The script for subject-driven generation and human content generation is provided in infer_generation.sh
:
# infer_generation.sh
python infer.py \
--config configs/infer_config.yaml \
--guidance_scale 7.5\
--seed 3154 \
--split_ratio 0.4 \
--infer_steps 50 \
--sim_threshold 0.99 \
--target_prompt "a dog in police outfit" \
--subject_prompt "a dog" \
--subject_img_path "example_images/subjects/dog6.png" \
--output_root "outputs/" \
# --num_interations 6
Some explanations for the arguments:
-
split_ratio=0.4
means that we leave the last 40% of timesteps for Appearance Transfer, the first 60% steps for Layout Generation Process. the value ranges from 0 to 1 in which large value indicates more Appearance Transfer. -
sim_threshold
is the CLIP similarity threshold for autostop.subject_prompt
acts as a placeholder, however, it's always recommended to type in the correct class name of the subject image for best subject feature extraction. -
# --num_interations 6
is by default set to -1 to give way to the autostop mechanism (with a minimum of 3 and maximum of 10 iterations), however, you can uncomment this line and assign the desired iteration number.
Some subjects are presented in example_images/subjects
.
# infer_editing.sh
python infer.py \
--config configs/infer_config.yaml \
--guidance_scale 7.5\
--seed 3154 \
--split_ratio 0.4 \
--infer_steps 50 \
--sim_threshold 0.99 \
--target_prompt "a woman" \
--subject_prompt "a woman" \
--subject_img_path "example_images/subjects/lifeifei.png" \
--output_root "outputs/" \
--foreground_mask_path example_images/source_images_with_masks/woman_mask.png \
--source_image_path example_images/source_images_with_masks/woman.png \
--do_editing
# --num_interations 6
Some explanations for the arguments:
-
source_image_path
: the path to the source RGB image for editing. -
foreground_mask_path
: the path to a 3-channel mask with foreground as (255, 255, 255) and background as (0, 0, 0), indicating the source image area for editing, should be the same height and width as the source image.
Some input examples are presented in example_images/source_images_with_masks
.
The user can simply take a generated image from any off-the-shelf image generator and edit it with infer_edit.sh
, example results from FLUX are shown below:
You can also start you own training following the instructions below:
Download YoutubeVIS2019 dataset (training split) following this link: https://competitions.codalab.org/competitions/20128#participate-get_data
Download COCO2014 dataset (train/val splits) following this link: https://cocodataset.org/#download
Extract the data to local folders and configure the corresponding path in configs/train_config.yaml
After dataset preparation, you can then simply start DDP training with HuggingFace Accelerator:
sh train.sh
Alternatively, you can also run the training using plain python on a single GPU:
python train.py --config configs/train_config.yaml
The checkpoint folders (e.g. checkpoint-5000) and tensorboard log will be automatically saved to the output_dir
, which in turn can be used to do the inference.
In default setting, with 200k samples, the training takes about 4 hours on 8 A100-40G GPUs, and 26 hours on 1 A100-40G GPU, with batch size equals to 1 on each device. The provided checkpoint is trained on a single GPU, thus having a checkpoint post fix '200000', while if trained on multiple devices, the checkpointing post fix would be num_samples / num_GPUs
.
- Demo pages
- Inference code and checkpoint
- Training code
Thanks AnyDoor for providing the YoutubeVIS dataset scripts, shout out to this great work!
If you find this codebase useful for your research, please cite as follows:
@article{duan2024ezigen,
title={EZIGen: Enhancing zero-shot personalized image generation with precise subject encoding and decoupled guidance},
author={Duan, Zicheng and Ding, Yuxuan and Gou, Chenhui and Zhou, Ziqin and Smith, Ethan and Liu, Lingqiao},
journal={arXiv preprint arXiv:2409.08091},
year={2024}
}