diff --git a/.gitignore b/.gitignore index da99824aa3..3781f1c20e 100644 --- a/.gitignore +++ b/.gitignore @@ -138,4 +138,4 @@ dmypy.json .DS_Store # More test things -wandb \ No newline at end of file +wandb diff --git a/README.md b/README.md index 9024860947..4b40ea993a 100644 --- a/README.md +++ b/README.md @@ -155,4 +155,4 @@ To use 🤗 PEFT in your publication, please cite it by using the following BibT howpublished = {\url{https://github.com/huggingface/peft}}, year = {2022} } -``` +``` \ No newline at end of file diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index ffd54c18d4..614a94b6d0 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -57,6 +57,8 @@ title: Soft prompts - local: conceptual_guides/ia3 title: IA3 + - local: conceptual_guides/oft + title: OFT/BOFT - sections: - sections: @@ -90,6 +92,8 @@ title: Multitask Prompt Tuning - local: package_reference/oft title: OFT + - local: package_reference/boft + title: BOFT - local: package_reference/poly title: Polytropon - local: package_reference/p_tuning diff --git a/docs/source/conceptual_guides/adapter.md b/docs/source/conceptual_guides/adapter.md index 14e4b8339b..e80ec655f6 100644 --- a/docs/source/conceptual_guides/adapter.md +++ b/docs/source/conceptual_guides/adapter.md @@ -71,6 +71,12 @@ LoHa uses the [Hadamard product](https://en.wikipedia.org/wiki/Hadamard_product_ OFT preserves the hyperspherical energy by learning an orthogonal transformation for neurons to keep the cosine similarity between them unchanged. In practice, this means taking the matrix product of an orthogonal matrix with the pretrained weight matrix. However, to be parameter-efficient, the orthogonal matrix is represented as a block-diagonal matrix with rank `r` blocks. Whereas LoRA reduces the number of trainable parameters with low-rank structures, OFT reduces the number of trainable parameters with a sparse block-diagonal matrix structure. +## Orthogonal Butterfly (BOFT) + +[BOFT](https://hf.co/papers/2311.06243) is a method that primarily focuses on preserving a pretrained model's generative performance in the finetuned model. It tries to maintain the same cosine similarity (hyperspherical energy) between all pairwise neurons in a layer because this better captures the semantic information among neurons. This means OFT is more capable at preserving the subject and it is better for controllable generation (similar to [ControlNet](https://huggingface.co/docs/diffusers/using-diffusers/controlnet)). + +OFT preserves the hyperspherical energy by learning an orthogonal transformation for neurons to keep the cosine similarity between them unchanged. In practice, this means taking the matrix product of an orthogonal matrix with the pretrained weight matrix. However, to be parameter-efficient, the orthogonal matrix is represented as a block-diagonal matrix with rank `r` blocks. Whereas LoRA reduces the number of trainable parameters with low-rank structures, OFT reduces the number of trainable parameters with a sparse block-diagonal matrix structure. + ## Adaptive Low-Rank Adaptation (AdaLoRA) [AdaLoRA](https://hf.co/papers/2303.10512) manages the parameter budget introduced from LoRA by allocating more parameters - in other words, a higher rank `r` - for important weight matrices that are better adapted for a task and pruning less important ones. The rank is controlled by a method similar to singular value decomposition (SVD). The ∆W is parameterized with two orthogonal matrices and a diagonal matrix which contains singular values. This parametrization method avoids iteratively applying SVD which is computationally expensive. Based on this method, the rank of ∆W is adjusted according to an importance score. ∆W is divided into triplets and each triplet is scored according to its contribution to model performance. Triplets with low importance scores are pruned and triplets with high importance scores are kept for finetuning. diff --git a/docs/source/conceptual_guides/oft.md b/docs/source/conceptual_guides/oft.md new file mode 100644 index 0000000000..1693989691 --- /dev/null +++ b/docs/source/conceptual_guides/oft.md @@ -0,0 +1,107 @@ + + +# Orthogonal Finetuning (OFT and BOFT) + +This conceptual guide gives a brief overview of [OFT](https://arxiv.org/abs/2306.07280) and [BOFT](https://arxiv.org/abs/2311.06243), a parameter-efficient fine-tuning technique that utilizes orthogonal matrix to multiplicatively transform the pretrained weight matrices. + +To achieve efficient fine-tuning, OFT represents the weight updates with an orthogonal transformation. The orthogonal transformation is parameterized by an orthogonal matrix multiplied to the pretrained weight matrix. These new matrices can be trained to adapt to the new data while keeping the overall number of changes low. The original weight matrix remains frozen and doesn’t receive any further adjustments. To produce the final results, both the original and the adapted weights are multiplied togethor. + +Orthogonal Butterfly (BOFT) generalizes OFT with Butterfly factorization and further improves its parameter efficiency and finetuning flexibility. In short, OFT can be viewed as a special case of BOFT. Different from LoRA that uses additive low-rank weight updates, BOFT uses multiplicative orthogonal weight updates. The comparison is shown below. + +
+ +
+ + +BOFT has some advantages compared to LoRA: + +* BOFT proposes a simple yet generic way to finetune pretrained models to downstream tasks, yielding a better preservation of pretraining knowledge and a better parameter efficiency. +* Through the orthogonality, BOFT introduces a structural constraint, i.e., keeping the [hyperspherical energy](https://arxiv.org/abs/1805.09298) unchanged during finetuning. This can effectively reduce the forgetting of pretraining knowledge. +* BOFT uses the butterfly factorization to efficiently parameterize the orthogonal matrix, which yields a compact yet expressive learning space (i.e., hypothesis class). +* The sparse matrix decomposition in BOFT brings in additional inductive biases that are beneficial to generalization. + +In principle, BOFT can be applied to any subset of weight matrices in a neural network to reduce the number of trainable parameters. Given the target layers for injecting BOFT parameters, the number of trainable parameters can be determined based on the size of the weight matrices. + +## Merge OFT/BOFT weights into the base model + +Similar to LoRA, the weights learned by OFT/BOFT can be integrated into the pretrained weight matrices using the merge_and_unload() function. This function merges the adapter weights with the base model which allows you to effectively use the newly merged model as a standalone model. + +
+ +
+ +This works because during training, the orthogonal weight matrix (R in the diagram above) and the pretrained weight matrices are separate. But once training is complete, these weights can actually be merged (multiplied) into a new weight matrix that is equivalent. + +## Utils for OFT / BOFT + +### Common OFT / BOFT parameters in PEFT + +As with other methods supported by PEFT, to fine-tune a model using OFT or BOFT, you need to: + +1. Instantiate a base model. +2. Create a configuration (`OFTConfig` or `BOFTConfig`) where you define OFT/BOFT-specific parameters. +3. Wrap the base model with `get_peft_model()` to get a trainable `PeftModel`. +4. Train the `PeftModel` as you normally would train the base model. + + +### BOFT-specific paramters + +`BOFTConfig` allows you to control how OFT/BOFT is applied to the base model through the following parameters: + +- `boft_block_size`: the BOFT matrix block size across different layers, expressed in `int`. Smaller block size results in sparser update matrices with fewer trainable paramters. **Note**, please choose `boft_block_size` to be divisible by most layer's input dimension (`in_features`), e.g., 4, 8, 16. Also, please only +specify either `boft_block_size` or `boft_block_num`, but not both simultaneously or leaving both to 0, because `boft_block_size` x `boft_block_num` must equal the layer's input dimension. +- `boft_block_num`: the number of BOFT matrix blocks across different layers, expressed in `int`. Fewer blocks result in sparser update matrices with fewer trainable paramters. **Note**, please choose `boft_block_num` to be divisible by most layer's input dimension (`in_features`), e.g., 4, 8, 16. Also, please only +specify either `boft_block_size` or `boft_block_num`, but not both simultaneously or leaving both to 0, because `boft_block_size` x `boft_block_num` must equal the layer's input dimension. +- `boft_n_butterfly_factor`: the number of butterfly factors. **Note**, for `boft_n_butterfly_factor=1`, BOFT is the same as vanilla OFT, for `boft_n_butterfly_factor=2`, the effective block size of OFT becomes twice as big and the number of blocks become half. +- `bias`: specify if the `bias` parameters should be trained. Can be `"none"`, `"all"` or `"boft_only"`. +- `boft_dropout`: specify the probability of multiplicative dropout. +- `target_modules`: The modules (for example, attention blocks) to inject the OFT/BOFT matrices. +- `modules_to_save`: List of modules apart from OFT/BOFT matrices to be set as trainable and saved in the final checkpoint. These typically include model's custom head that is randomly initialized for the fine-tuning task. + + + +## BOFT Example Usage + +For an example of the BOFT method application to various downstream tasks, please refer to the following guides: + +Take a look at the following step-by-step guides on how to finetune a model with BOFT: +- [Dreambooth finetuning with BOFT](../task_guides/boft_dreambooth) +- [Controllable generation finetuning with BOFT (ControlNet)](../task_guides/boft_controlnet) + +For the task of image classification, one can initialize the BOFT config for a DinoV2 model as follows: + +```py +import transformers +from transformers import AutoModelForSeq2SeqLM, BOFTConfig +from peft import BOFTConfig, get_peft_model + +config = BOFTConfig( + boft_block_size=4, + boft_n_butterfly_factor=2, + target_modules=["query", "value", "key", "output.dense", "mlp.fc1", "mlp.fc2"], + boft_dropout=0.1, + bias="boft_only", + modules_to_save=["classifier"], +) + +model = transformers.Dinov2ForImageClassification.from_pretrained( + "facebook/dinov2-large", + num_labels=100, +) + +boft_model = get_peft_model(model, config) +``` diff --git a/docs/source/package_reference/boft.md b/docs/source/package_reference/boft.md new file mode 100644 index 0000000000..1384bfe6fa --- /dev/null +++ b/docs/source/package_reference/boft.md @@ -0,0 +1,31 @@ + + +# BOFT + +[Orthogonal Butterfly (BOFT)](https://hf.co/papers/2311.06243) is a generic method designed for finetuning foundation models. It improves the paramter efficiency of the finetuning paradigm -- Orthogonal Finetuning (OFT), by taking inspiration from Cooley-Tukey fast Fourier transform, showing favorable results across finetuning different foundation models, including large vision transformers, large language models and text-to-image diffusion models. + +The abstract from the paper is: + +*Large foundation models are becoming ubiquitous, but training them from scratch is prohibitively expensive. Thus, efficiently adapting these powerful models to downstream tasks is increasingly important. In this paper, we study a principled finetuning paradigm -- Orthogonal Finetuning (OFT) -- for downstream task adaptation. Despite demonstrating good generalizability, OFT still uses a fairly large number of trainable parameters due to the high dimensionality of orthogonal matrices. To address this, we start by examining OFT from an information transmission perspective, and then identify a few key desiderata that enable better parameter-efficiency. Inspired by how the Cooley-Tukey fast Fourier transform algorithm enables efficient information transmission, we propose an efficient orthogonal parameterization using butterfly structures. We apply this parameterization to OFT, creating a novel parameter-efficient finetuning method, called Orthogonal Butterfly (BOFT). By subsuming OFT as a special case, BOFT introduces a generalized orthogonal finetuning framework. Finally, we conduct an extensive empirical study of adapting large vision transformers, large language models, and text-to-image diffusion models to various downstream tasks in vision and language*. + +## BOFTConfig + +[[autodoc]] tuners.boft.config.BOFTConfig + +## BOFTModel + +[[autodoc]] tuners.boft.model.BOFTModel diff --git a/examples/boft_controlnet/__init__.py b/examples/boft_controlnet/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/examples/boft_controlnet/boft_controlnet.md b/examples/boft_controlnet/boft_controlnet.md new file mode 100644 index 0000000000..e6b98c5db9 --- /dev/null +++ b/examples/boft_controlnet/boft_controlnet.md @@ -0,0 +1,177 @@ + + + +# Fine-tuning for controllable generation with BOFT (ControlNet) + +This guide demonstrates how to use BOFT, an orthogonal fine-tuning method, to fine-tune Stable Diffusion with either `stabilityai/stable-diffusion-2-1` or `runwayml/stable-diffusion-v1-5` model for controllable generation. + +By using BOFT from 🤗 PEFT, we can significantly reduce the number of trainable parameters while still achieving impressive results in various fine-tuning tasks across different foundation models. BOFT enhances model efficiency by integrating full-rank orthogonal matrices with a butterfly structure into specific model blocks, such as attention blocks, mirroring the approach used in LoRA. During fine-tuning, only these inserted matrices are trained, leaving the original model parameters untouched. During inference, the trainable BOFT paramteres can be merged into the original model, eliminating any additional computational costs. + +As a member of the **orthogonal finetuning** class, BOFT presents a systematic and principled method for fine-tuning. It possesses several unique properties and has demonstrated superior performance compared to LoRA in a variety of scenarios. For further details on BOFT, please consult the [PEFT's GitHub repo's concept guide OFT](https://https://huggingface.co/docs/peft/index), the [original BOFT paper](https://arxiv.org/abs/2311.06243) and the [original OFT paper](https://arxiv.org/abs/2306.07280). + +In this guide we provide a controllable generation (ControlNet) fine-tuning script that is available in [PEFT's GitHub repo examples](https://github.com/huggingface/peft/tree/main/examples/boft_controlnet). This implementation is adapted from [diffusers's ControlNet](https://github.com/huggingface/diffusers/tree/main/examples/controlnet) and [Hecong Wu's ControlLoRA](https://github.com/HighCWu/ControlLoRA). You can try it out and finetune on your custom images. + +## Set up your environment +Start by cloning the PEFT repository: + +```bash +git clone https://github.com/huggingface/peft +``` + +Navigate to the directory containing the training scripts for fine-tuning Dreambooth with BOFT: +```bash +cd peft/examples/boft_controlnet +``` + +Set up your environment: install PEFT, and all the required libraries. At the time of writing this guide we recommend installing PEFT from source. + +```bash +conda create --name peft python=3.10 +conda activate peft +conda install pytorch==2.1.2 torchvision==0.16.2 torchaudio==2.1.2 pytorch-cuda=11.8 -c pytorch -c nvidia +conda install xformers -c xformers +pip install -r requirements.txt +pip install git+https://github.com/huggingface/peft +``` + +## Data + +We use the [control-celeba-hq](https://huggingface.co/datasets/oftverse/control-celeba-hq) dataset for landmark-to-face controllable generation. We also provide evaluation scripts to evaluate the controllable generation performance. This task can be used to quantitatively compare different fine-tuning techniques. + +```bash +export DATASET_NAME="oftverse/control-celeba-hq" +``` + +## Train controllable generation (ControlNet) with BOFT + +Start with setting some hyperparamters for BOFT: +```bash +PEFT_TYPE="boft" +BLOCK_NUM=8 +BLOCK_SIZE=0 +N_BUTTERFLY_FACTOR=0 +``` + +Here: + + +Navigate to the directory containing the training scripts for fine-tuning Stable Diffusion with BOFT for controllable generation: + +```bash +./train_controlnet.sh +``` +or +```bash +export MODEL_NAME="stabilityai/stable-diffusion-2-1" +# export MODEL_NAME="runwayml/stable-diffusion-v1-5" + +export DATASET_NAME="oftverse/control-celeba-hq" +export PROJECT_NAME="controlnet_${PEFT_TYPE}" +export RUN_NAME="${PEFT_TYPE}_${BLOCK_NUM}${BLOCK_SIZE}${N_BUTTERFLY_FACTOR}" +export CONTROLNET_PATH="" +export OUTPUT_DIR="./output/${DATASET_NAME}/${RUN_NAME}" + +accelerate launch train_controlnet.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --resume_from_checkpoint=$RESUME_PATH \ + --controlnet_model_name_or_path=$CONTROLNET_PATH \ + --output_dir=$OUTPUT_DIR \ + --report_to="wandb" \ + --dataset_name=$DATASET_NAME \ + --resolution=512 \ + --learning_rate=1e-5 \ + --checkpointing_steps=5000 \ + --max_train_steps=50000 \ + --validation_steps=2000 \ + --num_validation_images=12 \ + --train_batch_size=4 \ + --dataloader_num_workers=2 \ + --seed="0" \ + --lr_scheduler="constant" \ + --lr_warmup_steps=0 \ + --wandb_project_name=$PROJECT_NAME \ + --wandb_run_name=$RUN_NAME \ + --enable_xformers_memory_efficient_attention \ + --use_boft \ + --boft_block_num=$BLOCK_NUM \ + --boft_block_size=$BLOCK_SIZE \ + --boft_n_butterfly_factor=$N_BUTTERFLY_FACTOR \ + --boft_dropout=0.1 \ + --boft_bias="boft_only" \ + --report_to="wandb" \ +``` + +Run inference on the saved model to sample new images from the validation set: + +```bash +./test_controlnet.sh +``` +or +```bash +ITER_NUM=50000 + +export MODEL_NAME="stabilityai/stable-diffusion-2-1" +# export MODEL_NAME="runwayml/stable-diffusion-v1-5" + +export RUN_NAME="${PEFT_TYPE}_${BLOCK_NUM}${BLOCK_SIZE}${N_BUTTERFLY_FACTOR}" +export DATASET_NAME="oftverse/control-celeba-hq" +export CKPT_NAME="checkpoint-${ITER_NUM}" +export OUTPUT_DIR="./output/${DATASET_NAME}/${RUN_NAME}/${CKPT_NAME}" +export CONTROLNET_PATH="${OUTPUT_DIR}/controlnet/model.safetensors" +export UNET_PATH="${OUTPUT_DIR}/unet/${RUN_NAME}" +export RESULTS_PATH="${OUTPUT_DIR}/results" + +accelerate launch test_controlnet.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --dataset_name=$DATASET_NAME \ + --controlnet_path=$CONTROLNET_PATH \ + --unet_path=$UNET_PATH \ + --adapter_name=$RUN_NAME \ + --output_dir=$RESULTS_PATH \ + --dataset_name=$DATASET_NAME \ + +``` + +Run evaluation on the sampled images to evaluate the landmark reprojection error: + +```bash +./eval.sh +``` +or +```bash +ITER_NUM=50000 + +export MODEL_NAME="stabilityai/stable-diffusion-2-1" +# export MODEL_NAME="runwayml/stable-diffusion-v1-5" + +export RUN_NAME="${PEFT_TYPE}_${BLOCK_NUM}${BLOCK_SIZE}${N_BUTTERFLY_FACTOR}" +export DATASET_NAME="oftverse/control-celeba-hq" +export CKPT_NAME="checkpoint-${ITER_NUM}" +export OUTPUT_DIR="./output/${DATASET_NAME}/${RUN_NAME}/${CKPT_NAME}" +export CONTROLNET_PATH="${OUTPUT_DIR}/controlnet/model.safetensors" +export UNET_PATH="${OUTPUT_DIR}/unet/${RUN_NAME}" + +accelerate launch eval.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --dataset_name=$DATASET_NAME \ + --controlnet_path=$CONTROLNET_PATH \ + --unet_path=$UNET_PATH \ + --adapter_name=$RUN_NAME \ + --output_dir=$OUTPUT_DIR \ + --dataset_name=$DATASET_NAME \ + --vis_overlays \ +``` \ No newline at end of file diff --git a/examples/boft_controlnet/eval.py b/examples/boft_controlnet/eval.py new file mode 100644 index 0000000000..48ff615b4c --- /dev/null +++ b/examples/boft_controlnet/eval.py @@ -0,0 +1,200 @@ +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# The implementation is based on "Parameter-Efficient Orthogonal Finetuning +# via Butterfly Factorization" (https://arxiv.org/abs/2311.06243) in ICLR 2024. + +import glob +import os +from pathlib import Path + +import cv2 +import face_alignment +import numpy as np +import torch +from accelerate import Accelerator +from skimage.io import imread +from torchvision.utils import save_image +from tqdm import tqdm +from transformers import AutoTokenizer +from utils.args_loader import parse_args +from utils.dataset import make_dataset + + +detect_model = face_alignment.FaceAlignment(face_alignment.LandmarksType.TWO_D, device="cuda:0", flip_input=False) + +# with open('./data/celebhq-text/prompt_val_blip_full.json', 'rt') as f: # fill50k, COCO +# for line in f: +# val_data = json.loads(line) + +end_list = np.array([17, 22, 27, 42, 48, 31, 36, 68], dtype=np.int32) - 1 + + +def count_txt_files(directory): + pattern = os.path.join(directory, "*.txt") + txt_files = glob.glob(pattern) + return len(txt_files) + + +def plot_kpts(image, kpts, color="g"): + """Draw 68 key points + Args: + image: the input image + kpt: (68, 3). + """ + if color == "r": + c = (255, 0, 0) + elif color == "g": + c = (0, 255, 0) + elif color == "b": + c = (255, 0, 0) + image = image.copy() + kpts = kpts.copy() + radius = max(int(min(image.shape[0], image.shape[1]) / 200), 1) + for i in range(kpts.shape[0]): + st = kpts[i, :2] + if kpts.shape[1] == 4: + if kpts[i, 3] > 0.5: + c = (0, 255, 0) + else: + c = (0, 0, 255) + image = cv2.circle(image, (int(st[0]), int(st[1])), radius, c, radius * 2) + if i in end_list: + continue + ed = kpts[i + 1, :2] + image = cv2.line(image, (int(st[0]), int(st[1])), (int(ed[0]), int(ed[1])), (255, 255, 255), radius) + return image + + +def generate_landmark2d(dataset, input_dir, pred_lmk_dir, gt_lmk_dir, vis=False): + print("Generate 2d landmarks ...") + os.makedirs(pred_lmk_dir, exist_ok=True) + + imagepath_list = sorted(glob.glob(f"{input_dir}/pred*.png")) + + for imagepath in tqdm(imagepath_list): + name = Path(imagepath).stem + idx = int(name.split("_")[-1]) + pred_txt_path = os.path.join(pred_lmk_dir, f"{idx}.txt") + gt_lmk_path = os.path.join(gt_lmk_dir, f"{idx}_gt_lmk.jpg") + gt_txt_path = os.path.join(gt_lmk_dir, f"{idx}.txt") + gt_img_path = os.path.join(gt_lmk_dir, f"{idx}_gt_img.jpg") + + if (not os.path.exists(pred_txt_path)) or (not os.path.exists(gt_txt_path)): + image = imread(imagepath) # [:, :, :3] + out = detect_model.get_landmarks(image) + if out is None: + continue + + pred_kpt = out[0].squeeze() + np.savetxt(pred_txt_path, pred_kpt) + + # Your existing code for obtaining the image tensor + gt_lmk_img = dataset[idx]["conditioning_pixel_values"] + save_image(gt_lmk_img, gt_lmk_path) + + gt_img = (dataset[idx]["pixel_values"]) * 0.5 + 0.5 + save_image(gt_img, gt_img_path) + + gt_img = (gt_img.permute(1, 2, 0) * 255).type(torch.uint8).cpu().numpy() + out = detect_model.get_landmarks(gt_img) + if out is None: + continue + + gt_kpt = out[0].squeeze() + np.savetxt(gt_txt_path, gt_kpt) + # gt_image = cv2.resize(cv2.imread(gt_lmk_path), (512, 512)) + + if vis: + gt_lmk_image = cv2.imread(gt_lmk_path) + + # visualize predicted landmarks + vis_path = os.path.join(pred_lmk_dir, f"{idx}_overlay.jpg") + image = cv2.imread(imagepath) + image_point = plot_kpts(image, pred_kpt) + cv2.imwrite(vis_path, np.concatenate([image_point, gt_lmk_image], axis=1)) + + # visualize gt landmarks + vis_path = os.path.join(gt_lmk_dir, f"{idx}_overlay.jpg") + image = cv2.imread(gt_img_path) + image_point = plot_kpts(image, gt_kpt) + cv2.imwrite(vis_path, np.concatenate([image_point, gt_lmk_image], axis=1)) + + +def landmark_comparison(val_dataset, lmk_dir, gt_lmk_dir): + print("Calculating reprojection error") + lmk_err = [] + + pbar = tqdm(range(len(val_dataset))) + for i in pbar: + # line = val_dataset[i] + # img_name = line["image"].split(".")[0] + lmk1_path = os.path.join(gt_lmk_dir, f"{i}.txt") + lmk1 = np.loadtxt(lmk1_path) + lmk2_path = os.path.join(lmk_dir, f"{i}.txt") + + if not os.path.exists(lmk2_path): + print(f"{lmk2_path} not exist") + continue + + lmk2 = np.loadtxt(lmk2_path) + lmk_err.append(np.mean(np.linalg.norm(lmk1 - lmk2, axis=1))) + pbar.set_description(f"lmk_err: {np.mean(lmk_err):.5f}") + + print("Reprojection error:", np.mean(lmk_err)) + np.save(os.path.join(lmk_dir, "lmk_err.npy"), lmk_err) + + +def main(args): + logging_dir = Path(args.output_dir, args.logging_dir) + + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with=args.report_to, + project_dir=logging_dir, + ) + + # Load the tokenizer + if args.tokenizer_name: + tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False) + elif args.pretrained_model_name_or_path: + tokenizer = AutoTokenizer.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="tokenizer", + revision=args.revision, + use_fast=False, + ) + + val_dataset = make_dataset(args, tokenizer, accelerator, "test") + + gt_lmk_dir = os.path.join(args.output_dir, "gt_lmk") + if not os.path.exists(gt_lmk_dir): + os.makedirs(gt_lmk_dir, exist_ok=True) + + pred_lmk_dir = os.path.join(args.output_dir, "pred_lmk") + if not os.path.exists(pred_lmk_dir): + os.makedirs(pred_lmk_dir, exist_ok=True) + + input_dir = os.path.join(args.output_dir, "results") + + generate_landmark2d(val_dataset, input_dir, pred_lmk_dir, gt_lmk_dir, args.vis_overlays) + + if count_txt_files(pred_lmk_dir) == len(val_dataset) and count_txt_files(gt_lmk_dir) == len(val_dataset): + landmark_comparison(val_dataset, pred_lmk_dir, gt_lmk_dir) + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/examples/boft_controlnet/eval.sh b/examples/boft_controlnet/eval.sh new file mode 100755 index 0000000000..d5ed282ea1 --- /dev/null +++ b/examples/boft_controlnet/eval.sh @@ -0,0 +1,29 @@ +PEFT_TYPE="boft" +BLOCK_NUM=8 +BLOCK_SIZE=0 +N_BUTTERFLY_FACTOR=1 +ITER_NUM=50000 + +export RUN_NAME="${PEFT_TYPE}_${BLOCK_NUM}${BLOCK_SIZE}${N_BUTTERFLY_FACTOR}" + +export MODEL_NAME="stabilityai/stable-diffusion-2-1" +# export MODEL_NAME="runwayml/stable-diffusion-v1-5" + +export DATASET_NAME="oftverse/control-celeba-hq" +export CKPT_NAME="checkpoint-${ITER_NUM}" +export OUTPUT_DIR="./output/${DATASET_NAME}/${RUN_NAME}/${CKPT_NAME}" +export CONTROLNET_PATH="${OUTPUT_DIR}/controlnet/model.safetensors" +export UNET_PATH="${OUTPUT_DIR}/unet/${RUN_NAME}" + + +accelerate launch eval.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --dataset_name=$DATASET_NAME \ + --controlnet_path=$CONTROLNET_PATH \ + --unet_path=$UNET_PATH \ + --adapter_name=$RUN_NAME \ + --output_dir=$OUTPUT_DIR \ + --dataset_name=$DATASET_NAME \ + --vis_overlays \ + + diff --git a/examples/boft_controlnet/requirements.txt b/examples/boft_controlnet/requirements.txt new file mode 100644 index 0000000000..2eda894185 --- /dev/null +++ b/examples/boft_controlnet/requirements.txt @@ -0,0 +1,8 @@ +datasets==2.16.1 +diffusers==0.17.1 +transformers==4.36.2 +accelerate==0.25.0 +wandb==0.16.1 +scikit-image==0.22.0 +opencv-python==4.9.0.80 +face-alignment==1.4.1 \ No newline at end of file diff --git a/examples/boft_controlnet/test_controlnet.py b/examples/boft_controlnet/test_controlnet.py new file mode 100644 index 0000000000..5d8767a29a --- /dev/null +++ b/examples/boft_controlnet/test_controlnet.py @@ -0,0 +1,129 @@ +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# The implementation is based on "Parameter-Efficient Orthogonal Finetuning +# via Butterfly Factorization" (https://arxiv.org/abs/2311.06243) in ICLR 2024. + +import os +import sys +import time +from pathlib import Path + +import numpy as np +import torch +import torch.utils.checkpoint +from accelerate import Accelerator +from diffusers import DDIMScheduler +from diffusers.utils import check_min_version +from safetensors.torch import load_file +from tqdm import tqdm +from transformers import AutoTokenizer +from utils.args_loader import parse_args +from utils.dataset import make_dataset +from utils.light_controlnet import ControlNetModel +from utils.pipeline_controlnet import LightControlNetPipeline +from utils.unet_2d_condition import UNet2DConditionNewModel + + +sys.path.append("../../src") +from peft import PeftModel + + +# Will error if the minimal version of diffusers is not installed. Remove at your own risks. +check_min_version("0.10.0.dev0") +device = torch.device("cuda:0") + + +def main(args): + logging_dir = Path(args.output_dir, args.logging_dir) + + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with=args.report_to, + project_dir=logging_dir, + ) + + # Load the tokenizer + if args.tokenizer_name: + tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False) + elif args.pretrained_model_name_or_path: + tokenizer = AutoTokenizer.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="tokenizer", + revision=args.revision, + use_fast=False, + ) + + val_dataset = make_dataset(args, tokenizer, accelerator, "test") + + controlnet_path = args.controlnet_path + unet_path = args.unet_path + + controlnet = ControlNetModel() + controlnet.load_state_dict(load_file(controlnet_path)) + unet = UNet2DConditionNewModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet") + unet = PeftModel.from_pretrained(unet, unet_path, adapter_name=args.adapter_name) + + pipe = LightControlNetPipeline.from_pretrained( + args.pretrained_model_name_or_path, + controlnet=controlnet, + unet=unet.model, + torch_dtype=torch.float32, + requires_safety_checker=False, + ).to(device) + + pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) + + if not os.path.exists(args.output_dir): + os.makedirs(args.output_dir, exist_ok=True) + + exist_lst = [int(img.split("_")[-1][:-4]) for img in os.listdir(args.output_dir)] + all_lst = np.arange(len(val_dataset)) + idx_lst = [item for item in all_lst if item not in exist_lst] + + print("Number of images to be processed: ", len(idx_lst)) + + np.random.seed(seed=int(time.time())) + np.random.shuffle(idx_lst) + + for idx in tqdm(idx_lst): + output_path = os.path.join(args.output_dir, f"pred_img_{idx:04d}.png") + + if not os.path.exists(output_path): + data = val_dataset[idx.item()] + negative_prompt = "low quality, blurry, unfinished" + + with torch.no_grad(): + pred_img = pipe( + data["text"], + [data["conditioning_pixel_values"]], + num_inference_steps=50, + guidance_scale=7, + negative_prompt=negative_prompt, + ).images[0] + + pred_img.save(output_path) + + # control_img = Image.fromarray( + # (data["conditioning_pixel_value"] * 255).numpy().transpose(1, 2, 0).astype(np.uint8) + # ) + # gt_img = Image.fromarray( + # ((data["pixel_value"] + 1.0) * 0.5 * 255).numpy().transpose(1, 2, 0).astype(np.uint8) + # ) + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/examples/boft_controlnet/test_controlnet.sh b/examples/boft_controlnet/test_controlnet.sh new file mode 100755 index 0000000000..3afd8d9c7a --- /dev/null +++ b/examples/boft_controlnet/test_controlnet.sh @@ -0,0 +1,29 @@ +PEFT_TYPE="boft" +BLOCK_NUM=8 +BLOCK_SIZE=0 +N_BUTTERFLY_FACTOR=1 +ITER_NUM=50000 + +export RUN_NAME="${PEFT_TYPE}_${BLOCK_NUM}${BLOCK_SIZE}${N_BUTTERFLY_FACTOR}" + +export MODEL_NAME="stabilityai/stable-diffusion-2-1" +# export MODEL_NAME="runwayml/stable-diffusion-v1-5" + +export DATASET_NAME="oftverse/control-celeba-hq" +export CKPT_NAME="checkpoint-${ITER_NUM}" +export OUTPUT_DIR="./output/${DATASET_NAME}/${RUN_NAME}/${CKPT_NAME}" +export CONTROLNET_PATH="${OUTPUT_DIR}/controlnet/model.safetensors" +export UNET_PATH="${OUTPUT_DIR}/unet/${RUN_NAME}" +export RESULTS_PATH="${OUTPUT_DIR}/results" + + +accelerate launch test_controlnet.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --dataset_name=$DATASET_NAME \ + --controlnet_path=$CONTROLNET_PATH \ + --unet_path=$UNET_PATH \ + --adapter_name=$RUN_NAME \ + --output_dir=$RESULTS_PATH \ + --dataset_name=$DATASET_NAME \ + + diff --git a/examples/boft_controlnet/train_controlnet.py b/examples/boft_controlnet/train_controlnet.py new file mode 100644 index 0000000000..27f3e81892 --- /dev/null +++ b/examples/boft_controlnet/train_controlnet.py @@ -0,0 +1,537 @@ +#!/usr/bin/env python +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# The implementation is based on "Parameter-Efficient Orthogonal Finetuning +# via Butterfly Factorization" (https://arxiv.org/abs/2311.06243) in ICLR 2024. + +import itertools +import logging +import math +import os +from pathlib import Path + +import datasets +import diffusers +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +import transformers +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import set_seed +from diffusers import ( + AutoencoderKL, + DDIMScheduler, +) +from diffusers.optimization import get_scheduler +from diffusers.utils import check_min_version +from diffusers.utils.import_utils import is_xformers_available +from packaging import version +from tqdm.auto import tqdm +from transformers import AutoTokenizer +from utils.args_loader import ( + import_model_class_from_model_name_or_path, + parse_args, +) +from utils.dataset import collate_fn, log_validation, make_dataset +from utils.light_controlnet import ControlNetModel +from utils.tracemalloc import TorchTracemalloc, b2mb +from utils.unet_2d_condition import UNet2DConditionNewModel + +from peft import BOFTConfig, get_peft_model +from peft.peft_model import PeftModel + + +# Will error if the minimal version of diffusers is not installed. Remove at your own risks. +check_min_version("0.16.0.dev0") + +logger = get_logger(__name__) + +UNET_TARGET_MODULES = ["to_q", "to_v", "to_k", "query", "value", "key"] + +TEXT_ENCODER_TARGET_MODULES = ["q_proj", "v_proj"] + + +@torch.no_grad() +def save_adaptor(accelerator, output_dir, nets_dict): + for net_key in nets_dict.keys(): + net_model = nets_dict[net_key] + unwarpped_net = accelerator.unwrap_model(net_model) + + if isinstance(unwarpped_net, PeftModel): + unwarpped_net.save_pretrained( + os.path.join(output_dir, net_key), + state_dict=accelerator.get_state_dict(net_model), + safe_serialization=True, + ) + else: + accelerator.save_model( + unwarpped_net, + os.path.join(output_dir, net_key), + safe_serialization=True, + ) + + +def main(args): + logging_dir = Path(args.output_dir, args.logging_dir) + + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with=args.report_to, + project_dir=logging_dir, + ) + + if args.report_to == "wandb": + wandb_init = { + "wandb": { + "name": args.wandb_run_name, + "mode": "online", + } + } + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + + logger.info(accelerator.state, main_process_only=False) + + if accelerator.is_local_main_process: + datasets.utils.logging.set_verbosity_warning() + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + datasets.utils.logging.set_verbosity_error() + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + + # Handle the repository creation + if accelerator.is_main_process: + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + # Load the tokenizer + if args.tokenizer_name: + tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False) + elif args.pretrained_model_name_or_path: + tokenizer = AutoTokenizer.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="tokenizer", + revision=args.revision, + use_fast=False, + ) + + # import correct text encoder class + text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision) + + # Load scheduler and models + noise_scheduler = DDIMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") + + text_encoder = text_encoder_cls.from_pretrained( + args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision + ) + vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision) + unet = UNet2DConditionNewModel.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="unet", + revision=args.revision, + ) + + controlnet = ControlNetModel() + + if args.controlnet_model_name_or_path != "": + logger.info(f"Loading existing controlnet weights from {args.controlnet_model_name_or_path}") + controlnet.load_state_dict(torch.load(args.controlnet_model_name_or_path)) + + if args.use_boft: + config = BOFTConfig( + boft_block_size=args.boft_block_size, + boft_block_num=args.boft_block_num, + boft_n_butterfly_factor=args.boft_n_butterfly_factor, + target_modules=UNET_TARGET_MODULES, + boft_dropout=args.boft_dropout, + bias=args.boft_bias, + ) + unet = get_peft_model(unet, config) + unet.print_trainable_parameters() + + vae.requires_grad_(False) + controlnet.requires_grad_(True) + + if not args.train_text_encoder: + text_encoder.requires_grad_(False) + + unet.train() + controlnet.train() + + if args.train_text_encoder and args.use_boft: + config = BOFTConfig( + boft_block_size=args.boft_block_size, + boft_block_num=args.boft_block_num, + boft_n_butterfly_factor=args.boft_n_butterfly_factor, + target_modules=TEXT_ENCODER_TARGET_MODULES, + boft_dropout=args.boft_dropout, + bias=args.boft_bias, + ) + text_encoder = get_peft_model(text_encoder, config, adapter_name=args.wandb_run_name) + text_encoder.print_trainable_parameters() + + if args.train_text_encoder: + text_encoder.train() + + # For mixed precision training we cast the text_encoder and vae weights to half-precision + # as these models are only used for inference, keeping weights in full precision is not required. + weight_dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + # Move unet, vae and text_encoder to device and cast to weight_dtype + unet.to(accelerator.device, dtype=weight_dtype) + vae.to(accelerator.device, dtype=weight_dtype) + controlnet.to(accelerator.device, dtype=weight_dtype) + + if not args.train_text_encoder: + text_encoder.to(accelerator.device, dtype=weight_dtype) + + if args.enable_xformers_memory_efficient_attention: + if is_xformers_available(): + import xformers + + xformers_version = version.parse(xformers.__version__) + if xformers_version == version.parse("0.0.16"): + logger.warning( + "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." + ) + unet.enable_xformers_memory_efficient_attention() + controlnet.enable_xformers_memory_efficient_attention() + if args.train_text_encoder and not (args.use_lora or args.use_boft or args.use_oft): + text_encoder.enable_xformers_memory_efficient_attention() + else: + raise ValueError("xformers is not available. Make sure it is installed correctly") + + if args.gradient_checkpointing: + controlnet.enable_gradient_checkpointing() + unet.enable_gradient_checkpointing() + if args.train_text_encoder and not (args.use_lora or args.use_boft or args.use_oft): + text_encoder.gradient_checkpointing_enable() + + # Check that all trainable models are in full precision + low_precision_error_string = ( + " Please make sure to always have all model weights in full float32 precision when starting training - even if" + " doing mixed precision training, copy of the weights should still be float32." + ) + + if accelerator.unwrap_model(controlnet).dtype != torch.float32: + raise ValueError( + f"Controlnet loaded as datatype {accelerator.unwrap_model(controlnet).dtype}. {low_precision_error_string}" + ) + + if accelerator.unwrap_model(unet).dtype != torch.float32: + raise ValueError( + f"UNet loaded as datatype {accelerator.unwrap_model(unet).dtype}. {low_precision_error_string}" + ) + + # Enable TF32 for faster training on Ampere GPUs, + # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices + if args.allow_tf32: + torch.backends.cuda.matmul.allow_tf32 = True + + if args.scale_lr: + args.learning_rate = ( + args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes + ) + + # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs + if args.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError( + "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." + ) + + optimizer_class = bnb.optim.AdamW8bit + else: + optimizer_class = torch.optim.AdamW + + params_to_optimize = [param for param in controlnet.parameters() if param.requires_grad] + params_to_optimize += [param for param in unet.parameters() if param.requires_grad] + + if args.train_text_encoder: + params_to_optimize += [param for param in text_encoder.parameters() if param.requires_grad] + + # Optimizer creation + optimizer = optimizer_class( + params_to_optimize, + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + + # Load the dataset + train_dataset = make_dataset(args, tokenizer, accelerator, "train") + val_dataset = make_dataset(args, tokenizer, accelerator, "test") + + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + shuffle=True, + collate_fn=collate_fn, + batch_size=args.train_batch_size, + num_workers=args.dataloader_num_workers, + ) + + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, + num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, + num_cycles=args.lr_num_cycles, + power=args.lr_power, + ) + + # Prepare everything with our `accelerator`. + controlnet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + controlnet, optimizer, train_dataloader, lr_scheduler + ) + + if args.train_text_encoder: + text_encoder = accelerator.prepare(text_encoder) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + accelerator.init_trackers(args.wandb_project_name, config=vars(args), init_kwargs=wandb_init) + + # Train! + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num batches each epoch = {len(train_dataloader)}") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + + global_step = 0 + first_epoch = 0 + + # Potentially load in the weights and states from a previous save + + if args.resume_from_checkpoint: + if args.resume_from_checkpoint != "latest": + path = os.path.basename(args.resume_from_checkpoint) + else: + # Get the most recent checkpoint + dirs = os.listdir(args.output_dir) + if "checkpoint-current" in dirs: + path = "checkpoint-current" + dirs = [d for d in dirs if d.startswith("checkpoint") and d.endswith("0")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + + else: + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] if len(dirs) > 0 else None + + if path is None: + accelerator.print( + f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." + ) + args.resume_from_checkpoint = None + initial_global_step = 0 + else: + accelerator.print(f"Resuming from checkpoint {path}") + accelerator.load_state(os.path.join(args.output_dir, path)) + if path.split("-")[1] == "current": + global_step = int(dirs[-1].split("-")[1]) + else: + global_step = int(path.split("-")[1]) + + initial_global_step = global_step + resume_global_step = global_step * args.gradient_accumulation_steps + first_epoch = global_step // num_update_steps_per_epoch + resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps) + else: + initial_global_step = 0 + + progress_bar = tqdm( + range(0, args.max_train_steps), + initial=initial_global_step, + desc="Steps", + disable=not accelerator.is_local_main_process, + ) + + progress_bar.set_description("Steps") + + for epoch in range(first_epoch, args.num_train_epochs): + with TorchTracemalloc() as tracemalloc: + for step, batch in enumerate(train_dataloader): + # Skip steps until we reach the resumed step + if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step: + if step % args.gradient_accumulation_steps == 0: + progress_bar.update(1) + if args.report_to == "wandb": + accelerator.print(progress_bar) + continue + + with accelerator.accumulate(controlnet), accelerator.accumulate(unet): + # Convert images to latent space + latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample() + latents = latents * vae.config.scaling_factor + + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents) + bsz = latents.shape[0] + + # Sample a random timestep for each image + timesteps = torch.randint( + 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device + ) + timesteps = timesteps.long() + + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + + # Get the text embedding for conditioning + encoder_hidden_states = text_encoder(batch["input_ids"])[0] + + controlnet_image = batch["conditioning_pixel_values"].to(dtype=weight_dtype) + + # Get the guided hint for the UNet (320 dim) + guided_hint = controlnet( + controlnet_cond=controlnet_image, + ) + + # Predict the noise residual + model_pred = unet( + noisy_latents, + timesteps, + guided_hint=guided_hint, + encoder_hidden_states=encoder_hidden_states, + ).sample + + # Get the target for loss depending on the prediction type + if noise_scheduler.config.prediction_type == "epsilon": + target = noise + elif noise_scheduler.config.prediction_type == "v_prediction": + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") + + loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + + accelerator.backward(loss) + + if accelerator.sync_gradients: + params_to_clip = ( + itertools.chain(controlnet.parameters(), text_encoder.parameters()) + if args.train_text_encoder + else itertools.chain( + controlnet.parameters(), + ) + ) + + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=args.set_grads_to_none) + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + if args.report_to == "wandb": + accelerator.print(progress_bar) + global_step += 1 + + step_save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + + if accelerator.is_main_process: + if global_step % args.validation_steps == 0 or global_step == 1: + logger.info(f"Running validation... \n Generating {args.num_validation_images} images.") + logger.info("Running validation... ") + + with torch.no_grad(): + log_validation(val_dataset, text_encoder, unet, controlnet, args, accelerator) + + if global_step % args.checkpointing_steps == 0: + save_adaptor(accelerator, step_save_path, {"controlnet": controlnet, "unet": unet}) + + # save text_encoder if any + if args.train_text_encoder: + save_adaptor(accelerator, step_save_path, {"text_encoder": text_encoder}) + + accelerator.save_state(step_save_path) + + logger.info(f"Saved {global_step} state to {step_save_path}") + logger.info(f"Saved current state to {step_save_path}") + + logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + accelerator.log(logs, step=global_step) + + if global_step >= args.max_train_steps: + break + + # Printing the GPU memory usage details such as allocated memory, peak memory, and total memory usage + accelerator.print(f"GPU Memory before entering the train : {b2mb(tracemalloc.begin)}") + accelerator.print(f"GPU Memory consumed at the end of the train (end-begin): {tracemalloc.used}") + accelerator.print(f"GPU Peak Memory consumed during the train (max-begin): {tracemalloc.peaked}") + accelerator.print( + f"GPU Total Peak Memory consumed during the train (max): {tracemalloc.peaked + b2mb(tracemalloc.begin)}" + ) + + accelerator.print(f"CPU Memory before entering the train : {b2mb(tracemalloc.cpu_begin)}") + accelerator.print(f"CPU Memory consumed at the end of the train (end-begin): {tracemalloc.cpu_used}") + accelerator.print(f"CPU Peak Memory consumed during the train (max-begin): {tracemalloc.cpu_peaked}") + accelerator.print( + f"CPU Total Peak Memory consumed during the train (max): {tracemalloc.cpu_peaked + b2mb(tracemalloc.cpu_begin)}" + ) + + # Create the pipeline using using the trained modules and save it. + accelerator.wait_for_everyone() + accelerator.end_training() + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/examples/boft_controlnet/train_controlnet.sh b/examples/boft_controlnet/train_controlnet.sh new file mode 100755 index 0000000000..efad2c4348 --- /dev/null +++ b/examples/boft_controlnet/train_controlnet.sh @@ -0,0 +1,42 @@ +PEFT_TYPE="boft" +BLOCK_NUM=8 +BLOCK_SIZE=0 +N_BUTTERFLY_FACTOR=1 + +export DATASET_NAME="oftverse/control-celeba-hq" +export PROJECT_NAME="controlnet_${PEFT_TYPE}" +export RUN_NAME="${PEFT_TYPE}_${BLOCK_NUM}${BLOCK_SIZE}${N_BUTTERFLY_FACTOR}" +export CONTROLNET_PATH="" + +export MODEL_NAME="stabilityai/stable-diffusion-2-1" +# export MODEL_NAME="runwayml/stable-diffusion-v1-5" + +export OUTPUT_DIR="./output/${DATASET_NAME}/${RUN_NAME}" + +accelerate launch train_controlnet.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --resume_from_checkpoint=$RESUME_PATH \ + --controlnet_model_name_or_path=$CONTROLNET_PATH \ + --output_dir=$OUTPUT_DIR \ + --report_to="wandb" \ + --dataset_name=$DATASET_NAME \ + --resolution=512 \ + --learning_rate=1e-5 \ + --checkpointing_steps=500 \ + --max_train_steps=50000 \ + --validation_steps=5000 \ + --num_validation_images=12 \ + --train_batch_size=4 \ + --dataloader_num_workers=2 \ + --seed="0" \ + --lr_scheduler="constant" \ + --lr_warmup_steps=0 \ + --wandb_project_name=$PROJECT_NAME \ + --wandb_run_name=$RUN_NAME \ + --enable_xformers_memory_efficient_attention \ + --use_boft \ + --boft_block_num=$BLOCK_NUM \ + --boft_block_size=$BLOCK_SIZE \ + --boft_n_butterfly_factor=$N_BUTTERFLY_FACTOR \ + --boft_dropout=0.1 \ + --boft_bias="boft_only" \ \ No newline at end of file diff --git a/examples/boft_controlnet/utils/__init__.py b/examples/boft_controlnet/utils/__init__.py new file mode 100644 index 0000000000..8b13789179 --- /dev/null +++ b/examples/boft_controlnet/utils/__init__.py @@ -0,0 +1 @@ + diff --git a/examples/boft_controlnet/utils/args_loader.py b/examples/boft_controlnet/utils/args_loader.py new file mode 100644 index 0000000000..16e3c9a8ee --- /dev/null +++ b/examples/boft_controlnet/utils/args_loader.py @@ -0,0 +1,447 @@ +import argparse +import os +from typing import Optional + +from huggingface_hub import HfFolder, whoami +from transformers import PretrainedConfig + + +def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None): + if token is None: + token = HfFolder.get_token() + if organization is None: + username = whoami(token)["name"] + return f"{username}/{model_id}" + else: + return f"{organization}/{model_id}" + + +def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str): + text_encoder_config = PretrainedConfig.from_pretrained( + pretrained_model_name_or_path, + subfolder="text_encoder", + revision=revision, + ) + model_class = text_encoder_config.architectures[0] + + if model_class == "CLIPTextModel": + from transformers import CLIPTextModel + + return CLIPTextModel + elif model_class == "RobertaSeriesModelWithTransformation": + from diffusers.pipelines.alt_diffusion.modeling_roberta_series import ( + RobertaSeriesModelWithTransformation, + ) + + return RobertaSeriesModelWithTransformation + else: + raise ValueError(f"{model_class} is not supported.") + + +def parse_args(input_args=None): + parser = argparse.ArgumentParser(description="Simple example of a ControlNet training script.") + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + required=True, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--controlnet_model_name_or_path", + type=str, + default=None, + help="Path to pretrained controlnet model or model identifier from huggingface.co/models." + " If not specified controlnet weights are initialized from unet.", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + required=False, + help=( + "Revision of pretrained model identifier from huggingface.co/models. Trainable model components should be" + " float32 precision." + ), + ) + parser.add_argument( + "--tokenizer_name", + type=str, + default=None, + help="Pretrained tokenizer name or path if not the same as model_name", + ) + parser.add_argument( + "--output_dir", + type=str, + default="controlnet-model", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument( + "--cache_dir", + type=str, + default=None, + help="The directory where the downloaded models and datasets will be stored.", + ) + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument( + "--resolution", + type=int, + default=512, + help=( + "The resolution for input images, all the images in the train/validation dataset will be resized to this" + " resolution" + ), + ) + parser.add_argument("--train_text_encoder", action="store_true", help="Whether to train the text encoder") + + parser.add_argument( + "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." + ) + parser.add_argument( + "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images." + ) + + parser.add_argument("--num_train_epochs", type=int, default=1) + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--checkpointing_steps", + type=int, + default=500, + help=( + "Save a checkpoint of the training state every X updates. Checkpoints can be used for resuming training via `--resume_from_checkpoint`. " + "In the case that the checkpoint is better than the final trained model, the checkpoint can also be used for inference." + "Using a checkpoint for inference requires separate loading of the original pipeline and the individual checkpointed model components." + "See https://huggingface.co/docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint for step by step" + "instructions." + ), + ) + parser.add_argument( + "--checkpoints_total_limit", + type=int, + default=None, + help=("Max number of checkpoints to store."), + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help=( + "Whether training should be resumed from a previous checkpoint. Use a path saved by" + ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' + ), + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=5e-6, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--scale_lr", + action="store_true", + default=False, + help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument( + "--lr_num_cycles", + type=int, + default=1, + help="Number of hard resets of the lr in cosine_with_restarts scheduler.", + ) + parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.") + parser.add_argument( + "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." + ) + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=0, + help=( + "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." + ), + ) + parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") + parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") + parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") + parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), + ) + parser.add_argument( + "--allow_tf32", + action="store_true", + help=( + "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" + " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" + ), + ) + parser.add_argument( + "--report_to", + type=str, + default="wandb", + help=( + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' + ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' + ), + ) + parser.add_argument( + "--wandb_key", + type=str, + default=None, + help=("If report to option is set to wandb, api-key for wandb used for login to wandb "), + ) + parser.add_argument( + "--wandb_project_name", + type=str, + default=None, + help=("If report to option is set to wandb, project name in wandb for log tracking "), + ) + parser.add_argument( + "--wandb_run_name", + type=str, + default=None, + help=("If report to option is set to wandb, project name in wandb for log tracking "), + ) + parser.add_argument( + "--mixed_precision", + type=str, + default=None, + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." + ), + ) + parser.add_argument( + "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." + ) + parser.add_argument( + "--set_grads_to_none", + action="store_true", + help=( + "Save more memory by using setting grads to None instead of zero. Be aware, that this changes certain" + " behaviors, so disable this argument if it causes any problems. More info:" + " https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html" + ), + ) + parser.add_argument( + "--dataset_name", + type=str, + default=None, + help=( + "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private," + " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem," + " or to a folder containing files that 🤗 Datasets can understand." + ), + ) + parser.add_argument( + "--dataset_config_name", + type=str, + default=None, + help="The config of the Dataset, leave as None if there's only one config.", + ) + parser.add_argument( + "--train_data_dir", + type=str, + default=None, + help=( + "A folder containing the training data. Folder contents must follow the structure described in" + " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file" + " must exist to provide the captions for the images. Ignored if `dataset_name` is specified." + ), + ) + parser.add_argument( + "--image_column", type=str, default="image", help="The column of the dataset containing the target image." + ) + parser.add_argument( + "--conditioning_image_column", + type=str, + default="conditioning_image", + help="The column of the dataset containing the controlnet conditioning image.", + ) + parser.add_argument( + "--caption_column", + type=str, + default="text", + help="The column of the dataset containing a caption or a list of captions.", + ) + parser.add_argument( + "--max_train_samples", + type=int, + default=None, + help=( + "For debugging purposes or quicker training, truncate the number of training examples to this " + "value if set." + ), + ) + parser.add_argument( + "--proportion_empty_prompts", + type=float, + default=0, + help="Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).", + ) + parser.add_argument( + "--validation_prompt", + type=str, + default=None, + nargs="+", + help=( + "A set of prompts evaluated every `--validation_steps` and logged to `--report_to`." + " Provide either a matching number of `--validation_image`s, a single `--validation_image`" + " to be used with all prompts, or a single prompt that will be used with all `--validation_image`s." + ), + ) + parser.add_argument( + "--validation_image", + type=str, + default=None, + nargs="+", + help=( + "A set of paths to the controlnet conditioning image be evaluated every `--validation_steps`" + " and logged to `--report_to`. Provide either a matching number of `--validation_prompt`s, a" + " a single `--validation_prompt` to be used with all `--validation_image`s, or a single" + " `--validation_image` that will be used with all `--validation_prompt`s." + ), + ) + parser.add_argument( + "--num_validation_images", + type=int, + default=4, + help="Number of images to be generated for each `--validation_image`, `--validation_prompt` pair", + ) + parser.add_argument( + "--validation_steps", + type=int, + default=100, + help=( + "Run validation every X steps. Validation consists of running the prompt" + " `args.validation_prompt` multiple times: `args.num_validation_images`" + " and logging the images." + ), + ) + parser.add_argument( + "--tracker_project_name", + type=str, + default="train_controlnet", + help=( + "The `project_name` argument passed to Accelerator.init_trackers for" + " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator" + ), + ) + + # evaluation arguments + parser.add_argument("--controlnet_path", type=str, default=None, help="Path to pretrained controlnet.") + parser.add_argument("--unet_path", type=str, default=None, help="Path to pretrained unet.") + parser.add_argument("--adapter_name", type=str, default=None, help="Name of the adapter to use.") + parser.add_argument("--vis_overlays", action="store_true", help="Whether to visualize the landmarks.") + + # self-invented arguments + + parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") + + parser.add_argument( + "--name", + type=str, + help=("The name of the current experiment run, consists of [data]-[prompt]"), + ) + + # BOFT args + parser.add_argument("--use_boft", action="store_true", help="Whether to use BOFT for parameter efficient tuning") + parser.add_argument("--boft_block_num", type=int, default=8, help="The number of BOFT blocks") + parser.add_argument("--boft_block_size", type=int, default=0, help="The size of BOFT blocks") + parser.add_argument("--boft_n_butterfly_factor", type=int, default=0, help="The number of butterfly factors") + parser.add_argument("--boft_dropout", type=float, default=0.1, help="BOFT dropout, only used if use_boft is True") + parser.add_argument( + "--boft_bias", + type=str, + default="none", + help="Bias type for BOFT. Can be 'none', 'all' or 'boft_only', only used if use_boft is True", + ) + + if input_args is not None: + args = parser.parse_args(input_args) + else: + args = parser.parse_args() + + env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) + + if env_local_rank != -1 and env_local_rank != args.local_rank: + args.local_rank = env_local_rank + + if args.dataset_name is None and args.train_data_dir is None: + raise ValueError("Specify either `--dataset_name` or `--train_data_dir`") + + if args.dataset_name is not None and args.train_data_dir is not None: + raise ValueError("Specify only one of `--dataset_name` or `--train_data_dir`") + + if args.proportion_empty_prompts < 0 or args.proportion_empty_prompts > 1: + raise ValueError("`--proportion_empty_prompts` must be in the range [0, 1].") + + if args.validation_prompt is not None and args.validation_image is None: + raise ValueError("`--validation_image` must be set if `--validation_prompt` is set") + + if args.validation_prompt is None and args.validation_image is not None: + raise ValueError("`--validation_prompt` must be set if `--validation_image` is set") + + if ( + args.validation_image is not None + and args.validation_prompt is not None + and len(args.validation_image) != 1 + and len(args.validation_prompt) != 1 + and len(args.validation_image) != len(args.validation_prompt) + ): + raise ValueError( + "Must provide either 1 `--validation_image`, 1 `--validation_prompt`," + " or the same number of `--validation_prompt`s and `--validation_image`s" + ) + + if args.resolution % 8 != 0: + raise ValueError( + "`--resolution` must be divisible by 8 for consistently sized encoded images between the VAE and the controlnet encoder." + ) + + return args diff --git a/examples/boft_controlnet/utils/dataset.py b/examples/boft_controlnet/utils/dataset.py new file mode 100644 index 0000000000..1de3c8cc36 --- /dev/null +++ b/examples/boft_controlnet/utils/dataset.py @@ -0,0 +1,207 @@ +import random + +import numpy as np +import torch +import wandb +from datasets import load_dataset +from diffusers import DDIMScheduler +from PIL import Image +from torchvision import transforms +from utils.pipeline_controlnet import LightControlNetPipeline + + +def image_grid(imgs, rows, cols): + assert len(imgs) == rows * cols + + w, h = imgs[0].size + grid = Image.new("RGB", size=(cols * w, rows * h)) + + for i, img in enumerate(imgs): + grid.paste(img, box=(i % cols * w, i // cols * h)) + return grid + + +def log_validation(val_dataset, text_encoder, unet, controlnet, args, accelerator): + pipeline = LightControlNetPipeline.from_pretrained( + args.pretrained_model_name_or_path, + controlnet=accelerator.unwrap_model(controlnet, keep_fp32_wrapper=True), + unet=accelerator.unwrap_model(unet, keep_fp32_wrapper=True).model, + text_encoder=accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True), + safety_checker=None, + revision=args.revision, + ) + + pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config) + pipeline = pipeline.to(accelerator.device) + + pipeline.set_progress_bar_config(disable=True) + + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) + + image_logs = [] + + for idx in range(args.num_validation_images): + data = val_dataset[idx] + validation_prompt = data["text"] + validation_image = data["conditioning_pixel_values"] + + image = pipeline( + validation_prompt, + [validation_image], + num_inference_steps=50, + generator=generator, + )[0][0] + + image_logs.append( + { + "validation_image": validation_image, + "image": image, + "validation_prompt": validation_prompt, + } + ) + + for tracker in accelerator.trackers: + formatted_images = [] + + for log in image_logs: + image = log["image"] + validation_prompt = log["validation_prompt"] + validation_image = log["validation_image"] + + formatted_images.append(wandb.Image(validation_image, caption="Controlnet conditioning")) + + image = wandb.Image(image, caption=validation_prompt) + formatted_images.append(image) + + tracker.log({"validation": formatted_images}) + + del pipeline + torch.cuda.empty_cache() + + +def make_dataset(args, tokenizer, accelerator, split="train"): + # Get the datasets: you can either provide your own training and evaluation files (see below) + # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub). + + # In distributed training, the load_dataset function guarantees that only one local process can concurrently + # download the dataset. + if args.dataset_name is not None: + # Downloading and loading a dataset from the hub. + dataset = load_dataset( + args.dataset_name, + args.dataset_config_name, + cache_dir=args.cache_dir, + ) + else: + if args.train_data_dir is not None: + dataset = load_dataset( + args.train_data_dir, + cache_dir=args.cache_dir, + ) + # See more about loading custom images at + # https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script + + # Preprocessing the datasets. + # We need to tokenize inputs and targets. + column_names = dataset[split].column_names + + # Get the column names for input/target. + if args.image_column is None: + image_column = column_names[0] + else: + image_column = args.image_column + if image_column not in column_names: + raise ValueError( + f"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" + ) + + if args.caption_column is None: + caption_column = column_names[1] + else: + caption_column = args.caption_column + if caption_column not in column_names: + raise ValueError( + f"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" + ) + + if args.conditioning_image_column is None: + conditioning_image_column = column_names[2] + else: + conditioning_image_column = args.conditioning_image_column + if conditioning_image_column not in column_names: + raise ValueError( + f"`--conditioning_image_column` value '{args.conditioning_image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" + ) + + def tokenize_captions(examples, is_train=True): + captions = [] + for caption in examples[caption_column]: + if random.random() < args.proportion_empty_prompts: + captions.append("") + elif isinstance(caption, str): + captions.append(caption) + elif isinstance(caption, (list, np.ndarray)): + # take a random caption if there are multiple + captions.append(random.choice(caption) if is_train else caption[0]) + else: + raise ValueError( + f"Caption column `{caption_column}` should contain either strings or lists of strings." + ) + inputs = tokenizer( + captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt" + ) + return inputs.input_ids + + image_transforms = transforms.Compose( + [ + transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(args.resolution), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) + + conditioning_image_transforms = transforms.Compose( + [ + transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(args.resolution), + transforms.ToTensor(), + ] + ) + + def preprocess_train(examples): + images = [image.convert("RGB") for image in examples[image_column]] + images = [image_transforms(image) for image in images] + + conditioning_images = [image.convert("RGB") for image in examples[conditioning_image_column]] + conditioning_images = [conditioning_image_transforms(image) for image in conditioning_images] + + examples["pixel_values"] = images + examples["conditioning_pixel_values"] = conditioning_images + examples["input_ids"] = tokenize_captions(examples) + + return examples + + with accelerator.main_process_first(): + if args.max_train_samples is not None: + dataset[split] = dataset[split].shuffle(seed=args.seed).select(range(args.max_train_samples)) + # Set the training transforms + split_dataset = dataset[split].with_transform(preprocess_train) + + return split_dataset + + +def collate_fn(examples): + pixel_values = torch.stack([example["pixel_values"] for example in examples]) + pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() + + conditioning_pixel_values = torch.stack([example["conditioning_pixel_values"] for example in examples]) + conditioning_pixel_values = conditioning_pixel_values.to(memory_format=torch.contiguous_format).float() + + input_ids = torch.stack([example["input_ids"] for example in examples]) + + return { + "pixel_values": pixel_values, + "conditioning_pixel_values": conditioning_pixel_values, + "input_ids": input_ids, + } diff --git a/examples/boft_controlnet/utils/light_controlnet.py b/examples/boft_controlnet/utils/light_controlnet.py new file mode 100644 index 0000000000..fce1774041 --- /dev/null +++ b/examples/boft_controlnet/utils/light_controlnet.py @@ -0,0 +1,263 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from dataclasses import dataclass +from typing import Dict, List, Optional, Tuple, Union + +import torch +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor +from diffusers.models.modeling_utils import ModelMixin +from diffusers.models.unet_2d_blocks import ( + CrossAttnDownBlock2D, + DownBlock2D, +) +from diffusers.utils import BaseOutput, logging +from torch import nn +from torch.nn import functional as F + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class ControlNetOutput(BaseOutput): + down_block_res_samples: Tuple[torch.Tensor] + mid_block_res_sample: torch.Tensor + + +class ControlNetConditioningEmbedding(nn.Module): + """ + Quoting from https://arxiv.org/abs/2302.05543: "Stable Diffusion uses a pre-processing method similar to VQ-GAN + [11] to convert the entire dataset of 512 × 512 images into smaller 64 × 64 “latent images” for stabilized + training. This requires ControlNets to convert image-based conditions to 64 × 64 feature space to match the + convolution size. We use a tiny network E(·) of four convolution layers with 4 × 4 kernels and 2 × 2 strides + (activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full + model) to encode image-space conditions ... into feature maps ..." + """ + + def __init__( + self, + conditioning_embedding_channels: int, + conditioning_channels: int = 3, + block_out_channels: Tuple[int] = (16, 32, 96, 256), + ): + super().__init__() + + self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1) + + self.blocks = nn.ModuleList([]) + + for i in range(len(block_out_channels) - 1): + channel_in = block_out_channels[i] + channel_out = block_out_channels[i + 1] + self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1)) + self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2)) + + self.conv_out = zero_module( + nn.Conv2d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1) + ) + + def forward(self, conditioning): + embedding = self.conv_in(conditioning) + embedding = F.silu(embedding) + + for block in self.blocks: + embedding = block(embedding) + embedding = F.silu(embedding) + + embedding = self.conv_out(embedding) + + return embedding + + +class ControlNetModel(ModelMixin, ConfigMixin): + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + in_channels: int = 4, + out_channels: int = 320, + controlnet_conditioning_channel_order: str = "rgb", + conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256), + ): + super().__init__() + + # for control image + self.controlnet_cond_embedding = ControlNetConditioningEmbedding( + conditioning_embedding_channels=out_channels, + block_out_channels=conditioning_embedding_out_channels, + ) + + @property + # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "set_processor"): + processors[f"{name}.processor"] = module.processor + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + r""" + Parameters: + `processor (`dict` of `AttentionProcessor` or `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + of **all** `Attention` layers. + In case `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainable attention processors.: + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor + def set_default_attn_processor(self): + """ + Disables custom attention processors and sets the default attention implementation. + """ + self.set_attn_processor(AttnProcessor()) + + # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice + def set_attention_slice(self, slice_size): + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module will split the input tensor in slices, to compute attention + in several steps. This is useful to save some memory in exchange for a small speed decrease. + + Args: + slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): + When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If + `"max"`, maximum amount of memory will be saved by running only one slice at a time. If a number is + provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` + must be a multiple of `slice_size`. + """ + sliceable_head_dims = [] + + def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module): + if hasattr(module, "set_attention_slice"): + sliceable_head_dims.append(module.sliceable_head_dim) + + for child in module.children(): + fn_recursive_retrieve_sliceable_dims(child) + + # retrieve number of attention layers + for module in self.children(): + fn_recursive_retrieve_sliceable_dims(module) + + num_sliceable_layers = len(sliceable_head_dims) + + if slice_size == "auto": + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = [dim // 2 for dim in sliceable_head_dims] + elif slice_size == "max": + # make smallest slice possible + slice_size = num_sliceable_layers * [1] + + slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size + + if len(slice_size) != len(sliceable_head_dims): + raise ValueError( + f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" + f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." + ) + + for i in range(len(slice_size)): + size = slice_size[i] + dim = sliceable_head_dims[i] + if size is not None and size > dim: + raise ValueError(f"size {size} has to be smaller or equal to {dim}.") + + # Recursively walk through all the children. + # Any children which exposes the set_attention_slice method + # gets the message + def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): + if hasattr(module, "set_attention_slice"): + module.set_attention_slice(slice_size.pop()) + + for child in module.children(): + fn_recursive_set_attention_slice(child, slice_size) + + reversed_slice_size = list(reversed(slice_size)) + for module in self.children(): + fn_recursive_set_attention_slice(module, reversed_slice_size) + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)): + module.gradient_checkpointing = value + + def forward( + self, + controlnet_cond: torch.FloatTensor, + ) -> Union[ControlNetOutput, Tuple]: + # check channel order + channel_order = self.config.controlnet_conditioning_channel_order + + if channel_order == "rgb": + # in rgb order by default + ... + elif channel_order == "bgr": + controlnet_cond = torch.flip(controlnet_cond, dims=[1]) + else: + raise ValueError(f"unknown `controlnet_conditioning_channel_order`: {channel_order}") + + # 2. pre-process + + controlnet_cond = self.controlnet_cond_embedding(controlnet_cond) + + return controlnet_cond + + +def zero_module(module): + for p in module.parameters(): + nn.init.zeros_(p) + return module diff --git a/examples/boft_controlnet/utils/pipeline_controlnet.py b/examples/boft_controlnet/utils/pipeline_controlnet.py new file mode 100644 index 0000000000..7d301d7c25 --- /dev/null +++ b/examples/boft_controlnet/utils/pipeline_controlnet.py @@ -0,0 +1,452 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import PIL.Image +import torch +from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel +from diffusers.pipelines.controlnet.pipeline_controlnet import StableDiffusionControlNetPipeline +from diffusers.utils import BaseOutput, is_compiled_module, logging +from torch.nn import functional as F +from utils.light_controlnet import ControlNetModel + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class LightControlNetPipelineOutput(BaseOutput): + """ + Output class for Stable Diffusion pipelines. + + Args: + images (`List[PIL.Image.Image]` or `np.ndarray`) + List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, + num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. + nsfw_content_detected (`List[bool]`) + List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, or `None` if safety checking could not be performed. + """ + + images: Union[List[PIL.Image.Image], np.ndarray] + nsfw_content_detected: Optional[List[bool]] + + +class LightControlNetPipeline(StableDiffusionControlNetPipeline): + _optional_components = ["safety_checker", "feature_extractor"] + + def check_inputs( + self, + prompt, + image, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + controlnet_conditioning_scale=1.0, + ): + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + # `prompt` needs more sophisticated handling when there are multiple + # conditionings. + if isinstance(self.controlnet, MultiControlNetModel): + if isinstance(prompt, list): + logger.warning( + f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}" + " prompts. The conditionings will be fixed across the prompts." + ) + + # Check `image` + is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance( + self.controlnet, torch._dynamo.eval_frame.OptimizedModule + ) + + if ( + isinstance(self.controlnet, ControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, ControlNetModel) + ): + self.check_image(image, prompt, prompt_embeds) + elif ( + isinstance(self.controlnet, MultiControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, MultiControlNetModel) + ): + if not isinstance(image, list): + raise TypeError("For multiple controlnets: `image` must be type `list`") + + # When `image` is a nested list: + # (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]]) + elif any(isinstance(i, list) for i in image): + raise ValueError("A single batch of multiple conditionings are supported at the moment.") + elif len(image) != len(self.controlnet.nets): + raise ValueError( + "For multiple controlnets: `image` must have the same length as the number of controlnets." + ) + + for image_ in image: + self.check_image(image_, prompt, prompt_embeds) + else: + assert False + + # Check `controlnet_conditioning_scale` + if ( + isinstance(self.controlnet, ControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, ControlNetModel) + ): + if not isinstance(controlnet_conditioning_scale, float): + raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.") + elif ( + isinstance(self.controlnet, MultiControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, MultiControlNetModel) + ): + if isinstance(controlnet_conditioning_scale, list): + if any(isinstance(i, list) for i in controlnet_conditioning_scale): + raise ValueError("A single batch of multiple conditionings are supported at the moment.") + elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len( + self.controlnet.nets + ): + raise ValueError( + "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have" + " the same length as the number of controlnets" + ) + else: + assert False + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]] = None, + image: Union[ + torch.FloatTensor, + PIL.Image.Image, + np.ndarray, + List[torch.FloatTensor], + List[PIL.Image.Image], + List[np.ndarray], + ] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + controlnet_conditioning_scale: Union[float, List[float]] = 1.0, + guess_mode: bool = False, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,: + `List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): + The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If + the type is specified as `Torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can + also be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If + height and/or width are passed, `image` is resized according to them. If multiple ControlNets are + specified in init, images must be passed as a list such that each element of the list can be correctly + batched for input to a single controlnet. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0): + The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added + to the residual in the original unet. If multiple ControlNets are specified in init, you can set the + corresponding scale as a list. + guess_mode (`bool`, *optional*, defaults to `False`): + In this mode, the ControlNet encoder will try best to recognize the content of the input image even if + you remove all prompts. The `guidance_scale` between 3.0 and 5.0 is recommended. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + image, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + controlnet_conditioning_scale, + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet + + if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float): + controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets) + + # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + prompt_embeds = self._encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + ) + + # 4. Prepare image + if isinstance(controlnet, ControlNetModel): + image = self.prepare_image( + image=image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnet.dtype, + do_classifier_free_guidance=do_classifier_free_guidance, + guess_mode=guess_mode, + ) + height, width = image.shape[-2:] + elif isinstance(controlnet, MultiControlNetModel): + images = [] + + for image_ in image: + image_ = self.prepare_image( + image=image_, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnet.dtype, + do_classifier_free_guidance=do_classifier_free_guidance, + guess_mode=guess_mode, + ) + + images.append(image_) + + image = images + height, width = image[0].shape[-2:] + else: + assert False + + # 5. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 6. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # controlnet(s) inference + if guess_mode and do_classifier_free_guidance: + # Infer ControlNet only for the conditional batch. + control_model_input = latents + control_model_input = self.scheduler.scale_model_input(control_model_input, t) + else: + control_model_input = latent_model_input + + # Get the guided hint for the UNet (320 dim) + guided_hint = self.controlnet( + controlnet_cond=image, + ) + + # Predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + guided_hint=guided_hint, + encoder_hidden_states=prompt_embeds, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + # If we do sequential model offloading, let's offload unet and controlnet + # manually for max memory savings + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.unet.to("cpu") + self.controlnet.to("cpu") + torch.cuda.empty_cache() + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + if not return_dict: + return (image, has_nsfw_concept) + + return LightControlNetPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/examples/boft_controlnet/utils/tracemalloc.py b/examples/boft_controlnet/utils/tracemalloc.py new file mode 100644 index 0000000000..ce31dbb208 --- /dev/null +++ b/examples/boft_controlnet/utils/tracemalloc.py @@ -0,0 +1,58 @@ +import gc +import threading + +import psutil +import torch + + +# Converting Bytes to Megabytes +def b2mb(x): + return int(x / 2**20) + + +# This context manager is used to track the peak memory usage of the process +class TorchTracemalloc: + def __enter__(self): + gc.collect() + torch.cuda.empty_cache() + torch.cuda.reset_max_memory_allocated() # reset the peak gauge to zero + self.begin = torch.cuda.memory_allocated() + self.process = psutil.Process() + + self.cpu_begin = self.cpu_mem_used() + self.peak_monitoring = True + peak_monitor_thread = threading.Thread(target=self.peak_monitor_func) + peak_monitor_thread.daemon = True + peak_monitor_thread.start() + return self + + def cpu_mem_used(self): + """get resident set size memory for the current process""" + return self.process.memory_info().rss + + def peak_monitor_func(self): + self.cpu_peak = -1 + + while True: + self.cpu_peak = max(self.cpu_mem_used(), self.cpu_peak) + + # can't sleep or will not catch the peak right (this comment is here on purpose) + # time.sleep(0.001) # 1msec + + if not self.peak_monitoring: + break + + def __exit__(self, *exc): + self.peak_monitoring = False + + gc.collect() + torch.cuda.empty_cache() + self.end = torch.cuda.memory_allocated() + self.peak = torch.cuda.max_memory_allocated() + self.used = b2mb(self.end - self.begin) + self.peaked = b2mb(self.peak - self.begin) + + self.cpu_end = self.cpu_mem_used() + self.cpu_used = b2mb(self.cpu_end - self.cpu_begin) + self.cpu_peaked = b2mb(self.cpu_peak - self.cpu_begin) + # print(f"delta used/peak {self.used:4d}/{self.peaked:4d}") diff --git a/examples/boft_controlnet/utils/unet_2d_condition.py b/examples/boft_controlnet/utils/unet_2d_condition.py new file mode 100644 index 0000000000..9a384cb9ba --- /dev/null +++ b/examples/boft_controlnet/utils/unet_2d_condition.py @@ -0,0 +1,277 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Any, Dict, Optional, Tuple, Union + +import torch +from diffusers.models import UNet2DConditionModel +from diffusers.utils import BaseOutput, logging + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class UNet2DConditionOutput(BaseOutput): + """ + Args: + sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Hidden states conditioned on `encoder_hidden_states` input. Output of last layer of model. + """ + + sample: torch.FloatTensor + + +class UNet2DConditionNewModel(UNet2DConditionModel): + def forward( + self, + sample: torch.FloatTensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor, + guided_hint: Optional[torch.Tensor] = None, + class_labels: Optional[torch.Tensor] = None, + timestep_cond: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, + down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, + mid_block_additional_residual: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + return_dict: bool = True, + ) -> Union[UNet2DConditionOutput, Tuple]: + r""" + Args: + sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor + timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps + encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states + encoder_attention_mask (`torch.Tensor`): + (batch, sequence_length) cross-attention mask, applied to encoder_hidden_states. True = keep, False = + discard. Mask will be converted into a bias, which adds large negative values to attention scores + corresponding to "discard" tokens. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + added_cond_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified includes additonal conditions that can be used for additonal time + embeddings or encoder hidden states projections. See the configurations `encoder_hid_dim_type` and + `addition_embed_type` for more information. + + Returns: + [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: + [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is the sample tensor. + """ + # By default samples have to be AT least a multiple of the overall upsampling factor. + # The overall upsampling factor is equal to 2 ** (# num of upsampling layers). + # However, the upsampling interpolation output size can be forced to fit any upsampling size + # on the fly if necessary. + default_overall_up_factor = 2**self.num_upsamplers + + # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` + forward_upsample_size = False + upsample_size = None + + if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): + logger.info("Forward upsample size to force interpolation output size.") + forward_upsample_size = True + + # ensure attention_mask is a bias, and give it a singleton query_tokens dimension + # expects mask of shape: + # [batch, key_tokens] + # adds singleton query_tokens dimension: + # [batch, 1, key_tokens] + # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: + # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) + # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) + if attention_mask is not None: + # assume that mask is expressed as: + # (1 = keep, 0 = discard) + # convert mask into a bias that can be added to attention scores: + # (keep = +0, discard = -10000.0) + attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # convert encoder_attention_mask to a bias the same way we do for attention_mask + if encoder_attention_mask is not None: + encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0 + encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + + # 0. center input if necessary + if self.config.center_input_sample: + sample = 2 * sample - 1.0 + + # 1. time + timesteps = timestep + if not torch.is_tensor(timesteps): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = sample.device.type == "mps" + if isinstance(timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps.expand(sample.shape[0]) + + t_emb = self.time_proj(timesteps) + + # `Timesteps` does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=sample.dtype) + + emb = self.time_embedding(t_emb, timestep_cond) + + if self.class_embedding is not None: + if class_labels is None: + raise ValueError("class_labels should be provided when num_class_embeds > 0") + + if self.config.class_embed_type == "timestep": + class_labels = self.time_proj(class_labels) + + # `Timesteps` does not contain any weights and will always return f32 tensors + # there might be better ways to encapsulate this. + class_labels = class_labels.to(dtype=sample.dtype) + + class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype) + + if self.config.class_embeddings_concat: + emb = torch.cat([emb, class_emb], dim=-1) + else: + emb = emb + class_emb + + if self.config.addition_embed_type == "text": + aug_emb = self.add_embedding(encoder_hidden_states) + emb = emb + aug_emb + elif self.config.addition_embed_type == "text_image": + # Kadinsky 2.1 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" + ) + + image_embs = added_cond_kwargs.get("image_embeds") + text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states) + + aug_emb = self.add_embedding(text_embs, image_embs) + emb = emb + aug_emb + + if self.time_embed_act is not None: + emb = self.time_embed_act(emb) + + if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj": + encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states) + elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj": + # Kadinsky 2.1 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" + ) + + image_embeds = added_cond_kwargs.get("image_embeds") + encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds) + + # 2. pre-process and insert conditioning (ControlNet) + # Note: the added "guided_hint" is the only difference between this implementation and the original UNet2DConditionModel + sample = self.conv_in(sample) + sample = guided_hint + sample if guided_hint is not None else sample + + # 3. down + down_block_res_samples = (sample,) + for downsample_block in self.down_blocks: + if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, + ) + else: + sample, res_samples = downsample_block(hidden_states=sample, temb=emb) + + down_block_res_samples += res_samples + + if down_block_additional_residuals is not None: + new_down_block_res_samples = () + + for down_block_res_sample, down_block_additional_residual in zip( + down_block_res_samples, down_block_additional_residuals + ): + down_block_res_sample = down_block_res_sample + down_block_additional_residual + new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,) + + down_block_res_samples = new_down_block_res_samples + + # 4. mid + if self.mid_block is not None: + sample = self.mid_block( + sample, + emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, + ) + + if mid_block_additional_residual is not None: + sample = sample + mid_block_additional_residual + + # 5. up + for i, upsample_block in enumerate(self.up_blocks): + is_final_block = i == len(self.up_blocks) - 1 + + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] + + # if we have not reached the final block and need to forward the + # upsample size, we do it here + if not is_final_block and forward_upsample_size: + upsample_size = down_block_res_samples[-1].shape[2:] + + if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + upsample_size=upsample_size, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + ) + else: + sample = upsample_block( + hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size + ) + + # 6. post-process + if self.conv_norm_out: + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + if not return_dict: + return (sample,) + + return UNet2DConditionOutput(sample=sample) diff --git a/examples/boft_dreambooth/.gitignore b/examples/boft_dreambooth/.gitignore new file mode 100644 index 0000000000..adbb97d2d3 --- /dev/null +++ b/examples/boft_dreambooth/.gitignore @@ -0,0 +1 @@ +data/ \ No newline at end of file diff --git a/examples/boft_dreambooth/__init__.py b/examples/boft_dreambooth/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/examples/boft_dreambooth/boft_dreambooth.md b/examples/boft_dreambooth/boft_dreambooth.md new file mode 100644 index 0000000000..a4dda5b412 --- /dev/null +++ b/examples/boft_dreambooth/boft_dreambooth.md @@ -0,0 +1,165 @@ + + +# DreamBooth fine-tuning with BOFT + +This guide demonstrates how to use BOFT, an orthogonal fine-tuning method, to fine-tune Dreambooth with either `stabilityai/stable-diffusion-2-1` or `runwayml/stable-diffusion-v1-5` model. + +By using BOFT from 🤗 PEFT, we can significantly reduce the number of trainable parameters while still achieving impressive results in various fine-tuning tasks across different foundation models. BOFT enhances model efficiency by integrating full-rank orthogonal matrices with a butterfly structure into specific model blocks, such as attention blocks, mirroring the approach used in LoRA. During fine-tuning, only these inserted matrices are trained, leaving the original model parameters untouched. During inference, the trainable BOFT paramteres can be merged into the original model, eliminating any additional computational costs. + +As a member of the **orthogonal finetuning** class, BOFT presents a systematic and principled method for fine-tuning. It possesses several unique properties and has demonstrated superior performance compared to LoRA in a variety of scenarios. For further details on BOFT, please consult the [PEFT's GitHub repo's concept guide OFT](https://https://huggingface.co/docs/peft/index), the [original BOFT paper](https://arxiv.org/abs/2311.06243) and the [original OFT paper](https://arxiv.org/abs/2306.07280). + +In this guide we provide a Dreambooth fine-tuning script that is available in [PEFT's GitHub repo examples](https://github.com/huggingface/peft/tree/main/examples/boft_dreambooth). This implementation is adapted from [peft's lora_dreambooth](https://github.com/huggingface/peft/tree/main/examples/lora_dreambooth). You can try it out and finetune on your custom images. + +## Set up your environment + +Start by cloning the PEFT repository: + +```bash +git clone --recursive https://github.com/huggingface/peft +``` + +Navigate to the directory containing the training scripts for fine-tuning Dreambooth with BOFT: + +```bash +cd peft/examples/boft_dreambooth +``` + +Set up your environment: install PEFT, and all the required libraries. At the time of writing this guide we recommend installing PEFT from source. The following environment setup should work on A100 and H100: + +```bash +conda create --name peft python=3.10 +conda activate peft +conda install pytorch==2.1.2 torchvision==0.16.2 torchaudio==2.1.2 pytorch-cuda=11.8 -c pytorch -c nvidia +conda install xformers -c xformers +pip install -r requirements.txt +pip install git+https://github.com/huggingface/peft +``` + +## Download the data + +[dreambooth](https://github.com/google/dreambooth) dataset should have been automatically cloned in the following structure when running the training script. + +``` +boft_dreambooth +├── data +│ ├── data_dir +│ └── dreambooth +│ └── data +│ ├── backpack +│ └── backpack_dog +│ ... +``` + +You can also put your custom images into `boft_dreambooth/data/dreambooth`. + +## Finetune Dreambooth with BOFT + +```bash +./train_dreambooth.sh +``` + +or using the following script arguments: + +```bash +export MODEL_NAME="runwayml/stable-diffusion-v1-5" +export INSTANCE_DIR="path-to-instance-images" +export CLASS_DIR="path-to-class-images" +export OUTPUT_DIR="path-to-save-model" +``` + +Here: + +- `INSTANCE_DIR`: The directory containing the images that you intend to use for training your model. +- `CLASS_DIR`: The directory containing class-specific images. In this example, we use prior preservation to avoid overfitting and language-drift. For prior preservation, you need other images of the same class as part of the training process. However, these images can be generated and the training script will save them to a local path you specify here. +- `OUTPUT_DIR`: The destination folder for storing the trained model's weights. + +To learn more about DreamBooth fine-tuning with prior-preserving loss, check out the [Diffusers documentation](https://huggingface.co/docs/diffusers/training/dreambooth#finetuning-with-priorpreserving-loss). + +Launch the training script with `accelerate` and pass hyperparameters, as well as LoRa-specific arguments to it such as: + +- `use_boft`: Enables BOFT in the training script. +- `boft_block_size`: the BOFT matrix block size across different layers, expressed in `int`. Smaller block size results in sparser update matrices with fewer trainable paramters. **Note**, please choose it to be dividable to most layer `in_features` dimension, e.g., 4, 8, 16. Also, you can only specify either `boft_block_size` or `boft_block_num`, but not both simultaneously, because `boft_block_size` x `boft_block_num` = layer dimension. +- `boft_block_num`: the number of BOFT matrix blocks across different layers, expressed in `int`. Fewer blocks result in sparser update matrices with fewer trainable paramters. **Note**, please choose it to be dividable to most layer `in_features` dimension, e.g., 4, 8, 16. Also, you can only specify either `boft_block_size` or `boft_block_num`, but not both simultaneously, because `boft_block_size` x `boft_block_num` = layer dimension. +- `boft_n_butterfly_factor`: the number of butterfly factors. **Note**, for `boft_n_butterfly_factor=1`, BOFT is the same as vanilla OFT, for `boft_n_butterfly_factor=2`, the effective block size of OFT becomes twice as big and the number of blocks become half. +- `bias`: specify if the `bias` paramteres should be traind. Can be `none`, `all` or `boft_only`. +- `boft_dropout`: specify the probability of multiplicative dropout. + +Here's what the full set of script arguments may look like: + +```bash +PEFT_TYPE="boft" +BLOCK_NUM=8 +BLOCK_SIZE=0 +N_BUTTERFLY_FACTOR=1 + +VALIDATION_PROMPT=${PROMPT_LIST[@]} +INSTANCE_PROMPT="a photo of ${UNIQUE_TOKEN} ${CLASS_TOKEN}" +CLASS_PROMPT="a photo of ${CLASS_TOKEN}" + +export MODEL_NAME="stabilityai/stable-diffusion-2-1" +# export MODEL_NAME="runwayml/stable-diffusion-v1-5" +export PROJECT_NAME="dreambooth_${PEFT_TYPE}" +export RUN_NAME="${SELECTED_SUBJECT}_${PEFT_TYPE}_${BLOCK_NUM}${BLOCK_SIZE}${N_BUTTERFLY_FACTOR}" +export INSTANCE_DIR="./data/dreambooth/dataset/${SELECTED_SUBJECT}" +export CLASS_DIR="./data/class_data/${CLASS_TOKEN}" +export OUTPUT_DIR="./data/output/${PEFT_TYPE}" + + +accelerate launch train_dreambooth.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --instance_data_dir=$INSTANCE_DIR \ + --class_data_dir="$CLASS_DIR" \ + --output_dir=$OUTPUT_DIR \ + --wandb_project_name=$PROJECT_NAME \ + --wandb_run_name=$RUN_NAME \ + --with_prior_preservation --prior_loss_weight=1.0 \ + --instance_prompt="$INSTANCE_PROMPT" \ + --validation_prompt="$VALIDATION_PROMPT" \ + --class_prompt="$CLASS_PROMPT" \ + --resolution=512 \ + --train_batch_size=1 \ + --num_dataloader_workers=2 \ + --lr_scheduler="constant" \ + --lr_warmup_steps=0 \ + --num_class_images=200 \ + --use_boft \ + --boft_block_num=$BLOCK_NUM \ + --boft_block_size=$BLOCK_SIZE \ + --boft_n_butterfly_factor=$N_BUTTERFLY_FACTOR \ + --boft_dropout=0.1 \ + --boft_bias="boft_only" \ + --learning_rate=3e-5 \ + --max_train_steps=1010 \ + --checkpointing_steps=200 \ + --validation_steps=200 \ + --enable_xformers_memory_efficient_attention \ + --report_to="wandb" \ +``` + +or use this training script: + +```bash +./train_dreambooth.sh $idx +``` + +with the `$idx` corresponds to different subjects. + +If you are running this script on Windows, you may need to set the `--num_dataloader_workers` to 0. + +## Inference with a single adapter + +To run inference with the fine-tuned model, simply run the jupyter notebook `dreambooth_inference.ipynb` for visualization with `jupyter notebook` under `./examples/boft_dreambooth`. diff --git a/examples/boft_dreambooth/data/dreambooth b/examples/boft_dreambooth/data/dreambooth new file mode 160000 index 0000000000..4f887af797 --- /dev/null +++ b/examples/boft_dreambooth/data/dreambooth @@ -0,0 +1 @@ +Subproject commit 4f887af7970a06fc0cd3adaa1d0b368547d6a1d0 diff --git a/examples/boft_dreambooth/dreambooth_inference.ipynb b/examples/boft_dreambooth/dreambooth_inference.ipynb new file mode 100644 index 0000000000..99f6debc83 --- /dev/null +++ b/examples/boft_dreambooth/dreambooth_inference.ipynb @@ -0,0 +1,186 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "acab479f", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "\n", + "import torch\n", + "from accelerate.logging import get_logger\n", + "from diffusers import StableDiffusionPipeline\n", + "from diffusers.utils import check_min_version\n", + "\n", + "from peft import PeftModel\n", + "\n", + "# Will error if the minimal version of diffusers is not installed. Remove at your own risks.\n", + "check_min_version(\"0.10.0.dev0\")\n", + "\n", + "logger = get_logger(__name__)\n", + "\n", + "MODEL_NAME = \"stabilityai/stable-diffusion-2-1\"\n", + "# MODEL_NAME=\"runwayml/stable-diffusion-v1-5\"\n", + "\n", + "PEFT_TYPE=\"boft\"\n", + "BLOCK_NUM=8\n", + "BLOCK_SIZE=0\n", + "N_BUTTERFLY_FACTOR=1\n", + "SELECTED_SUBJECT=\"backpack\"\n", + "EPOCH_IDX = 200\n", + "\n", + "PROJECT_NAME=f\"dreambooth_{PEFT_TYPE}\"\n", + "RUN_NAME=f\"{SELECTED_SUBJECT}_{PEFT_TYPE}_{BLOCK_NUM}{BLOCK_SIZE}{N_BUTTERFLY_FACTOR}\"\n", + "OUTPUT_DIR=f\"./data/output/{PEFT_TYPE}\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "06cfd506", + "metadata": {}, + "outputs": [], + "source": [ + "def get_boft_sd_pipeline(\n", + " ckpt_dir, base_model_name_or_path=None, epoch=int, dtype=torch.float32, device=\"cuda\", adapter_name=\"default\"\n", + "):\n", + "\n", + " if base_model_name_or_path is None:\n", + " raise ValueError(\"Please specify the base model name or path\")\n", + "\n", + " pipe = StableDiffusionPipeline.from_pretrained(\n", + " base_model_name_or_path, torch_dtype=dtype, requires_safety_checker=False\n", + " ).to(device)\n", + " \n", + " load_adapter(pipe, ckpt_dir, epoch, adapter_name)\n", + "\n", + " if dtype in (torch.float16, torch.bfloat16):\n", + " pipe.unet.half()\n", + " pipe.text_encoder.half()\n", + "\n", + " pipe.to(device)\n", + " return pipe\n", + "\n", + "\n", + "def load_adapter(pipe, ckpt_dir, epoch, adapter_name=\"default\"):\n", + " \n", + " unet_sub_dir = os.path.join(ckpt_dir, f\"unet/{epoch}\", adapter_name)\n", + " text_encoder_sub_dir = os.path.join(ckpt_dir, f\"text_encoder/{epoch}\", adapter_name)\n", + " \n", + " if isinstance(pipe.unet, PeftModel):\n", + " pipe.unet.load_adapter(unet_sub_dir, adapter_name=adapter_name)\n", + " else:\n", + " pipe.unet = PeftModel.from_pretrained(pipe.unet, unet_sub_dir, adapter_name=adapter_name)\n", + " \n", + " if os.path.exists(text_encoder_sub_dir):\n", + " if isinstance(pipe.text_encoder, PeftModel):\n", + " pipe.text_encoder.load_adapter(text_encoder_sub_dir, adapter_name=adapter_name)\n", + " else:\n", + " pipe.text_encoder = PeftModel.from_pretrained(pipe.text_encoder, text_encoder_sub_dir, adapter_name=adapter_name)\n", + " \n", + "\n", + "def set_adapter(pipe, adapter_name):\n", + " pipe.unet.set_adapter(adapter_name)\n", + " if isinstance(pipe.text_encoder, PeftModel):\n", + " pipe.text_encoder.set_adapter(adapter_name)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "98a0d8ac", + "metadata": {}, + "outputs": [], + "source": [ + "prompt = \"a photo of sks backpack on a wooden floor\"\n", + "negative_prompt = \"low quality, blurry, unfinished\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d4e888d2", + "metadata": {}, + "outputs": [], + "source": [ + "%%time\n", + "pipe = get_boft_sd_pipeline(OUTPUT_DIR, MODEL_NAME, EPOCH_IDX, adapter_name=RUN_NAME)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f1c1a1c0", + "metadata": {}, + "outputs": [], + "source": [ + "%%time\n", + "image = pipe(prompt, num_inference_steps=50, guidance_scale=7, negative_prompt=negative_prompt).images[0]\n", + "image" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3a1aafdf-8cf7-4e47-9471-26478034245e", + "metadata": {}, + "outputs": [], + "source": [ + "# load and reset another adapter\n", + "# WARNING: requires training DreamBooth with `boft_bias=None`\n", + "\n", + "SELECTED_SUBJECT=\"dog\"\n", + "EPOCH_IDX = 200\n", + "RUN_NAME=f\"{SELECTED_SUBJECT}_{PEFT_TYPE}_{BLOCK_NUM}{BLOCK_SIZE}{N_BUTTERFLY_FACTOR}\"\n", + "\n", + "load_adapter(pipe, OUTPUT_DIR, epoch=EPOCH_IDX, adapter_name=RUN_NAME)\n", + "set_adapter(pipe, adapter_name=RUN_NAME)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c7091ad0-2005-4528-afc1-4f9d70a9a535", + "metadata": {}, + "outputs": [], + "source": [ + "%%time\n", + "prompt = \"a photo of sks dog running on the beach\"\n", + "negative_prompt = \"low quality, blurry, unfinished\"\n", + "image = pipe(prompt, num_inference_steps=50, guidance_scale=7, negative_prompt=negative_prompt).images[0]\n", + "image" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f534eca2-94a4-432b-b092-7149ac44b12f", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python [conda env:peft] *", + "language": "python", + "name": "conda-env-peft-py" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.13" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/boft_dreambooth/requirements.txt b/examples/boft_dreambooth/requirements.txt new file mode 100644 index 0000000000..3cc487940f --- /dev/null +++ b/examples/boft_dreambooth/requirements.txt @@ -0,0 +1,13 @@ +transformers==4.36.2 +accelerate==0.25.0 +evaluate +tqdm +datasets==2.16.1 +diffusers==0.17.1 +Pillow +huggingface_hub +safetensors +nb_conda_kernels +ipykernel +ipywidgets +wandb==0.16.1 \ No newline at end of file diff --git a/examples/boft_dreambooth/train_dreambooth.py b/examples/boft_dreambooth/train_dreambooth.py new file mode 100644 index 0000000000..182dd6a1f3 --- /dev/null +++ b/examples/boft_dreambooth/train_dreambooth.py @@ -0,0 +1,612 @@ +#!/usr/bin/env python +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# The implementation is based on "Parameter-Efficient Orthogonal Finetuning +# via Butterfly Factorization" (https://arxiv.org/abs/2311.06243) in ICLR 2024. + +import hashlib +import itertools +import logging +import math +import os +from contextlib import nullcontext +from pathlib import Path + +import datasets +import diffusers +import numpy as np +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +import transformers +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import ProjectConfiguration, set_seed +from diffusers import ( + AutoencoderKL, + DDIMScheduler, + DiffusionPipeline, + DPMSolverMultistepScheduler, + UNet2DConditionModel, +) +from diffusers.optimization import get_scheduler +from diffusers.utils import check_min_version +from diffusers.utils.import_utils import is_xformers_available +from huggingface_hub import Repository +from tqdm.auto import tqdm +from transformers import AutoTokenizer +from utils.args_loader import ( + get_full_repo_name, + import_model_class_from_model_name_or_path, + parse_args, +) +from utils.dataset import DreamBoothDataset, PromptDataset, collate_fn +from utils.tracemalloc import TorchTracemalloc, b2mb + +from peft import BOFTConfig, get_peft_model + + +# Will error if the minimal version of diffusers is not installed. Remove at your own risks. +check_min_version("0.16.0.dev0") + +logger = get_logger(__name__) + +UNET_TARGET_MODULES = ["to_q", "to_v", "to_k", "query", "value", "key", "to_out.0", "add_k_proj", "add_v_proj"] +TEXT_ENCODER_TARGET_MODULES = ["q_proj", "v_proj"] + + +def save_adaptor(accelerator, step, unet, text_encoder, args): + unwarpped_unet = accelerator.unwrap_model(unet) + unwarpped_unet.save_pretrained( + os.path.join(args.output_dir, f"unet/{step}"), state_dict=accelerator.get_state_dict(unet) + ) + if args.train_text_encoder: + unwarpped_text_encoder = accelerator.unwrap_model(text_encoder) + unwarpped_text_encoder.save_pretrained( + os.path.join(args.output_dir, f"text_encoder/{step}"), + state_dict=accelerator.get_state_dict(text_encoder), + ) + + +def main(args): + validation_prompts = list(filter(None, args.validation_prompt[0].split("."))) + + logging_dir = Path(args.output_dir, args.logging_dir) + accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) + + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with=args.report_to, + project_dir=accelerator_project_config, + ) + if args.report_to == "wandb": + import wandb + + wandb_init = { + "wandb": { + "name": args.wandb_run_name, + "mode": "online", + } + } + + # Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate + # This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models. + # TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate. + if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1: + raise ValueError( + "Gradient accumulation is not supported when training the text encoder in distributed training. " + "Please set gradient_accumulation_steps to 1. This feature will be supported in the future." + ) + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + datasets.utils.logging.set_verbosity_warning() + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + datasets.utils.logging.set_verbosity_error() + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + # If passed along, set the training seed now. + global_seed = hash(args.wandb_run_name) % (2**32) + set_seed(global_seed) + + # Generate class images if prior preservation is enabled. + if args.with_prior_preservation: + class_images_dir = Path(args.class_data_dir) + if not class_images_dir.exists(): + class_images_dir.mkdir(parents=True) + cur_class_images = len(list(class_images_dir.iterdir())) + + if cur_class_images < args.num_class_images: + torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32 + if args.prior_generation_precision == "fp32": + torch_dtype = torch.float32 + elif args.prior_generation_precision == "fp16": + torch_dtype = torch.float16 + elif args.prior_generation_precision == "bf16": + torch_dtype = torch.bfloat16 + pipeline = DiffusionPipeline.from_pretrained( + args.pretrained_model_name_or_path, + torch_dtype=torch_dtype, + safety_checker=None, + revision=args.revision, + ) + pipeline.set_progress_bar_config(disable=True) + + num_new_images = args.num_class_images - cur_class_images + logger.info(f"Number of class images to sample: {num_new_images}.") + + sample_dataset = PromptDataset(args.class_prompt, num_new_images) + sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size) + + sample_dataloader = accelerator.prepare(sample_dataloader) + pipeline.to(accelerator.device) + + for example in tqdm( + sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process + ): + images = pipeline(example["prompt"]).images + + for i, image in enumerate(images): + hash_image = hashlib.sha1(image.tobytes()).hexdigest() + image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg" + image.save(image_filename) + + del pipeline + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + # Handle the repository creation + if accelerator.is_main_process: + if args.push_to_hub: + if args.hub_model_id is None: + repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token) + else: + repo_name = args.hub_model_id + repo = Repository(args.output_dir, clone_from=repo_name) # noqa: F841 + + with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore: + if "step_*" not in gitignore: + gitignore.write("step_*\n") + if "epoch_*" not in gitignore: + gitignore.write("epoch_*\n") + elif args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + # Load the tokenizer + if args.tokenizer_name: + tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False) + elif args.pretrained_model_name_or_path: + tokenizer = AutoTokenizer.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="tokenizer", + revision=args.revision, + use_fast=False, + ) + + # import correct text encoder class + text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision) + + # Load scheduler and models + noise_scheduler = DDIMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") + + text_encoder = text_encoder_cls.from_pretrained( + args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision + ) + vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision) + unet = UNet2DConditionModel.from_pretrained( + args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision + ) + + if args.use_boft: + config = BOFTConfig( + boft_block_size=args.boft_block_size, + boft_block_num=args.boft_block_num, + boft_n_butterfly_factor=args.boft_n_butterfly_factor, + target_modules=UNET_TARGET_MODULES, + boft_dropout=args.boft_dropout, + bias=args.boft_bias, + ) + unet = get_peft_model(unet, config, adapter_name=args.wandb_run_name) + unet.print_trainable_parameters() + + vae.requires_grad_(False) + unet.train() + + if args.train_text_encoder and args.use_boft: + config = BOFTConfig( + boft_block_size=args.boft_block_size, + boft_block_num=args.boft_block_num, + boft_n_butterfly_factor=args.boft_n_butterfly_factor, + target_modules=TEXT_ENCODER_TARGET_MODULES, + boft_dropout=args.boft_dropout, + bias=args.boft_bias, + ) + text_encoder = get_peft_model(text_encoder, config, adapter_name=args.wandb_run_name) + text_encoder.print_trainable_parameters() + text_encoder.train() + else: + text_encoder.requires_grad_(False) + + # For mixed precision training we cast the text_encoder and vae weights to half-precision + # as these models are only used for inference, keeping weights in full precision is not required. + weight_dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + # Move unet, vae and text_encoder to device and cast to weight_dtype + unet.to(accelerator.device, dtype=weight_dtype) + vae.to(accelerator.device, dtype=weight_dtype) + text_encoder.to(accelerator.device, dtype=weight_dtype) + + if args.enable_xformers_memory_efficient_attention: + if is_xformers_available(): + unet.enable_xformers_memory_efficient_attention() + else: + raise ValueError("xformers is not available. Make sure it is installed correctly") + + if args.gradient_checkpointing: + unet.enable_gradient_checkpointing() + # below fails when using boft so commenting it out + if args.train_text_encoder and not args.use_boft: + text_encoder.gradient_checkpointing_enable() + + # Enable TF32 for faster training on Ampere GPUs, + # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices + if args.allow_tf32: + torch.backends.cuda.matmul.allow_tf32 = True + + if args.scale_lr: + args.learning_rate = ( + args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes + ) + + # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs + if args.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError( + "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." + ) + + optimizer_class = bnb.optim.AdamW8bit + else: + optimizer_class = torch.optim.AdamW + + # Optimizer creation + params_to_optimize = [param for param in unet.parameters() if param.requires_grad] + + if args.train_text_encoder: + params_to_optimize += [param for param in text_encoder.parameters() if param.requires_grad] + + optimizer = optimizer_class( + params_to_optimize, + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + + # Download the official dreambooth dataset from the official repository: https://github.com/google/dreambooth.git + data_path = os.path.join(os.getcwd(), "data", "dreambooth") + if not os.path.exists(data_path): + os.makedirs(os.path.join(os.getcwd(), "data"), exist_ok=True) + os.system(f"git clone https://github.com/google/dreambooth.git '{data_path}'") + + # Dataset and DataLoaders creation: + train_dataset = DreamBoothDataset( + instance_data_root=args.instance_data_dir, + instance_prompt=args.instance_prompt, + class_data_root=args.class_data_dir if args.with_prior_preservation else None, + class_prompt=args.class_prompt, + tokenizer=tokenizer, + size=args.resolution, + center_crop=args.center_crop, + ) + + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + batch_size=args.train_batch_size, + shuffle=True, + collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation), + num_workers=args.num_dataloader_workers, + ) + + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, + num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, + num_cycles=args.lr_num_cycles, + power=args.lr_power, + ) + + # Prepare everything with our `accelerator`. + if args.train_text_encoder: + unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, text_encoder, optimizer, train_dataloader, lr_scheduler + ) + else: + unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, optimizer, train_dataloader, lr_scheduler + ) + + # For mixed precision training we cast the text_encoder and vae weights to half-precision + # as these models are only used for inference, keeping weights in full precision is not required. + weight_dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + # Move vae and text_encoder to device and cast to weight_dtype + vae.to(accelerator.device, dtype=weight_dtype) + if not args.train_text_encoder: + text_encoder.to(accelerator.device, dtype=weight_dtype) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + accelerator.init_trackers(args.wandb_project_name, config=vars(args), init_kwargs=wandb_init) + + # Train! + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num batches each epoch = {len(train_dataloader)}") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + global_step = 0 + first_epoch = 0 + + # Potentially load in the weights and states from a previous save + if args.resume_from_checkpoint: + if args.resume_from_checkpoint != "latest": + path = os.path.basename(args.resume_from_checkpoint) + else: + # Get the most recent checkpoint + dirs = os.listdir(args.output_dir) + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] if len(dirs) > 0 else None + accelerator.print(f"Resuming from checkpoint {path}") + accelerator.load_state(os.path.join(args.output_dir, path)) + global_step = int(path.split("-")[1]) + + resume_global_step = global_step * args.gradient_accumulation_steps + first_epoch = resume_global_step // num_update_steps_per_epoch + resume_step = resume_global_step % num_update_steps_per_epoch + + # Only show the progress bar once on each machine. + progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process) + progress_bar.set_description("Steps") + + if args.train_text_encoder: + text_encoder.train() + + for epoch in range(first_epoch, args.num_train_epochs): + unet.train() + + with TorchTracemalloc() if not args.no_tracemalloc else nullcontext() as tracemalloc: + for step, batch in enumerate(train_dataloader): + # Skip steps until we reach the resumed step + if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step: + if step % args.gradient_accumulation_steps == 0: + progress_bar.update(1) + if args.report_to == "wandb": + accelerator.print(progress_bar) + continue + + with accelerator.accumulate(unet): + # Convert images to latent space + latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample() + latents = latents * vae.config.scaling_factor + + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents) + bsz = latents.shape[0] + # Sample a random timestep for each image + timesteps = torch.randint( + 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device + ) + timesteps = timesteps.long() + + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + + # Get the text embedding for conditioning + encoder_hidden_states = text_encoder(batch["input_ids"])[0] + + # Predict the noise residual + model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + + # Get the target for loss depending on the prediction type + if noise_scheduler.config.prediction_type == "epsilon": + target = noise + elif noise_scheduler.config.prediction_type == "v_prediction": + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") + + if args.with_prior_preservation: + # Chunk the noise and model_pred into two parts and compute the loss on each part separately. + model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) + target, target_prior = torch.chunk(target, 2, dim=0) + + # Compute instance loss + loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + + # Compute prior loss + prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") + + # Add the prior loss to the instance loss. + loss = loss + args.prior_loss_weight * prior_loss + else: + loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + + accelerator.backward(loss) + + if accelerator.sync_gradients: + params_to_clip = ( + itertools.chain(unet.parameters(), text_encoder.parameters()) + if args.train_text_encoder + else unet.parameters() + ) + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + if args.report_to == "wandb": + accelerator.print(progress_bar) + global_step += 1 + + if global_step % args.checkpointing_steps == 0 and global_step != 0: + if accelerator.is_main_process: + save_adaptor(accelerator, global_step, unet, text_encoder, args) + + logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + accelerator.log(logs, step=global_step) + + if ( + args.validation_prompt is not None + and (step + num_update_steps_per_epoch * epoch) % args.validation_steps == 0 + and global_step > 10 + ): + unet.eval() + + logger.info( + f"Running validation... \n Generating {len(validation_prompts)} images with prompt:" + f" {validation_prompts[0]}, ......" + ) + # create pipeline + pipeline = DiffusionPipeline.from_pretrained( + args.pretrained_model_name_or_path, + safety_checker=None, + revision=args.revision, + ) + # set `keep_fp32_wrapper` to True because we do not want to remove + # mixed precision hooks while we are still training + pipeline.unet = accelerator.unwrap_model(unet, keep_fp32_wrapper=True) + pipeline.text_encoder = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True) + pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config) + pipeline = pipeline.to(accelerator.device) + pipeline.set_progress_bar_config(disable=True) + + # run inference + if args.seed is not None: + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) + else: + generator = None + # images = [] + # for _ in range(args.num_validation_images): + # image = pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0] + # images.append(image) + + images = [] + val_img_dir = os.path.join( + args.output_dir, + f"validation/{global_step}", + args.wandb_run_name, + ) + os.makedirs(val_img_dir, exist_ok=True) + + for val_promot in validation_prompts: + image = pipeline(val_promot, num_inference_steps=50, generator=generator).images[0] + image.save(os.path.join(val_img_dir, f"{'_'.join(val_promot.split(' '))}.png"[1:])) + images.append(image) + + for tracker in accelerator.trackers: + if tracker.name == "tensorboard": + np_images = np.stack([np.asarray(img) for img in images]) + tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC") + if tracker.name == "wandb": + import wandb + + tracker.log( + { + "validation": [ + wandb.Image(image, caption=f"{i}: {validation_prompts[i]}") + for i, image in enumerate(images) + ] + } + ) + + del pipeline + torch.cuda.empty_cache() + + if global_step >= args.max_train_steps: + break + + # Printing the GPU memory usage details such as allocated memory, peak memory, and total memory usage + if not args.no_tracemalloc: + accelerator.print(f"GPU Memory before entering the train : {b2mb(tracemalloc.begin)}") + accelerator.print(f"GPU Memory consumed at the end of the train (end-begin): {tracemalloc.used}") + accelerator.print(f"GPU Peak Memory consumed during the train (max-begin): {tracemalloc.peaked}") + accelerator.print( + f"GPU Total Peak Memory consumed during the train (max): {tracemalloc.peaked + b2mb(tracemalloc.begin)}" + ) + + accelerator.print(f"CPU Memory before entering the train : {b2mb(tracemalloc.cpu_begin)}") + accelerator.print(f"CPU Memory consumed at the end of the train (end-begin): {tracemalloc.cpu_used}") + accelerator.print(f"CPU Peak Memory consumed during the train (max-begin): {tracemalloc.cpu_peaked}") + accelerator.print( + f"CPU Total Peak Memory consumed during the train (max): {tracemalloc.cpu_peaked + b2mb(tracemalloc.cpu_begin)}" + ) + + if args.push_to_hub: + repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True) + accelerator.end_training() + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/examples/boft_dreambooth/train_dreambooth.sh b/examples/boft_dreambooth/train_dreambooth.sh new file mode 100755 index 0000000000..f886a4fd1d --- /dev/null +++ b/examples/boft_dreambooth/train_dreambooth.sh @@ -0,0 +1,191 @@ +IDX=$1 +PROMPT_IDX=$((IDX % 25)) +CLASS_IDX=$((IDX % 30)) + +# Define the UNIQUE_TOKEN, CLASS_TOKENs, and SUBJECT_NAMES +UNIQUE_TOKEN="qwe" + +SUBJECT_NAMES=( + "backpack" "backpack_dog" "bear_plushie" "berry_bowl" "can" + "candle" "cat" "cat2" "clock" "colorful_sneaker" + "dog" "dog2" "dog3" "dog5" "dog6" + "dog7" "dog8" "duck_toy" "fancy_boot" "grey_sloth_plushie" + "monster_toy" "pink_sunglasses" "poop_emoji" "rc_car" "red_cartoon" + "robot_toy" "shiny_sneaker" "teapot" "vase" "wolf_plushie" +) + +CLASS_TOKENs=( + "backpack" "backpack" "stuffed animal" "bowl" "can" + "candle" "cat" "cat" "clock" "sneaker" + "dog" "dog" "dog" "dog" "dog" + "dog" "dog" "toy" "boot" "stuffed animal" + "toy" "glasses" "toy" "toy" "cartoon" + "toy" "sneaker" "teapot" "vase" "stuffed animal" +) + +CLASS_TOKEN=${CLASS_TOKENs[$CLASS_IDX]} +SELECTED_SUBJECT=${SUBJECT_NAMES[$CLASS_IDX]} + +if [[ $CLASS_IDX =~ ^(0|1|2|3|4|5|8|9|17|18|19|20|21|22|23|24|25|26|27|28|29)$ ]]; then + PROMPT_LIST=( + "a ${UNIQUE_TOKEN} ${CLASS_TOKEN} in the jungle." + "a ${UNIQUE_TOKEN} ${CLASS_TOKEN} in the snow." + "a ${UNIQUE_TOKEN} ${CLASS_TOKEN} on the beach." + "a ${UNIQUE_TOKEN} ${CLASS_TOKEN} on a cobblestone street." + "a ${UNIQUE_TOKEN} ${CLASS_TOKEN} on top of pink fabric." + "a ${UNIQUE_TOKEN} ${CLASS_TOKEN} on top of a wooden floor." + "a ${UNIQUE_TOKEN} ${CLASS_TOKEN} with a city in the background." + "a ${UNIQUE_TOKEN} ${CLASS_TOKEN} with a mountain in the background." + "a ${UNIQUE_TOKEN} ${CLASS_TOKEN} with a blue house in the background." + "a ${UNIQUE_TOKEN} ${CLASS_TOKEN} on top of a purple rug in a forest." + "a ${UNIQUE_TOKEN} ${CLASS_TOKEN} with a wheat field in the background." + "a ${UNIQUE_TOKEN} ${CLASS_TOKEN} with a tree and autumn leaves in the background." + "a ${UNIQUE_TOKEN} ${CLASS_TOKEN} with the Eiffel Tower in the background." + "a ${UNIQUE_TOKEN} ${CLASS_TOKEN} floating on top of water." + "a ${UNIQUE_TOKEN} ${CLASS_TOKEN} floating in an ocean of milk." + "a ${UNIQUE_TOKEN} ${CLASS_TOKEN} on top of green grass with sunflowers around it." + "a ${UNIQUE_TOKEN} ${CLASS_TOKEN} on top of a mirror." + "a ${UNIQUE_TOKEN} ${CLASS_TOKEN} on top of the sidewalk in a crowded street." + "a ${UNIQUE_TOKEN} ${CLASS_TOKEN} on top of a dirt road." + "a ${UNIQUE_TOKEN} ${CLASS_TOKEN} on top of a white rug." + "a red ${UNIQUE_TOKEN} ${CLASS_TOKEN}." + "a purple ${UNIQUE_TOKEN} ${CLASS_TOKEN}." + "a shiny ${UNIQUE_TOKEN} ${CLASS_TOKEN}." + "a wet ${UNIQUE_TOKEN} ${CLASS_TOKEN}." + "a cube shaped ${UNIQUE_TOKEN} ${CLASS_TOKEN}." + ) + + prompt_test_list=( + "a ${CLASS_TOKEN} in the jungle" + "a ${CLASS_TOKEN} in the snow" + "a ${CLASS_TOKEN} on the beach" + "a ${CLASS_TOKEN} on a cobblestone street" + "a ${CLASS_TOKEN} on top of pink fabric" + "a ${CLASS_TOKEN} on top of a wooden floor" + "a ${CLASS_TOKEN} with a city in the background" + "a ${CLASS_TOKEN} with a mountain in the background" + "a ${CLASS_TOKEN} with a blue house in the background" + "a ${CLASS_TOKEN} on top of a purple rug in a forest" + "a ${CLASS_TOKEN} with a wheat field in the background" + "a ${CLASS_TOKEN} with a tree and autumn leaves in the background" + "a ${CLASS_TOKEN} with the Eiffel Tower in the background" + "a ${CLASS_TOKEN} floating on top of water" + "a ${CLASS_TOKEN} floating in an ocean of milk" + "a ${CLASS_TOKEN} on top of green grass with sunflowers around it" + "a ${CLASS_TOKEN} on top of a mirror" + "a ${CLASS_TOKEN} on top of the sidewalk in a crowded street" + "a ${CLASS_TOKEN} on top of a dirt road" + "a ${CLASS_TOKEN} on top of a white rug" + "a red ${CLASS_TOKEN}" + "a purple ${CLASS_TOKEN}" + "a shiny ${CLASS_TOKEN}" + "a wet ${CLASS_TOKEN}" + "a cube shaped ${CLASS_TOKEN}" + ) + +else + PROMPT_LIST=( + "a ${UNIQUE_TOKEN} ${CLASS_TOKEN} in the jungle." + "a ${UNIQUE_TOKEN} ${CLASS_TOKEN} in the snow." + "a ${UNIQUE_TOKEN} ${CLASS_TOKEN} on the beach." + "a ${UNIQUE_TOKEN} ${CLASS_TOKEN} on a cobblestone street." + "a ${UNIQUE_TOKEN} ${CLASS_TOKEN} on top of pink fabric." + "a ${UNIQUE_TOKEN} ${CLASS_TOKEN} on top of a wooden floor." + "a ${UNIQUE_TOKEN} ${CLASS_TOKEN} with a city in the background." + "a ${UNIQUE_TOKEN} ${CLASS_TOKEN} with a mountain in the background." + "a ${UNIQUE_TOKEN} ${CLASS_TOKEN} with a blue house in the background." + "a ${UNIQUE_TOKEN} ${CLASS_TOKEN} on top of a purple rug in a forest." + "a ${UNIQUE_TOKEN} ${CLASS_TOKEN} wearing a red hat." + "a ${UNIQUE_TOKEN} ${CLASS_TOKEN} wearing a santa hat." + "a ${UNIQUE_TOKEN} ${CLASS_TOKEN} wearing a rainbow scarf." + "a ${UNIQUE_TOKEN} ${CLASS_TOKEN} wearing a black top hat and a monocle." + "a ${UNIQUE_TOKEN} ${CLASS_TOKEN} in a chef outfit." + "a ${UNIQUE_TOKEN} ${CLASS_TOKEN} in a firefighter outfit." + "a ${UNIQUE_TOKEN} ${CLASS_TOKEN} in a police outfit." + "a ${UNIQUE_TOKEN} ${CLASS_TOKEN} wearing pink glasses." + "a ${UNIQUE_TOKEN} ${CLASS_TOKEN} wearing a yellow shirt." + "a ${UNIQUE_TOKEN} ${CLASS_TOKEN} in a purple wizard outfit." + "a red ${UNIQUE_TOKEN} ${CLASS_TOKEN}." + "a purple ${UNIQUE_TOKEN} ${CLASS_TOKEN}." + "a shiny ${UNIQUE_TOKEN} ${CLASS_TOKEN}." + "a wet ${UNIQUE_TOKEN} ${CLASS_TOKEN}." + "a cube shaped ${UNIQUE_TOKEN} ${CLASS_TOKEN}." + ) + + prompt_test_list=( + "a ${CLASS_TOKEN} in the jungle" + "a ${CLASS_TOKEN} in the snow" + "a ${CLASS_TOKEN} on the beach" + "a ${CLASS_TOKEN} on a cobblestone street" + "a ${CLASS_TOKEN} on top of pink fabric" + "a ${CLASS_TOKEN} on top of a wooden floor" + "a ${CLASS_TOKEN} with a city in the background" + "a ${CLASS_TOKEN} with a mountain in the background" + "a ${CLASS_TOKEN} with a blue house in the background" + "a ${CLASS_TOKEN} on top of a purple rug in a forest" + "a ${CLASS_TOKEN} wearing a red hat" + "a ${CLASS_TOKEN} wearing a santa hat" + "a ${CLASS_TOKEN} wearing a rainbow scarf" + "a ${CLASS_TOKEN} wearing a black top hat and a monocle" + "a ${CLASS_TOKEN} in a chef outfit" + "a ${CLASS_TOKEN} in a firefighter outfit" + "a ${CLASS_TOKEN} in a police outfit" + "a ${CLASS_TOKEN} wearing pink glasses" + "a ${CLASS_TOKEN} wearing a yellow shirt" + "a ${CLASS_TOKEN} in a purple wizard outfit" + "a red ${CLASS_TOKEN}" + "a purple ${CLASS_TOKEN}" + "a shiny ${CLASS_TOKEN}" + "a wet ${CLASS_TOKEN}" + "a cube shaped ${CLASS_TOKEN}" + ) +fi + +VALIDATION_PROMPT=${PROMPT_LIST[@]} +INSTANCE_PROMPT="a photo of ${UNIQUE_TOKEN} ${CLASS_TOKEN}" +CLASS_PROMPT="a photo of ${CLASS_TOKEN}" + +export MODEL_NAME="stabilityai/stable-diffusion-2-1" +# export MODEL_NAME="runwayml/stable-diffusion-v1-5" + +PEFT_TYPE="boft" +BLOCK_NUM=8 +BLOCK_SIZE=0 +N_BUTTERFLY_FACTOR=1 + +export PROJECT_NAME="dreambooth_${PEFT_TYPE}" +export RUN_NAME="${SELECTED_SUBJECT}_${PEFT_TYPE}_${BLOCK_NUM}${BLOCK_SIZE}${N_BUTTERFLY_FACTOR}" +export INSTANCE_DIR="./data/dreambooth/dataset/${SELECTED_SUBJECT}" +export CLASS_DIR="./data/class_data/${CLASS_TOKEN}" +export OUTPUT_DIR="./data/output/${PEFT_TYPE}" + + +accelerate launch train_dreambooth.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --instance_data_dir=$INSTANCE_DIR \ + --class_data_dir="$CLASS_DIR" \ + --output_dir=$OUTPUT_DIR \ + --wandb_project_name=$PROJECT_NAME \ + --wandb_run_name=$RUN_NAME \ + --with_prior_preservation --prior_loss_weight=1.0 \ + --instance_prompt="$INSTANCE_PROMPT" \ + --validation_prompt="$VALIDATION_PROMPT" \ + --class_prompt="$CLASS_PROMPT" \ + --resolution=512 \ + --train_batch_size=1 \ + --num_dataloader_workers=2 \ + --lr_scheduler="constant" \ + --lr_warmup_steps=0 \ + --num_class_images=200 \ + --use_boft \ + --boft_block_num=$BLOCK_NUM \ + --boft_block_size=$BLOCK_SIZE \ + --boft_n_butterfly_factor=$N_BUTTERFLY_FACTOR \ + --boft_dropout=0.1 \ + --boft_bias="boft_only" \ + --learning_rate=3e-5 \ + --max_train_steps=1010 \ + --checkpointing_steps=200 \ + --validation_steps=200 \ + --enable_xformers_memory_efficient_attention \ + --report_to="wandb" \ \ No newline at end of file diff --git a/examples/boft_dreambooth/utils/__init__.py b/examples/boft_dreambooth/utils/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/examples/boft_dreambooth/utils/args_loader.py b/examples/boft_dreambooth/utils/args_loader.py new file mode 100644 index 0000000000..dd946e20f5 --- /dev/null +++ b/examples/boft_dreambooth/utils/args_loader.py @@ -0,0 +1,363 @@ +import argparse +import os +import warnings +from typing import Optional + +from huggingface_hub import HfFolder, whoami +from transformers import PretrainedConfig + + +def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str): + text_encoder_config = PretrainedConfig.from_pretrained( + pretrained_model_name_or_path, + subfolder="text_encoder", + revision=revision, + ) + model_class = text_encoder_config.architectures[0] + + if model_class == "CLIPTextModel": + from transformers import CLIPTextModel + + return CLIPTextModel + elif model_class == "RobertaSeriesModelWithTransformation": + from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation + + return RobertaSeriesModelWithTransformation + else: + raise ValueError(f"{model_class} is not supported.") + + +def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None): + if token is None: + token = HfFolder.get_token() + if organization is None: + username = whoami(token)["name"] + return f"{username}/{model_id}" + else: + return f"{organization}/{model_id}" + + +def parse_args(input_args=None): + parser = argparse.ArgumentParser(description="Simple example of a Dreambooth training script.") + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + required=True, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + required=False, + help="Revision of pretrained model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--tokenizer_name", + type=str, + default=None, + help="Pretrained tokenizer name or path if not the same as model_name", + ) + parser.add_argument( + "--instance_data_dir", + type=str, + default=None, + required=True, + help="A folder containing the training data of instance images.", + ) + parser.add_argument( + "--class_data_dir", + type=str, + default=None, + required=False, + help="A folder containing the training data of class images.", + ) + parser.add_argument( + "--instance_prompt", + type=str, + default=None, + required=True, + help="The prompt with identifier specifying the instance", + ) + parser.add_argument( + "--class_prompt", + type=str, + default=None, + help="The prompt to specify images in the same class as provided instance images.", + ) + parser.add_argument( + "--with_prior_preservation", + default=False, + action="store_true", + help="Flag to add prior preservation loss.", + ) + parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.") + parser.add_argument( + "--num_class_images", + type=int, + default=100, + help=( + "Minimal class images for prior preservation loss. If there are not enough images already present in" + " class_data_dir, additional images will be sampled with class_prompt." + ), + ) + parser.add_argument( + "--validation_prompt", + nargs="+", + help="A prompt that is used during validation to verify that the model is learning.", + ) + parser.add_argument( + "--num_validation_images", + type=int, + default=4, + help="Number of images that should be generated during validation with `validation_prompt`.", + ) + parser.add_argument( + "--validation_steps", + type=int, + default=500, + help=( + "Run dreambooth validation every X steps. Dreambooth validation consists of running the prompt" + " `args.validation_prompt` multiple times: `args.num_validation_images`." + ), + ) + parser.add_argument( + "--output_dir", + type=str, + default="text-inversion-model", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument( + "--resolution", + type=int, + default=512, + help=( + "The resolution for input images, all the images in the train/validation dataset will be resized to this" + " resolution" + ), + ) + parser.add_argument( + "--center_crop", action="store_true", help="Whether to center crop images before resizing to resolution" + ) + parser.add_argument("--train_text_encoder", action="store_true", help="Whether to train the text encoder") + + parser.add_argument( + "--set_grads_to_none", + action="store_true", + help=( + "Save more memory by using setting grads to None instead of zero. Be aware, that this changes certain" + " behaviors, so disable this argument if it causes any problems. More info:" + " https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html" + ), + ) + + # boft args + parser.add_argument("--use_boft", action="store_true", help="Whether to use BOFT for parameter efficient tuning") + parser.add_argument("--boft_block_num", type=int, default=4, help="The number of BOFT blocks") + parser.add_argument("--boft_block_size", type=int, default=0, help="The size of BOFT blocks") + parser.add_argument("--boft_n_butterfly_factor", type=int, default=2, help="The number of butterfly factors") + parser.add_argument("--boft_dropout", type=float, default=0.1, help="BOFT dropout, only used if use_boft is True") + parser.add_argument( + "--boft_bias", + type=str, + default="none", + help="Bias type for BOFT. Can be 'none', 'all' or 'boft_only', only used if use_boft is True", + ) + parser.add_argument( + "--num_dataloader_workers", type=int, default=1, help="Num of workers for the training dataloader." + ) + parser.add_argument( + "--no_tracemalloc", + default=False, + action="store_true", + help="Flag to stop memory allocation tracing during training. This could speed up training on Windows.", + ) + + parser.add_argument( + "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." + ) + parser.add_argument( + "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images." + ) + parser.add_argument("--num_train_epochs", type=int, default=1) + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--checkpointing_steps", + type=int, + default=500, + help=( + "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final" + " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming" + " training using `--resume_from_checkpoint`." + ), + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help=( + "Whether training should be resumed from a previous checkpoint. Use a path saved by" + ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' + ), + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=5e-6, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--scale_lr", + action="store_true", + default=False, + help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument( + "--lr_num_cycles", + type=int, + default=1, + help="Number of hard resets of the lr in cosine_with_restarts scheduler.", + ) + parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.") + parser.add_argument( + "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." + ) + parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") + parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") + parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") + parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), + ) + parser.add_argument( + "--allow_tf32", + action="store_true", + help=( + "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" + " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" + ), + ) + parser.add_argument( + "--report_to", + type=str, + default="wandb", + help=( + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' + ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' + ), + ) + parser.add_argument( + "--wandb_key", + type=str, + default=None, + help=("If report to option is set to wandb, api-key for wandb used for login to wandb "), + ) + parser.add_argument( + "--wandb_project_name", + type=str, + default=None, + help=("If report to option is set to wandb, project name in wandb for log tracking "), + ) + parser.add_argument( + "--wandb_run_name", + type=str, + default=None, + help=("If report to option is set to wandb, project name in wandb for log tracking "), + ) + parser.add_argument( + "--mixed_precision", + type=str, + default=None, + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." + ), + ) + parser.add_argument( + "--prior_generation_precision", + type=str, + default=None, + choices=["no", "fp32", "fp16", "bf16"], + help=( + "Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32." + ), + ) + parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") + parser.add_argument( + "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." + ) + + if input_args is not None: + args = parser.parse_args(input_args) + else: + args = parser.parse_args() + + env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) + if env_local_rank != -1 and env_local_rank != args.local_rank: + args.local_rank = env_local_rank + + # Sanity checks + # if args.dataset_name is None and args.train_data_dir is None: + # raise ValueError("Need either a dataset name or a training folder.") + + if args.with_prior_preservation: + if args.class_data_dir is None: + raise ValueError("You must specify a data directory for class images.") + if args.class_prompt is None: + raise ValueError("You must specify prompt for class images.") + else: + # logger is not available yet + if args.class_data_dir is not None: + warnings.warn("You need not use --class_data_dir without --with_prior_preservation.") + if args.class_prompt is not None: + warnings.warn("You need not use --class_prompt without --with_prior_preservation.") + + return args diff --git a/examples/boft_dreambooth/utils/dataset.py b/examples/boft_dreambooth/utils/dataset.py new file mode 100644 index 0000000000..7a968705cf --- /dev/null +++ b/examples/boft_dreambooth/utils/dataset.py @@ -0,0 +1,126 @@ +from pathlib import Path + +import torch +from PIL import Image +from torch.utils.data import Dataset +from torchvision import transforms + + +class DreamBoothDataset(Dataset): + """ + A dataset to prepare the instance and class images with the prompts for fine-tuning the model. + It pre-processes the images and the tokenizes prompts. + """ + + def __init__( + self, + instance_data_root, + instance_prompt, + tokenizer, + class_data_root=None, + class_prompt=None, + size=512, + center_crop=False, + ): + self.size = size + self.center_crop = center_crop + self.tokenizer = tokenizer + + self.instance_data_root = Path(instance_data_root) + if not self.instance_data_root.exists(): + raise ValueError("Instance images root doesn't exists.") + + self.instance_images_path = list(Path(instance_data_root).iterdir()) + self.num_instance_images = len(self.instance_images_path) + self.instance_prompt = instance_prompt + self._length = self.num_instance_images + + if class_data_root is not None: + self.class_data_root = Path(class_data_root) + self.class_data_root.mkdir(parents=True, exist_ok=True) + self.class_images_path = list(self.class_data_root.iterdir()) + self.num_class_images = len(self.class_images_path) + self._length = max(self.num_class_images, self.num_instance_images) + self.class_prompt = class_prompt + else: + self.class_data_root = None + + self.image_transforms = transforms.Compose( + [ + transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) + + def __len__(self): + return self._length + + def __getitem__(self, index): + example = {} + instance_image = Image.open(self.instance_images_path[index % self.num_instance_images]) + if not instance_image.mode == "RGB": + instance_image = instance_image.convert("RGB") + example["instance_images"] = self.image_transforms(instance_image) + example["instance_prompt_ids"] = self.tokenizer( + self.instance_prompt, + truncation=True, + padding="max_length", + max_length=self.tokenizer.model_max_length, + return_tensors="pt", + ).input_ids + + if self.class_data_root: + class_image = Image.open(self.class_images_path[index % self.num_class_images]) + if not class_image.mode == "RGB": + class_image = class_image.convert("RGB") + example["class_images"] = self.image_transforms(class_image) + example["class_prompt_ids"] = self.tokenizer( + self.class_prompt, + truncation=True, + padding="max_length", + max_length=self.tokenizer.model_max_length, + return_tensors="pt", + ).input_ids + + return example + + +def collate_fn(examples, with_prior_preservation=False): + input_ids = [example["instance_prompt_ids"] for example in examples] + pixel_values = [example["instance_images"] for example in examples] + + # Concat class and instance examples for prior preservation. + # We do this to avoid doing two forward passes. + if with_prior_preservation: + input_ids += [example["class_prompt_ids"] for example in examples] + pixel_values += [example["class_images"] for example in examples] + + pixel_values = torch.stack(pixel_values) + pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() + + input_ids = torch.cat(input_ids, dim=0) + + batch = { + "input_ids": input_ids, + "pixel_values": pixel_values, + } + return batch + + +class PromptDataset(Dataset): + "A simple dataset to prepare the prompts to generate class images on multiple GPUs." + + def __init__(self, prompt, num_samples): + self.prompt = prompt + self.num_samples = num_samples + + def __len__(self): + return self.num_samples + + def __getitem__(self, index): + example = {} + example["prompt"] = self.prompt + example["index"] = index + return example diff --git a/examples/boft_dreambooth/utils/tracemalloc.py b/examples/boft_dreambooth/utils/tracemalloc.py new file mode 100644 index 0000000000..ce31dbb208 --- /dev/null +++ b/examples/boft_dreambooth/utils/tracemalloc.py @@ -0,0 +1,58 @@ +import gc +import threading + +import psutil +import torch + + +# Converting Bytes to Megabytes +def b2mb(x): + return int(x / 2**20) + + +# This context manager is used to track the peak memory usage of the process +class TorchTracemalloc: + def __enter__(self): + gc.collect() + torch.cuda.empty_cache() + torch.cuda.reset_max_memory_allocated() # reset the peak gauge to zero + self.begin = torch.cuda.memory_allocated() + self.process = psutil.Process() + + self.cpu_begin = self.cpu_mem_used() + self.peak_monitoring = True + peak_monitor_thread = threading.Thread(target=self.peak_monitor_func) + peak_monitor_thread.daemon = True + peak_monitor_thread.start() + return self + + def cpu_mem_used(self): + """get resident set size memory for the current process""" + return self.process.memory_info().rss + + def peak_monitor_func(self): + self.cpu_peak = -1 + + while True: + self.cpu_peak = max(self.cpu_mem_used(), self.cpu_peak) + + # can't sleep or will not catch the peak right (this comment is here on purpose) + # time.sleep(0.001) # 1msec + + if not self.peak_monitoring: + break + + def __exit__(self, *exc): + self.peak_monitoring = False + + gc.collect() + torch.cuda.empty_cache() + self.end = torch.cuda.memory_allocated() + self.peak = torch.cuda.max_memory_allocated() + self.used = b2mb(self.end - self.begin) + self.peaked = b2mb(self.peak - self.begin) + + self.cpu_end = self.cpu_mem_used() + self.cpu_used = b2mb(self.cpu_end - self.cpu_begin) + self.cpu_peaked = b2mb(self.cpu_peak - self.cpu_begin) + # print(f"delta used/peak {self.used:4d}/{self.peaked:4d}") diff --git a/pyproject.toml b/pyproject.toml index 248dc0211a..c920f4f9d2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,4 +41,4 @@ markers = [ "multi_gpu_tests: tests that run on multiple GPUs", "regression: whether to run regression suite test", "bitsandbytes: select bitsandbytes integration tests" -] +] \ No newline at end of file diff --git a/setup.py b/setup.py index 992de0aa1c..16da054a59 100644 --- a/setup.py +++ b/setup.py @@ -52,7 +52,7 @@ url="https://github.com/huggingface/peft", package_dir={"": "src"}, packages=find_packages("src"), - package_data={"peft": ["py.typed"]}, + package_data={"peft": ["py.typed", "tuners/boft/fbd/fbd_cuda.cpp", "tuners/boft/fbd/fbd_cuda_kernel.cu"]}, entry_points={}, python_requires=">=3.8.0", install_requires=[ diff --git a/src/peft/__init__.py b/src/peft/__init__.py index 55419e22bc..ded9fddbe4 100644 --- a/src/peft/__init__.py +++ b/src/peft/__init__.py @@ -59,6 +59,8 @@ IA3Model, AdaLoraConfig, AdaLoraModel, + BOFTConfig, + BOFTModel, PrefixEncoder, PrefixTuningConfig, PromptEmbedding, diff --git a/src/peft/mapping.py b/src/peft/mapping.py index b62ddf94aa..ad18181416 100644 --- a/src/peft/mapping.py +++ b/src/peft/mapping.py @@ -33,6 +33,8 @@ AdaLoraConfig, AdaLoraModel, AdaptionPromptConfig, + BOFTConfig, + BOFTModel, IA3Config, IA3Model, LoHaConfig, @@ -75,6 +77,7 @@ "LOHA": LoHaConfig, "LOKR": LoKrConfig, "ADALORA": AdaLoraConfig, + "BOFT": BOFTConfig, "IA3": IA3Config, "MULTITASK_PROMPT_TUNING": MultitaskPromptTuningConfig, "OFT": OFTConfig, @@ -86,6 +89,7 @@ "LOHA": LoHaModel, "LOKR": LoKrModel, "ADALORA": AdaLoraModel, + "BOFT": BOFTModel, "IA3": IA3Model, "OFT": OFTModel, "POLY": PolyModel, diff --git a/src/peft/peft_model.py b/src/peft/peft_model.py index edf0b92f8e..f38872ee6d 100644 --- a/src/peft/peft_model.py +++ b/src/peft/peft_model.py @@ -41,6 +41,7 @@ from .tuners import ( AdaLoraModel, AdaptionPromptModel, + BOFTModel, IA3Model, LoHaModel, LoKrModel, @@ -80,6 +81,7 @@ PeftType.P_TUNING: PromptEncoder, PeftType.PREFIX_TUNING: PrefixEncoder, PeftType.ADALORA: AdaLoraModel, + PeftType.BOFT: BOFTModel, PeftType.ADAPTION_PROMPT: AdaptionPromptModel, PeftType.IA3: IA3Model, PeftType.OFT: OFTModel, diff --git a/src/peft/tuners/__init__.py b/src/peft/tuners/__init__.py index b47baa6681..62e1fb8891 100644 --- a/src/peft/tuners/__init__.py +++ b/src/peft/tuners/__init__.py @@ -23,6 +23,7 @@ from .lokr import LoKrConfig, LoKrModel from .ia3 import IA3Config, IA3Model from .adalora import AdaLoraConfig, AdaLoraModel +from .boft import BOFTConfig, BOFTModel from .p_tuning import PromptEncoder, PromptEncoderConfig, PromptEncoderReparameterizationType from .prefix_tuning import PrefixEncoder, PrefixTuningConfig from .prompt_tuning import PromptEmbedding, PromptTuningConfig, PromptTuningInit diff --git a/src/peft/tuners/boft/__init__.py b/src/peft/tuners/boft/__init__.py new file mode 100644 index 0000000000..5b72b73951 --- /dev/null +++ b/src/peft/tuners/boft/__init__.py @@ -0,0 +1,20 @@ +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .config import BOFTConfig +from .layer import BOFTLayer +from .model import BOFTModel + + +__all__ = ["BOFTConfig", "BOFTLayer", "BOFTModel"] diff --git a/src/peft/tuners/boft/config.py b/src/peft/tuners/boft/config.py new file mode 100644 index 0000000000..ab704b5d95 --- /dev/null +++ b/src/peft/tuners/boft/config.py @@ -0,0 +1,133 @@ +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# The implementation is based on "Parameter-Efficient Orthogonal Finetuning +# via Butterfly Factorization" (https://arxiv.org/abs/2311.06243) in ICLR 2024. + +from dataclasses import dataclass, field +from typing import List, Optional, Union + +from peft.config import PeftConfig +from peft.utils import PeftType + + +@dataclass +class BOFTConfig(PeftConfig): + """ + This is the configuration class to store the configuration of a [`BOFTModel`]. + + Args: + boft_block_size (`int`): BOFT block size across different layers. + boft_block_num (`int`): Number of BOFT blocks per injected layer. + boft_n_butterfly_factor (`int`): Number of butterfly factors across different layers. + target_modules (`Union[List[str],str]`): The names of the modules to apply the adapter to. + boft_dropout (`float`): The multiplicative dropout probability for BOFT layers. + fan_in_fan_out (`bool`): Set this to True if the layer to replace stores weight like (fan_in, fan_out). + For example, gpt-2 uses `Conv1D` which stores weights like (fan_in, fan_out) and hence this should be set + to `True`. + bias (`str`): Bias type for BOFT. Can be 'none', 'all' or 'boft_only'. If 'all' or 'boft_only', the + corresponding biases will be updated during training. Be aware that this means that, even when disabling + the adapters, the model will not produce the same output as the base model would have without adaptation. + modules_to_save (`List[str]`):List of modules apart from BOFT layers to be set as trainable + and saved in the final checkpoint. + layers_to_transform (`Union[List[int],int]`): + The layer indexes to transform, if this argument is specified, it will apply the BOFT transformations on + the layer indexes that are specified in this list. If a single integer is passed, it will apply the BOFT + transformations on the layer at this index. + layers_pattern (`str`): + The layer pattern name, used only if `layers_to_transform` is different from `None` and if the layer + pattern is not in the common layers pattern. + """ + + boft_block_size: int = field( + default=4, + metadata={ + "help": "BOFT block size across different layers.", + "note": "You can only specify either boft_block_size or boft_block_num, but not both simultaneously, because boft_block_size x boft_block_num = layer dimension.", + }, + ) + boft_block_num: int = field( + default=0, + metadata={ + "help": "Number of BOFT blocks per injected layer.", + "note": "You can only specify either boft_block_size or boft_block_num, but not both simultaneously, because boft_block_size x boft_block_num = layer dimension.", + }, + ) + boft_n_butterfly_factor: int = field( + default=1, + metadata={ + "help": "Number of butterfly factors.", + "note": ( + "for example, boft_n_butterfly_factor=2, the effective block size of OFT becomes twice as big and the number of blocks become half.", + "note: for boft_n_butterfly_factor=1, BOFT is the same as vanilla OFT.", + ), + }, + ) + target_modules: Optional[Union[List[str], str]] = field( + default=None, + metadata={ + "help": "List of module names or regex expression of the module names to replace with BOFT.", + "example": "For example, ['q', 'v'] or '.*decoder.*(SelfAttention|EncDecAttention).*(q|v)$' ", + }, + ) + boft_dropout: float = field(default=0.0, metadata={"help": "BOFT multiplicative dropout"}) + fan_in_fan_out: bool = field( + default=False, + metadata={"help": "Set this to True if the layer to replace stores weight like (fan_in, fan_out)"}, + ) + bias: str = field(default="none", metadata={"help": "Bias type for BOFT. Can be 'none', 'all' or 'boft_only'"}) + modules_to_save: Optional[List[str]] = field( + default=None, + metadata={ + "help": "List of modules apart from BOFT layers to be set as trainable and saved in the final checkpoint. ", + "note": ( + "For example, in Sequence Classification or Token Classification tasks, ", + "the final layer `classifier/score` are randomly initialized and as such need to be trainable and saved.", + ), + }, + ) + init_weights: bool = field( + default=True, + metadata={ + "help": ( + "Whether to initialize the weights of the BOFT layers with their default initialization. Don't change ", + "this setting, except if you know exactly what you're doing.", + ), + }, + ) + layers_to_transform: Optional[Union[List[int], int]] = field( + default=None, + metadata={ + "help": "The layer indexes to transform, is this argument is specified, PEFT will transform only the layers indexes that are specified inside this list. If a single integer is passed, PEFT will transform only the layer at this index." + }, + ) + layers_pattern: Optional[str] = field( + default=None, + metadata={ + "help": "The layer pattern name, used only if `layers_to_transform` is different to None and if the layer pattern is not in the common layers pattern." + }, + ) + + def __post_init__(self): + self.peft_type = PeftType.BOFT + self.target_modules = ( + set(self.target_modules) if isinstance(self.target_modules, list) else self.target_modules + ) + if self.boft_block_size == 0 and self.boft_block_num == 0: + raise ValueError("You must specify either boft_block_size or boft_block_num.") + if not (self.boft_block_size != 0) ^ (self.boft_block_num != 0): + raise ValueError( + f"You can only specify either boft_block_size ({self.boft_block_size}) or boft_block_num ({self.boft_block_num}), " + "but not both simultaneously, because boft_block_size x boft_block_num != in_features." + ) diff --git a/src/peft/tuners/boft/fbd/__init__.py b/src/peft/tuners/boft/fbd/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/peft/tuners/boft/fbd/fbd_cuda.cpp b/src/peft/tuners/boft/fbd/fbd_cuda.cpp new file mode 100644 index 0000000000..d63111b040 --- /dev/null +++ b/src/peft/tuners/boft/fbd/fbd_cuda.cpp @@ -0,0 +1,28 @@ +#include +#include +#include +#include + +std::vector forward_fast_block_diag_cuda( + at::Tensor input); + +std::vector forward_fast_block_diag( + at::Tensor input + ) { + return forward_fast_block_diag_cuda(input); +} + +std::vector backward_fast_block_diag_cuda( + at::Tensor grad_output, + at::Tensor input); +std::vector backward_fast_block_diag( + at::Tensor grad_output, + at::Tensor input + ) { + return backward_fast_block_diag_cuda(grad_output, input); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("forward", &forward_fast_block_diag, "FAST BLOCK DIAG (CUDA)"); + m.def("backward", &backward_fast_block_diag, "FAST BLOCK DIAG backward (CUDA)"); +} diff --git a/src/peft/tuners/boft/fbd/fbd_cuda_kernel.cu b/src/peft/tuners/boft/fbd/fbd_cuda_kernel.cu new file mode 100644 index 0000000000..7d683371e2 --- /dev/null +++ b/src/peft/tuners/boft/fbd/fbd_cuda_kernel.cu @@ -0,0 +1,109 @@ +// Author: Yao Feng +// Date: 2023/08 +// Description: cuda kernel for fast block diag + +#include + +#include +#include +#include + +namespace{ +template +__global__ void forward_fast_block_diag_cuda_kernel( + const scalar_t* __restrict__ input, //[z, N, b, b] + scalar_t* output, //[z, Nxb, Nxb] + int z, int N, int b + ) { + + const int i = blockIdx.x * blockDim.x + threadIdx.x; + if (i >= z*N*b*b) { + return; + } + const int zi = i/(N*b*b); + const int Ni = (i%(N*b*b))/(b*b); + const int x = ((i%(N*b*b))%(b*b))/b; + const int y = ((i%(N*b*b))%(b*b))%b; + + output[zi*N*b*N*b + (Ni*b+x)*N*b + Ni*b + y] = input[zi*N*b*b + Ni*b*b + x*b + y]; + +} + +template +__global__ void backward_fast_block_diag_cuda_kernel( + const scalar_t* __restrict__ grad_output, + scalar_t* grad_input, + int z, int N, int b + ) { + + const int i = blockIdx.x * blockDim.x + threadIdx.x; + if (i >= z*N*b*b) { + return; + } + const int zi = i/(N*b*b); + const int Ni = (i%(N*b*b))/(b*b); + const int x = ((i%(N*b*b))%(b*b))/b; + const int y = ((i%(N*b*b))%(b*b))%b; + + grad_input[zi*N*b*b + Ni*b*b + x*b + y] = grad_output[zi*N*b*N*b + (Ni*b+x)*N*b + Ni*b + y]; + +} // namespace +} + +std::vector forward_fast_block_diag_cuda( + at::Tensor input + ){ + const auto z = input.size(0); + const auto N = input.size(1); + const auto b = input.size(2); + + // print(channel_size) + const int threads = 512; + const dim3 blocks_1 ((z*N*b*b - 1) / threads +1); + // initlaize output + auto output = at::zeros({z, N*b, N*b}, input.options()); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "forward_fast_block_diag1", ([&] { + forward_fast_block_diag_cuda_kernel<<>>( + input.data(), + output.data(), + z, N, b); + })); + + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + printf("Error in forward_fast_block_diag_cuda_kernel: %s\n", cudaGetErrorString(err)); + + return {output}; +} + +std::vector backward_fast_block_diag_cuda( + at::Tensor grad_output, + at::Tensor input + ){ + + const auto z = input.size(0); + const auto N = input.size(1); + const auto b = input.size(2); + + // print(channel_size) + const int threads = 512; + const dim3 blocks_1 ((z*N*b*b - 1) / threads +1); + + // initialize grad input + auto grad_input = at::zeros_like(input); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(grad_output.type(), "backward_fast_block_diag", ([&] { + backward_fast_block_diag_cuda_kernel<<>>( + grad_output.data(), + grad_input.data(), + z, N, b); + })); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + printf("Error in backward_fast_block_diag_cuda_kernel: %s\n", cudaGetErrorString(err)); + + return {grad_input}; +} diff --git a/src/peft/tuners/boft/layer.py b/src/peft/tuners/boft/layer.py new file mode 100644 index 0000000000..7473d32e17 --- /dev/null +++ b/src/peft/tuners/boft/layer.py @@ -0,0 +1,943 @@ +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# The implementation is based on "Parameter-Efficient Orthogonal Finetuning +# via Butterfly Factorization" (https://arxiv.org/abs/2311.06243) in ICLR 2024. + +from __future__ import annotations + +import math +import os +import warnings +from typing import Any, Optional, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.autograd import Function +from torch.utils.cpp_extension import load + +from peft.tuners.tuners_utils import BaseTunerLayer, check_adapters_to_merge + + +os.environ["CC"] = "gcc" +os.environ["CXX"] = "gcc" +curr_dir = os.path.dirname(__file__) + +_FBD_CUDA = None + + +def get_fbd_cuda(): + global _FBD_CUDA + + if _FBD_CUDA is not None: + return _FBD_CUDA + + curr_dir = os.path.dirname(__file__) + # need ninja to build the extension + try: + fbd_cuda = load( + name="fbd_cuda", + sources=[f"{curr_dir}/fbd/fbd_cuda.cpp", f"{curr_dir}/fbd/fbd_cuda_kernel.cu"], + verbose=True, + # build_directory='/tmp/' # for debugging + ) + # extra_cuda_cflags = ['-std=c++14', '-ccbin=$$(which gcc-7)']) # cuda10.2 is not compatible with gcc9. Specify gcc 7 + import fbd_cuda + except Exception as e: + warnings.warn(f"Failed to load the CUDA extension: {e}, check if ninja is available.") + warnings.warn("Setting boft_n_butterfly_factor to 1 to speed up the finetuning process.") + fbd_cuda = None + + _FBD_CUDA = fbd_cuda + return _FBD_CUDA + + +class FastBlockDiag(Function): + """ + Implements a custom autograd Function for a fast block diagonal operation using CUDA. + + This function is optimized for 4D tensors where the last two dimensions are equal, representing block diagonal + matrices for efficient computation on CUDA devices. + """ + + @staticmethod + def forward(ctx, input): + """ + The forward method for FastBlockDiag. + + Computes the block diagonal operation on the input tensor using a CUDA-optimized function. This method assumes + that the input is a 4D tensor where the last two dimensions are equal, which represent the blocks to be + diagonalized. + + Parameters: + ctx: A context object that can be used to stash information for backward computation. + input (Tensor): The input tensor of shape (N, D, H, H), where `N` is the batch size, + `D` represents one additional dimension (In BOFT, the number of BOFT blocks), and `H` is the + size of the square blocks along the last two dimensions (In BOFT, the block size). + + Returns: + Tensor: The resulting tensor after applying the block diagonal operation, + will have the shape (N, DxH, DxH). + """ + output = get_fbd_cuda().forward(input)[0] + ctx.save_for_backward(input) + return output + + @staticmethod + def backward(ctx, grad_output): + (input,) = ctx.saved_tensors + grad_input = get_fbd_cuda().backward(grad_output, input)[0] + return grad_input + + +class MultiplicativeDropoutLayer(nn.Module): + """ + Implements the multiplicative dropout layer for BOFT. + """ + + def __init__(self, p=0.0): + """ + Initializes the multiplicative dropout layer. + + Parameters: + p (float): The probability of dropping out a block. Defaults to 0.0. + """ + super().__init__() + self.p = p + + def forward(self, x): + """ + Applies multiplicative dropout to the input tensor. + + Parameters: + x (Tensor): The input tensor of shape (N, D, H, H), where `N` is the batch size, `D` represents + one additional dimension (In BOFT, the number of BOFT blocks), and `H` is the size of the square + blocks along the last two dimensions (In BOFT, the block size). + """ + if self.training: + # Ensure the last two dimensions are the same + if x.shape[-1] != x.shape[-2]: + raise ValueError("The last two dimensions of input should be the same!") + + N, D, H, _ = x.shape + + # Randomly select one from N + n_random = torch.randint(0, N, (1,)).item() + + # Create a mask with 1s for matrices to be replaced with identity and 0s otherwise + num_to_replace = int(self.p * D) + num_zeros = D - num_to_replace + + # Generate a flat tensor with desired number of 1s and 0s + mask = torch.cat([torch.ones(num_to_replace, device=x.device), torch.zeros(num_zeros, device=x.device)]) + + # Shuffle and reshape the mask + mask = mask[torch.randperm(D)].view(1, D, 1, 1) + + full_mask = torch.zeros(N, D, 1, 1, device=x.device) + full_mask[n_random] = mask + + # Use the mask to combine original matrices and identity matrices + eye_matrix = torch.eye(H, device=x.device).repeat(N, D, 1, 1) + x = (1 - full_mask) * x + full_mask * eye_matrix + return x + + +class BOFTLayer(BaseTunerLayer): + """ + Implements the BOFT layer. + """ + + # All names of layers that may contain (trainable) adapter weights + adapter_layer_names = ("boft_R", "boft_s") + # All names of other parameters that may contain adapter-related parameters + other_param_names = ("boft_block_size", "boft_block_num", "boft_dropout") + + def __init__(self, base_layer: nn.Module, **kwargs) -> None: + """ + Initializes the BOFT layer. + + Note, currently only support linear layer and convolutional layer, with further support for other layers to be + added soon. + + Parameters: + base_layer: the pretrained model layer + """ + self.base_layer = base_layer + self.boft_block_size = {} + self.boft_block_num = {} + self.boft_dropout = nn.ModuleDict({}) + self.boft_R = nn.ParameterDict({}) + self.boft_s = nn.ParameterDict({}) + # Mark the weight as unmerged + self._disable_adapters = False + self.merged_adapters = [] + self.kwargs = kwargs + + base_layer = self.get_base_layer() + + if isinstance(base_layer, nn.Linear): + in_features, out_features = base_layer.in_features, base_layer.out_features + elif isinstance(base_layer, nn.Conv2d): + in_features, out_features = base_layer.in_channels, base_layer.out_channels + else: + raise ValueError(f"Unsupported layer type {type(base_layer)}") + + self.in_features = in_features + self.out_features = out_features + + def set_scale(self, adapter, scale): + if adapter not in self.scaling: + # Ignore the case where the adapter is not in the layer + return + + warnings.warn("Scaling operation for BOFT not supported! Automatically set scale to 1.") + + def scale_layer(self, scale: float) -> None: + if scale == 1: + return + + for active_adapter in self.active_adapters: + if active_adapter not in self.boft_R.keys(): + continue + + warnings.warn("Scaling operation for BOFT not supported! Automatically set scale to 1.") + + def unscale_layer(self, scale=None) -> None: + for active_adapter in self.active_adapters: + if active_adapter not in self.boft_R.keys(): + continue + + warnings.warn("Unscaling operation for BOFT not supported! Keeping scale to 1.") + + def update_layer( + self, adapter_name, boft_block_size, boft_block_num, boft_n_butterfly_factor, boft_dropout, init_weights + ): + """ + Update the linear layer with trainable BOFT weights. Override for other layer types. + """ + # to be consistent with the paper notation + boft_n_butterfly_factor = boft_n_butterfly_factor - 1 + if boft_n_butterfly_factor < 0: + raise ValueError( + f"You can only specify boft_n_butterfly_factor {boft_n_butterfly_factor+1} to be a positive integer number." + ) + + # Initialize the MultiplicativeDropoutLayer for boft_dropout > 0.0. + if boft_dropout > 0.0: + boft_dropout_layer = MultiplicativeDropoutLayer(p=boft_dropout) + else: + boft_dropout_layer = nn.Identity() + self.boft_dropout.update(nn.ModuleDict({adapter_name: boft_dropout_layer})) + + if boft_block_size == 0 and boft_block_num != 0: + if self.in_features % boft_block_num != 0: + raise ValueError( + f"in_features ({self.in_features}) must be divisible by boft_block_num ({boft_block_num})!" + ) + + if boft_n_butterfly_factor != 0: + if boft_n_butterfly_factor > int(math.log2(boft_block_num)): + raise ValueError( + f"Invalid combination of boft_n_butterfly_factor ({boft_n_butterfly_factor+1}) and boft_block_num ({boft_block_num})!" + ) + if boft_block_num % (2**boft_n_butterfly_factor) != 0: + raise ValueError( + f"boft_block_num ({boft_block_num}) must be a multiple of 2 raised to the power of boft_n_butterfly_factor ({boft_n_butterfly_factor+1})!" + ) + + boft_block_size = int(self.in_features // boft_block_num) + + elif boft_block_size != 0 and boft_block_num == 0: + if self.in_features % boft_block_size != 0: + raise ValueError( + f"in_features ({self.in_features}) must be divisible by boft_block_size ({boft_block_size})!" + ) + + if boft_n_butterfly_factor != 0: + if self.in_features < (boft_block_size * (2**boft_n_butterfly_factor)): + raise ValueError( + f"Invalid combination of in_features ({self.in_features}), boft_n_butterfly_factor ({boft_n_butterfly_factor+1}) and boft_block_size ({boft_block_size})!" + ) + if self.in_features % (boft_block_size * (2**boft_n_butterfly_factor)) != 0: + raise ValueError( + f"Invalid combination of in_features ({self.in_features}), boft_n_butterfly_factor ({boft_n_butterfly_factor+1}) and boft_block_size ({boft_block_size})!" + ) + + boft_block_num = int(self.in_features // boft_block_size) + + else: + raise ValueError( + f"You can only specify either boft_block_size ({boft_block_size}) or boft_block_num ({boft_block_num}), but not both simultaneously or setting both" + "to be 0, because boft_block_size x boft_block_num != in_features." + ) + + # In OFT you can specify the number of blocks to be 1 + if boft_n_butterfly_factor != 0: + if boft_block_num % 2 != 0: + raise ValueError(f"boft_block_num ({boft_block_num}) must be an even number!") + + if boft_block_size % 2 != 0: + raise ValueError(f"boft_block_size ({boft_block_size}) must be an even number!") + + # If there is no butterfly factor, then permutation matrix P will be an identity matrix. + P = torch.empty((boft_n_butterfly_factor + 1, self.in_features, self.in_features)) + for i in range(boft_n_butterfly_factor + 1): + perm = self.block_butterfly_perm( + self.in_features, int(boft_block_num / (2 ** (i))), int(boft_block_size / 2), boft_n_butterfly_factor + ) + perm_mat = self.perm2mat(perm) + P[i] = perm_mat + + self.register_buffer("boft_P", P) + + self.boft_R[adapter_name] = nn.Parameter( + torch.zeros(boft_n_butterfly_factor + 1, boft_block_num, boft_block_size, boft_block_size) + ) + self.boft_s[adapter_name] = nn.Parameter(torch.ones(int(self.out_features), 1)) + + self.reset_boft_parameters(adapter_name, init_weights) + + weight = getattr(self, "weight", None) + if weight is not None: + # the layer is already completely initialized, this is an update + if weight.dtype.is_floating_point or weight.dtype.is_complex: + self.to(weight.device, dtype=weight.dtype) + else: + self.to(weight.device) + + # set the boft block size and number + self.boft_block_size[adapter_name] = boft_block_size + self.boft_block_num[adapter_name] = boft_block_num + + self.set_adapter(self.active_adapters) + + def reset_boft_parameters(self, adapter_name, init_weights): + """ + Reset the BOFT parameters. + """ + if init_weights is False: + nn.init.normal_(self.boft_R[adapter_name], mean=0.0, std=0.1) + nn.init.normal_(self.boft_s[adapter_name], mean=1.0, std=0.1) + return + + if adapter_name in self.boft_R.keys(): + if init_weights is True: + # initialize R to zero + nn.init.zeros_(self.boft_R[adapter_name]) + nn.init.ones_(self.boft_s[adapter_name]) + else: + raise ValueError(f"Unknown initialization {init_weights=}") + + def perm2mat(self, indices): + """ + Convert permutation indices to permutation matrix. + + Args: + indices: A list of indices representing the permutation. + """ + # Number of indices determines the size of the square matrix + n = len(indices) + + # Initialize a matrix of zeros + perm_mat = torch.zeros((n, n)) + + # Set the 1s according to the indices + for i, idx in enumerate(indices): + perm_mat[i, idx] = 1 + + return perm_mat + + def block_butterfly_perm(self, n, b, r=3, n_butterfly_factor=1): + """ + Define the permutation matrix for the block butterfly permutation. + + Args: + n: size of the permutation matrix + b: desired number of blocks after multiplying with the permutation matrix + r: base block size of the block diagonal matrix, e.g. 2x2, 3x3, 5x5 etc. + """ + + if n_butterfly_factor == 0: + return torch.arange(n) + + if b * r * 2 > n: + raise ValueError("Invalid number of blocks!") + + block_size = int(n // b) + indices = torch.arange(n) + + def sort_block(b, r): + step = b / r + initial_order = torch.arange(b) + sorted_order = torch.empty(b, dtype=torch.long) + + evens = torch.arange(0, step, 2) + odds = torch.arange(1, step, 2) + sorted_seq = torch.cat((evens, odds), dim=0) + for i, pos in enumerate(sorted_seq): + sorted_order[int(i * r) : int(i * r + r)] = initial_order[int(pos * r) : int(pos * r + r)] + return sorted_order + + sorted_order = sort_block(block_size, r) + + for i in range(0, n, block_size): + block_end = i + block_size + tmp_indices = indices[i:block_end] + indices[i:block_end] = tmp_indices[sorted_order] + return indices + + def cayley_batch(self, data): + """ + Perform the Cayley parametrization on a batch of skew-symmetric matrices. + + Args: + data: A batch of skew-symmetric matrices of shape (b, r, c). + """ + b, r, c = data.shape + # Ensure the input matrix is skew-symmetric + skew_mat = 0.5 * (data - data.transpose(1, 2)) + id_mat = torch.eye(r, device=data.device).unsqueeze(0).expand(b, r, c) + + # Perform the Cayley parametrization + Q = torch.linalg.solve(id_mat + skew_mat, id_mat - skew_mat, left=False) + + return Q + + +class Linear(nn.Module, BOFTLayer): + """ + BOFT implemented in a dense layer. + """ + + def __init__( + self, + base_layer, + adapter_name: str, + boft_block_size: int = 8, + boft_block_num: int = 0, + boft_n_butterfly_factor: int = 0, + boft_dropout: float = 0.1, + fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out) + init_weights: Union[bool, str] = True, + is_target_conv_1d_layer: bool = False, + **kwargs, + ) -> None: + super().__init__() + BOFTLayer.__init__(self, base_layer, **kwargs) + self.fan_in_fan_out = fan_in_fan_out + + self._active_adapter = adapter_name + + # Attempt to load the CUDA extension during model initialization + if not get_fbd_cuda(): + self.fbd_cuda_available = False + # If the CUDA extension is not available, set the butterfly factor to 1 to speed up the finetuning process + boft_n_butterfly_factor = 1 + else: + self.fbd_cuda_available = True + + self.update_layer( + adapter_name, boft_block_size, boft_block_num, boft_n_butterfly_factor, boft_dropout, init_weights + ) + self.is_target_conv_1d_layer = is_target_conv_1d_layer + + def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None: + """ + Merge the active adapter weights into the base weights + + Args: + safe_merge (`bool`, *optional*): + If True, the merge operation will be performed in a copy of the original weights and check for NaNs + before merging the weights. This is useful if you want to check if the merge operation will produce + NaNs. Defaults to `False`. + adapter_names (`List[str]`, *optional*): + The list of adapter names that should be merged. If None, all active adapters will be merged. Defaults + to `None`. + """ + adapter_names = check_adapters_to_merge(self, adapter_names) + if not adapter_names: + # no adapter to merge + return + + for active_adapter in adapter_names: + if active_adapter in self.boft_R.keys(): + base_layer = self.get_base_layer() + if safe_merge: + # Note that safe_merge will be slower than the normal merge + # because of the copy operation. + orig_weight = base_layer.weight.data.clone() + butterfly_oft_mat, boft_s = self.get_delta_weight(active_adapter) + orig_weight = torch.transpose(orig_weight, 0, 1) + orig_weight = torch.mm(butterfly_oft_mat, orig_weight) + orig_weight = torch.transpose(orig_weight, 0, 1) + orig_weight = orig_weight * boft_s + + if not torch.isfinite(orig_weight).all(): + raise ValueError( + f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken" + ) + + self.base_layer.weight.data = orig_weight + else: + butterfly_oft_mat, boft_s = self.get_delta_weight(active_adapter) + orig_weight = base_layer.weight.data.clone() + orig_weight = torch.transpose(orig_weight, 0, 1) + orig_weight = torch.mm(butterfly_oft_mat, orig_weight) + orig_weight = torch.transpose(orig_weight, 0, 1) + orig_weight = orig_weight * boft_s + + self.base_layer.weight.data = orig_weight + + self.merged_adapters.append(active_adapter) + + def unmerge(self) -> None: + """ + This method unmerges all merged adapter layers from the base weights. + """ + if not self.merged: + warnings.warn("Already unmerged. Nothing to do.") + return + while len(self.merged_adapters) > 0: + active_adapter = self.merged_adapters.pop() + if active_adapter in self.boft_R.keys(): + butterfly_oft_mat, boft_s = self.get_delta_weight(active_adapter) + + orig_weight = self.get_base_layer().weight.data.clone() + orig_weight = torch.transpose(orig_weight, 0, 1) + orig_weight = torch.mm(butterfly_oft_mat.t(), orig_weight) + orig_weight = torch.transpose(orig_weight, 0, 1) + + self.get_base_layer().weight.data = orig_weight * (1 / boft_s) + + def get_delta_weight(self, adapter) -> tuple[torch.Tensor, torch.Tensor]: + """ + Compute the delta weight for the given adapter. + + Args: + adapter (str): + The name of the adapter for which the delta weight should be computed. + """ + boft_R = self.boft_R[adapter] + boft_s = self.boft_s[adapter] + + N, D, H, _ = boft_R.shape + boft_R = boft_R.view(N * D, H, H) + orth_rotate_butterfly = self.cayley_batch(boft_R) + orth_rotate_butterfly = orth_rotate_butterfly.view(N, D, H, H) + if self.fbd_cuda_available: + block_diagonal_butterfly = FastBlockDiag.apply(orth_rotate_butterfly) + else: + orth_rotate_butterfly = orth_rotate_butterfly.squeeze(0) + block_diagonal_butterfly = torch.block_diag(*torch.unbind(orth_rotate_butterfly)) + block_diagonal_butterfly = block_diagonal_butterfly.unsqueeze(0) + + butterfly_oft_mat_batch = torch.bmm(block_diagonal_butterfly, self.boft_P.permute(0, 2, 1)) + butterfly_oft_mat_batch = torch.bmm(self.boft_P, butterfly_oft_mat_batch) + butterfly_oft_mat = butterfly_oft_mat_batch[0] + + for i in range(1, butterfly_oft_mat_batch.shape[0]): + butterfly_oft_mat = butterfly_oft_mat_batch[i] @ butterfly_oft_mat + + return butterfly_oft_mat, boft_s + + def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: + previous_dtype = x.dtype + + if self.disable_adapters: + if self.merged: + self.unmerge() + result = self.base_layer(x, *args, **kwargs) + elif self.merged: + result = self.base_layer(x, *args, **kwargs) + else: + boft_rotation = torch.eye(self.in_features, device=x.device) + boft_scale = torch.ones((int(self.out_features), 1), device=x.device) + + for active_adapter in self.active_adapters: + if active_adapter not in self.boft_R.keys(): + continue + boft_R = self.boft_R[active_adapter] + boft_s = self.boft_s[active_adapter] + dropout = self.boft_dropout[active_adapter] + + N, D, H, _ = boft_R.shape + boft_R = boft_R.view(N * D, H, H) + orth_rotate_butterfly = self.cayley_batch(boft_R) + orth_rotate_butterfly = orth_rotate_butterfly.view(N, D, H, H) + orth_rotate_butterfly = dropout(orth_rotate_butterfly) + if self.fbd_cuda_available: + block_diagonal_butterfly = FastBlockDiag.apply(orth_rotate_butterfly) + else: + orth_rotate_butterfly = orth_rotate_butterfly.squeeze(0) + block_diagonal_butterfly = torch.block_diag(*torch.unbind(orth_rotate_butterfly)) + block_diagonal_butterfly = block_diagonal_butterfly.unsqueeze(0) + + butterfly_oft_mat_batch = torch.bmm(block_diagonal_butterfly, self.boft_P.permute(0, 2, 1)) + butterfly_oft_mat_batch = torch.bmm(self.boft_P, butterfly_oft_mat_batch) + butterfly_oft_mat = butterfly_oft_mat_batch[0] + + for i in range(1, butterfly_oft_mat_batch.shape[0]): + butterfly_oft_mat = butterfly_oft_mat_batch[i] @ butterfly_oft_mat + + boft_rotation = butterfly_oft_mat @ boft_rotation + boft_scale = boft_s * boft_scale + + x = x.to(self.get_base_layer().weight.data.dtype) + + orig_weight = self.get_base_layer().weight.data + orig_weight = torch.transpose(orig_weight, 0, 1) + rotated_weight = torch.mm(boft_rotation, orig_weight) + rotated_weight = torch.transpose(rotated_weight, 0, 1) + + scaled_rotated_weight = rotated_weight * boft_scale + + result = F.linear(input=x, weight=scaled_rotated_weight, bias=self.base_layer.bias) + + result = result.to(previous_dtype) + return result + + def __repr__(self) -> str: + rep = super().__repr__() + return "boft." + rep + + +class Conv2d(nn.Module, BOFTLayer): + """ + BOFT implemented in a Conv2d layer. + """ + + def __init__( + self, + base_layer: nn.Module, + adapter_name: str, + boft_block_size: int = 8, + boft_block_num: int = 0, + boft_n_butterfly_factor: int = 0, + boft_dropout: float = 0.1, + init_weights: Union[bool, str] = True, + **kwargs, + ) -> None: + super().__init__() + BOFTLayer.__init__(self, base_layer) + + self._active_adapter = adapter_name + + # Attempt to load the CUDA extension during model initialization + if not get_fbd_cuda(): + self.fbd_cuda_available = False + # If the CUDA extension is not available, set the butterfly factor to 1 to speed up the finetuning process + boft_n_butterfly_factor = 1 + else: + self.fbd_cuda_available = True + + self.update_layer( + adapter_name, boft_block_size, boft_block_num, boft_n_butterfly_factor, boft_dropout, init_weights + ) + + def update_layer( + self, adapter_name, boft_block_size, boft_block_num, boft_n_butterfly_factor, boft_dropout, init_weights + ): + """ + Update the conv2d layer with trainable BOFT weights. + """ + # to be consistent with the paper notation + boft_n_butterfly_factor = boft_n_butterfly_factor - 1 + if boft_n_butterfly_factor < 0: + raise ValueError( + f"You can only specify boft_n_butterfly_factor {boft_n_butterfly_factor+1} to be a positive integer number." + ) + + # Initialize the MultiplicativeDropoutLayer for boft_dropout > 0.0. + if boft_dropout > 0.0: + boft_dropout_layer = MultiplicativeDropoutLayer(p=boft_dropout) + else: + boft_dropout_layer = nn.Identity() + self.boft_dropout.update(nn.ModuleDict({adapter_name: boft_dropout_layer})) + + # layer information from the base layer + base_layer = self.get_base_layer() + conv_filter_dim = self.in_features * base_layer.kernel_size[0] * base_layer.kernel_size[0] + + # Initialize the BOFT parameters. + if not (boft_block_size != 0) ^ (boft_block_num != 0): + raise ValueError( + f"You can only specify either boft_block_size ({boft_block_size}) or boft_block_num ({boft_block_num}), but not both simultaneously, because boft_block_size x boft_block_num != in_features." + ) + + if boft_block_size == 0 and boft_block_num != 0: + if conv_filter_dim % boft_block_num != 0: + raise ValueError( + f"Convolutional kernel dimension ({conv_filter_dim}) must be divisible by boft_block_num ({boft_block_num})!" + ) + + if boft_n_butterfly_factor != 0: + if boft_n_butterfly_factor > int(math.log2(boft_block_num)): + raise ValueError( + f"Invalid combination of boft_n_butterfly_factor ({boft_n_butterfly_factor+1}) and boft_block_num ({boft_block_num})!" + ) + if boft_block_num % (2**boft_n_butterfly_factor) != 0: + raise ValueError( + f"boft_block_num ({boft_block_num}) must be a multiple of 2 raised to the power of boft_n_butterfly_factor ({boft_n_butterfly_factor+1})!" + ) + + boft_block_size = int(conv_filter_dim // boft_block_num) + + elif boft_block_size != 0 and boft_block_num == 0: + if conv_filter_dim % boft_block_size != 0: + raise ValueError( + f"Convolutional kernel dimension ({conv_filter_dim}) must be divisible by boft_block_size ({boft_block_size})!" + ) + + if boft_n_butterfly_factor != 0: + if conv_filter_dim < (boft_block_size * (2**boft_n_butterfly_factor)): + raise ValueError( + f"Invalid combination of convolutional kernel dimension ({conv_filter_dim}), boft_n_butterfly_factor ({boft_n_butterfly_factor+1}) and boft_block_size ({boft_block_size})!" + ) + if conv_filter_dim % (boft_block_size * (2**boft_n_butterfly_factor)) != 0: + raise ValueError( + f"Invalid combination of convolutional kernel dimension ({conv_filter_dim}), boft_n_butterfly_factor ({boft_n_butterfly_factor+1}) and boft_block_size ({boft_block_size})!" + ) + + boft_block_num = int(conv_filter_dim // boft_block_size) + + else: + raise ValueError("Unknown error!") + + # In OFT you can specify the number of blocks to be 1 + if boft_n_butterfly_factor != 0: + if boft_block_num % 2 != 0: + raise ValueError(f"boft_block_num ({boft_block_num}) must be an even number!") + + if boft_block_size % 2 != 0: + raise ValueError(f"boft_block_size ({boft_block_size}) must be an even number!") + + # If there is no butterfly factor, then permutation matrix P will be an identity matrix. + P = torch.empty((boft_n_butterfly_factor + 1, conv_filter_dim, conv_filter_dim)) + for i in range(boft_n_butterfly_factor + 1): + perm = self.block_butterfly_perm( + conv_filter_dim, int(boft_block_num / (2 ** (i))), int(boft_block_size / 2), boft_n_butterfly_factor + ) + perm_mat = self.perm2mat(perm) + P[i] = perm_mat + + self.register_buffer("boft_P", P) + + self.boft_R[adapter_name] = nn.Parameter( + torch.zeros(boft_n_butterfly_factor + 1, boft_block_num, boft_block_size, boft_block_size) + ) + self.boft_s[adapter_name] = nn.Parameter(torch.ones(1, int(self.out_features))) + + self.reset_boft_parameters(adapter_name, init_weights) + + weight = getattr(self, "weight", None) + if weight is not None: + # the layer is already completely initialized, this is an update + if weight.dtype.is_floating_point or weight.dtype.is_complex: + self.to(weight.device, dtype=weight.dtype) + else: + self.to(weight.device) + self.set_adapter(self.active_adapters) + + # set the boft block size and number + self.boft_block_size[adapter_name] = boft_block_size + self.boft_block_num[adapter_name] = boft_block_num + + def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None: + """ + Merge the active adapter weights into the base weights + + Args: + safe_merge (`bool`, *optional*): + If True, the merge operation will be performed in a copy of the original weights and check for NaNs + before merging the weights. This is useful if you want to check if the merge operation will produce + NaNs. Defaults to `False`. + adapter_names (`List[str]`, *optional*): + The list of adapter names that should be merged. If None, all active adapters will be merged. Defaults + to `None`. + """ + adapter_names = check_adapters_to_merge(self, adapter_names) + if not adapter_names: + # no adapter to merge + return + + for active_adapter in adapter_names: + if active_adapter in self.boft_R.keys(): + base_layer = self.get_base_layer() + if safe_merge: + # Note that safe_merge will be slower than the normal merge + # because of the copy operation. + orig_weight = base_layer.weight.data.clone() + butterfly_oft_mat, boft_s = self.get_delta_weight(active_adapter) + + orig_weight = orig_weight.view( + self.in_features * base_layer.kernel_size[0] * base_layer.kernel_size[0], self.out_features + ) + orig_weight = torch.mm(butterfly_oft_mat, orig_weight) + orig_weight = orig_weight * boft_s + orig_weight = orig_weight.view( + self.out_features, self.in_features, base_layer.kernel_size[0], base_layer.kernel_size[0] + ) + + self.base_layer.weight.data = orig_weight + else: + butterfly_oft_mat, boft_s = self.get_delta_weight(active_adapter) + + orig_weight = base_layer.weight.data.clone() + orig_weight = orig_weight.view( + self.in_features * base_layer.kernel_size[0] * base_layer.kernel_size[0], self.out_features + ) + orig_weight = torch.mm(butterfly_oft_mat, orig_weight) + orig_weight = orig_weight * boft_s + orig_weight = orig_weight.view( + self.out_features, self.in_features, base_layer.kernel_size[0], base_layer.kernel_size[0] + ) + + self.base_layer.weight.data = orig_weight + + self.merged_adapters.append(active_adapter) + + def unmerge(self) -> None: + """ + This method unmerges all merged adapter layers from the base weights. + """ + if not self.merged: + warnings.warn("Already unmerged. Nothing to do.") + return + while len(self.merged_adapters) > 0: + active_adapter = self.merged_adapters.pop() + if active_adapter in self.boft_R.keys(): + butterfly_oft_mat, boft_s = self.get_delta_weight(active_adapter) + + orig_weight = self.get_base_layer().weight.data.clone() + orig_weight = orig_weight.view( + self.in_features * self.get_base_layer().kernel_size[0] * self.get_base_layer().kernel_size[0], + self.out_features, + ) + orig_weight = torch.mm(butterfly_oft_mat.t(), orig_weight) + orig_weight = orig_weight * (1 / boft_s) + orig_weight = orig_weight.view( + self.out_features, + self.in_features, + self.get_base_layer().kernel_size[0], + self.get_base_layer().kernel_size[0], + ) + + self.get_base_layer().weight.data = orig_weight + + def get_delta_weight(self, adapter) -> tuple[torch.Tensor, torch.Tensor]: + """ + Compute the delta weight for the given adapter. + + Args: + adapter (str): + The name of the adapter for which the delta weight should be computed. + """ + + boft_R = self.boft_R[adapter] + boft_s = self.boft_s[adapter] + + N, D, H, _ = boft_R.shape + boft_R = boft_R.view(N * D, H, H) + orth_rotate_butterfly = self.cayley_batch(boft_R) + orth_rotate_butterfly = orth_rotate_butterfly.view(N, D, H, H) + if self.fbd_cuda_available: + block_diagonal_butterfly = FastBlockDiag.apply(orth_rotate_butterfly) + else: + orth_rotate_butterfly = orth_rotate_butterfly.squeeze(0) + block_diagonal_butterfly = torch.block_diag(*torch.unbind(orth_rotate_butterfly)) + block_diagonal_butterfly = block_diagonal_butterfly.unsqueeze(0) + + butterfly_oft_mat_batch = torch.bmm(block_diagonal_butterfly, self.boft_P.permute(0, 2, 1)) + butterfly_oft_mat_batch = torch.bmm(self.boft_P, butterfly_oft_mat_batch) + butterfly_oft_mat = butterfly_oft_mat_batch[0] + + for i in range(1, butterfly_oft_mat_batch.shape[0]): + butterfly_oft_mat = butterfly_oft_mat_batch[i] @ butterfly_oft_mat + + return butterfly_oft_mat, boft_s + + def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: + previous_dtype = x.dtype + + if self.disable_adapters: + if self.merged: + self.unmerge() + result = self.base_layer(x, *args, **kwargs) + elif self.merged: + result = self.base_layer(x, *args, **kwargs) + else: + boft_rotation = torch.eye( + self.in_features * self.base_layer.kernel_size[0] * self.base_layer.kernel_size[0], device=x.device + ) + boft_scale = torch.ones((1, int(self.out_features)), device=x.device) + + for active_adapter in self.active_adapters: + if active_adapter not in self.boft_R.keys(): + continue + boft_R = self.boft_R[active_adapter] + boft_s = self.boft_s[active_adapter] + dropout = self.boft_dropout[active_adapter] + + N, D, H, _ = boft_R.shape + boft_R = boft_R.view(N * D, H, H) + orth_rotate_butterfly = self.cayley_batch(boft_R) + orth_rotate_butterfly = orth_rotate_butterfly.view(N, D, H, H) + orth_rotate_butterfly = dropout(orth_rotate_butterfly) + if self.fbd_cuda_available: + block_diagonal_butterfly = FastBlockDiag.apply(orth_rotate_butterfly) + else: + orth_rotate_butterfly = orth_rotate_butterfly.squeeze(0) + block_diagonal_butterfly = torch.block_diag(*torch.unbind(orth_rotate_butterfly)) + block_diagonal_butterfly = block_diagonal_butterfly.unsqueeze(0) + + butterfly_oft_mat_batch = torch.bmm(block_diagonal_butterfly, self.boft_P.permute(0, 2, 1)) + butterfly_oft_mat_batch = torch.bmm(self.boft_P, butterfly_oft_mat_batch) + butterfly_oft_mat = butterfly_oft_mat_batch[0] + + for i in range(1, butterfly_oft_mat_batch.shape[0]): + butterfly_oft_mat = butterfly_oft_mat_batch[i] @ butterfly_oft_mat + + boft_rotation = butterfly_oft_mat @ boft_rotation + boft_scale = boft_s * boft_scale + + x = x.to(self.base_layer.weight.data.dtype) + + orig_weight = self.base_layer.weight.data + orig_weight = orig_weight.view( + self.in_features * self.base_layer.kernel_size[0] * self.base_layer.kernel_size[0], + self.out_features, + ) + rotated_weight = torch.mm(boft_rotation, orig_weight) + + scaled_rotated_weight = rotated_weight * boft_scale + + scaled_rotated_weight = scaled_rotated_weight.view( + self.out_features, self.in_features, self.base_layer.kernel_size[0], self.base_layer.kernel_size[0] + ) + result = F.conv2d( + input=x, + weight=scaled_rotated_weight, + bias=self.base_layer.bias, + padding=self.base_layer.padding[0], + stride=self.base_layer.stride[0], + ) + + result = result.to(previous_dtype) + return result + + def __repr__(self) -> str: + rep = super().__repr__() + return "boft." + rep diff --git a/src/peft/tuners/boft/model.py b/src/peft/tuners/boft/model.py new file mode 100644 index 0000000000..5acdfe9b4a --- /dev/null +++ b/src/peft/tuners/boft/model.py @@ -0,0 +1,333 @@ +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# The implementation is based on "Parameter-Efficient Orthogonal Finetuning +# via Butterfly Factorization" (https://arxiv.org/abs/2311.06243) in ICLR 2024. + +import warnings +from dataclasses import asdict +from enum import Enum +from typing import List, Optional + +import torch +from torch import nn +from tqdm import tqdm + +from peft.tuners.tuners_utils import BaseTuner, BaseTunerLayer, check_target_module_exists +from peft.utils import ( + TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING, + ModulesToSaveWrapper, + _get_submodules, +) + +from .config import BOFTConfig +from .layer import BOFTLayer, Conv2d, Linear + + +class BOFTModel(BaseTuner): + """ + Creates BOFT and OFT model from a pretrained transformers model. Paper: https://arxiv.org/abs/2311.06243 + https://arxiv.org/abs/2306.07280 + + Args: + model ([`transformers.PreTrainedModel`]): The model to be adapted. + config ([`BOFTConfig`]): The configuration of the BOFT model. + adapter_name (`str`): The name of the adapter, defaults to `"default"`. + + Returns: + `torch.nn.Module`: The BOFT model. + + Example:: + + >>> import transformers >>> from transformers import AutoModelForSeq2SeqLM, BOFTConfig >>> from peft import + BOFTConfig, get_peft_model + + >>> config = BOFTConfig( ... boft_block_size=8, ... boft_n_butterfly_factor=1, ... target_modules=["query", + "value", "key", "output.dense", "mlp.fc1", "mlp.fc2"], ... boft_dropout=0.1, ... bias="boft_only", ... + modules_to_save=["classifier"], ... ) + + >>> model = transformers.Dinov2ForImageClassification.from_pretrained( ... "facebook/dinov2-large", ... + num_labels=100, ... ) >>> boft_model = get_peft_model(model, config) + + **Attributes**: + - **model** ([`transformers.PreTrainedModel`]) -- The model to be adapted. + - **peft_config** ([`BOFTConfig`]): The configuration of the BOFT model. + """ + + prefix: str = "boft_" + + def __init__(self, model, config, adapter_name) -> None: + super().__init__(model, config, adapter_name) + + def _check_new_adapter_config(self, config: BOFTConfig) -> None: + """ + A helper method to check the config when a new adapter is being added. + + Raise a ValueError if there is something wrong with the config or if it conflicts with existing adapters. + + """ + # TODO: there should be a check if any of the existing adapters actually has bias != "none", or else the check + # does not fully correspond to the error message. + if (len(self.peft_config) > 1) and (config.bias != "none"): + raise ValueError( + f"{self.__class__.__name__} supports only 1 adapter with bias. When using multiple adapters, " + "set bias to 'none' for all adapters." + ) + + @staticmethod + def _check_target_module_exists(boft_config, key): + return check_target_module_exists(boft_config, key) + + def _create_and_replace( + self, + boft_config, + adapter_name, + target, + target_name, + parent, + current_key, + **optional_kwargs, + ): + if current_key is None: + raise ValueError("Current Key shouldn't be `None`") + + bias = hasattr(target, "bias") and target.bias is not None + kwargs = { + "boft_block_size": boft_config.boft_block_size, + "boft_block_num": boft_config.boft_block_num, + "boft_n_butterfly_factor": boft_config.boft_n_butterfly_factor, + "boft_dropout": boft_config.boft_dropout, + "fan_in_fan_out": boft_config.fan_in_fan_out, + "init_weights": boft_config.init_weights, + } + kwargs["bias"] = bias + + # If it is not a BOFTLayer, create a new module, else update it with new adapters + if not isinstance(target, BOFTLayer): + new_module = self._create_new_module(boft_config, adapter_name, target, **kwargs) + if adapter_name != self.active_adapter: + # adding an additional adapter: it is not automatically trainable + new_module.requires_grad_(False) + self._replace_module(parent, target_name, new_module, target) + else: + target.update_layer( + adapter_name, + boft_block_size=boft_config.boft_block_size, + boft_block_num=boft_config.boft_block_num, + boft_n_butterfly_factor=boft_config.boft_n_butterfly_factor, + boft_dropout=boft_config.boft_dropout, + init_weights=boft_config.init_weights, + ) + + def _replace_module(self, parent, child_name, new_module, child): + setattr(parent, child_name, new_module) + # It's not necessary to set requires_grad here, as that is handled by + # _mark_only_adapters_as_trainable + + # child layer wraps the original module, unpack it + if hasattr(child, "base_layer"): + child = child.base_layer + + if not hasattr(new_module, "base_layer"): + new_module.weight = child.weight + if hasattr(child, "bias"): + new_module.bias = child.bias + + if getattr(child, "state", None) is not None: + if hasattr(new_module, "base_layer"): + new_module.base_layer.state = child.state + else: + new_module.state = child.state + new_module.to(child.weight.device) + + # dispatch to correct device + for name, module in new_module.named_modules(): + if self.prefix in name: + module.to(child.weight.device) + + def _mark_only_adapters_as_trainable(self, model: nn.Module) -> None: + for n, p in model.named_parameters(): + if self.prefix not in n: + p.requires_grad = False + + for active_adapter in self.active_adapters: + bias = self.peft_config[active_adapter].bias + if bias == "none": + continue + + if bias == "all": + for n, p in model.named_parameters(): + if "bias" in n: + p.requires_grad = True + elif bias == "boft_only": + for name, m in model.named_modules(): + if isinstance(m, BOFTLayer) and hasattr(m, "bias") and m.bias is not None: + m.bias.requires_grad = True + else: + raise NotImplementedError(f"Requested bias: {bias}, is not implemented.") + + @staticmethod + def _create_new_module(boft_config, adapter_name, target, **kwargs): + if isinstance(target, BaseTunerLayer): + target_base_layer = target.get_base_layer() + else: + target_base_layer = target + + if isinstance(target_base_layer, torch.nn.Linear): + if kwargs["fan_in_fan_out"]: + warnings.warn( + "fan_in_fan_out is set to True but the target module is `torch.nn.Linear`. " + "Setting fan_in_fan_out to False." + ) + kwargs["fan_in_fan_out"] = boft_config.fan_in_fan_out = False + new_module = Linear(target, adapter_name, **kwargs) + elif isinstance(target_base_layer, torch.nn.Conv2d): + new_module = Conv2d(target, adapter_name, **kwargs) + else: + raise ValueError( + f"Target module {target} is not supported. " + "Currently, only `torch.nn.Linear` and `torch.nn.Conv2d` are supported." + ) + + return new_module + + def __getattr__(self, name: str): + """Forward missing attributes to the wrapped module.""" + try: + return super().__getattr__(name) # defer to nn.Module's logic + except AttributeError: + return getattr(self.model, name) + + def get_peft_config_as_dict(self, inference: bool = False): + config_dict = {} + for key, value in self.peft_config.items(): + config = {k: v.value if isinstance(v, Enum) else v for k, v in asdict(value).items()} + if inference: + config["inference_mode"] = True + config_dict[key] = config + return config + + def _set_adapter_layers(self, enabled=True): + for module in self.model.modules(): + if isinstance(module, (BaseTunerLayer, ModulesToSaveWrapper)): + module.enable_adapters(enabled) + + def enable_adapter_layers(self): + self._set_adapter_layers(enabled=True) + + def disable_adapter_layers(self): + for active_adapter in self.active_adapters: + val = self.peft_config[active_adapter].bias + if val != "none": + msg = ( + f"Careful, disabling adapter layers with bias configured to be '{val}' does not produce the same " + "output as the the base model would without adaption." + ) + warnings.warn(msg) + self._set_adapter_layers(enabled=False) + + def set_adapter(self, adapter_name): + for module in self.model.modules(): + if isinstance(module, BOFTLayer): + if module.merged: + warnings.warn("Adapter cannot be set when the model is merged. Unmerging the model first.") + module.unmerge() + module.set_adapter(adapter_name) + + @staticmethod + def _prepare_adapter_config(peft_config, model_config): + if peft_config.target_modules is None: + if model_config["model_type"] not in TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING: + raise ValueError("Please specify `target_modules` in `peft_config`") + peft_config.target_modules = set( + TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING[model_config["model_type"]] + ) + return peft_config + + def _unload_and_optionally_merge( + self, + merge=True, + progressbar: bool = False, + safe_merge: bool = False, + adapter_names: Optional[List[str]] = None, + ): + self._unloading_checks(adapter_names) + key_list = [key for key, _ in self.model.named_modules() if self.prefix not in key] + desc = "Unloading " + ("and merging " if merge else "") + "model" + for key in tqdm(key_list, disable=not progressbar, desc=desc): + try: + parent, target, target_name = _get_submodules(self.model, key) + except AttributeError: + continue + + if hasattr(target, "base_layer"): + if merge: + target.merge(safe_merge=safe_merge, adapter_names=adapter_names) + self._replace_module(parent, target_name, target.get_base_layer(), target) + elif isinstance(target, ModulesToSaveWrapper): + # save any additional trainable modules part of `modules_to_save` + setattr(parent, target_name, target.modules_to_save[target.active_adapter]) + + return self.model + + def delete_adapter(self, adapter_name: str) -> None: + """ + Deletes an existing adapter. + + Args: + adapter_name (str): Name of the adapter to be deleted. + """ + if adapter_name not in list(self.peft_config.keys()): + raise ValueError(f"Adapter {adapter_name} does not exist") + del self.peft_config[adapter_name] + + key_list = [key for key, _ in self.model.named_modules() if self.prefix not in key] + new_adapter = None + for key in key_list: + _, target, _ = _get_submodules(self.model, key) + if isinstance(target, BOFTLayer): + target.delete_adapter(adapter_name) + if new_adapter is None: + new_adapter = target.active_adapters[:] + + self.active_adapter = new_adapter or [] + + def merge_and_unload( + self, progressbar: bool = False, safe_merge: bool = False, adapter_names: Optional[List[str]] = None + ) -> torch.nn.Module: + r""" + This method merges the BOFT layers into the base model. This is needed if someone wants to use the base model + as a standalone model. + + Args: + progressbar (`bool`): + whether to show a progressbar indicating the unload and merge process + safe_merge (`bool`): + whether to activate the safe merging check to check if there is any potential Nan in the adapter + weights + adapter_names (`List[str]`, *optional*): + The list of adapter names that should be merged. If None, all active adapters will be merged. Defaults + to `None`. + + """ + return self._unload_and_optionally_merge( + progressbar=progressbar, safe_merge=safe_merge, adapter_names=adapter_names + ) + + def unload(self) -> torch.nn.Module: + """ + Gets back the base model by removing all the boft modules without merging. This gives back the original base + model. + """ + return self._unload_and_optionally_merge(merge=False) diff --git a/src/peft/utils/peft_types.py b/src/peft/utils/peft_types.py index d4a84435dc..a2d276c10b 100644 --- a/src/peft/utils/peft_types.py +++ b/src/peft/utils/peft_types.py @@ -43,6 +43,7 @@ class PeftType(str, enum.Enum): PREFIX_TUNING = "PREFIX_TUNING" LORA = "LORA" ADALORA = "ADALORA" + BOFT = "BOFT" ADAPTION_PROMPT = "ADAPTION_PROMPT" IA3 = "IA3" LOHA = "LOHA" diff --git a/src/peft/utils/save_and_load.py b/src/peft/utils/save_and_load.py index 8b7654f3b3..5b1662eb75 100644 --- a/src/peft/utils/save_and_load.py +++ b/src/peft/utils/save_and_load.py @@ -98,6 +98,23 @@ def get_peft_model_state_dict( config.rank_pattern = rank_pattern to_return = model.resize_state_dict_by_rank_pattern(rank_pattern, to_return, adapter_name) + elif config.peft_type == PeftType.BOFT: + bias = config.bias + if bias == "none": + to_return = {k: state_dict[k] for k in state_dict if "boft_" in k} + elif bias == "all": + to_return = {k: state_dict[k] for k in state_dict if "boft_" in k or "bias" in k} + elif bias == "boft_only": + to_return = {} + for k in state_dict: + if "boft_" in k: + to_return[k] = state_dict[k] + bias_name = k.split("boft_")[0] + "bias" + if bias_name in state_dict: + to_return[bias_name] = state_dict[bias_name] + else: + raise NotImplementedError + elif config.peft_type == PeftType.LOHA: to_return = {k: state_dict[k] for k in state_dict if "hada_" in k} @@ -253,6 +270,7 @@ def set_peft_model_state_dict( PeftType.IA3, PeftType.OFT, PeftType.POLY, + PeftType.BOFT, ): peft_model_state_dict = {} parameter_prefix = { @@ -263,6 +281,7 @@ def set_peft_model_state_dict( PeftType.LOKR: "lokr_", PeftType.OFT: "oft_", PeftType.POLY: "poly_", + PeftType.BOFT: "boft_", }[config.peft_type] for k, v in state_dict.items(): if parameter_prefix in k: diff --git a/tests/test_config.py b/tests/test_config.py index 1cddf97537..cea7e5efe7 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -24,6 +24,7 @@ from peft import ( AdaLoraConfig, AdaptionPromptConfig, + BOFTConfig, IA3Config, LoHaConfig, LoraConfig, @@ -53,6 +54,7 @@ PromptTuningConfig, OFTConfig, PolyConfig, + BOFTConfig, ) @@ -203,7 +205,7 @@ def test_prompt_encoder_warning_num_layers(self): expected_msg = "for MLP, the argument `encoder_num_layers` is ignored. Exactly 2 MLP layers are used." assert str(record.list[0].message) == expected_msg - @parameterized.expand([LoHaConfig, LoraConfig, IA3Config, OFTConfig]) + @parameterized.expand([LoHaConfig, LoraConfig, IA3Config, OFTConfig, BOFTConfig]) def test_save_pretrained_with_target_modules(self, config_class): # See #1041, #1045 config = config_class(target_modules=["a", "list"]) diff --git a/tests/test_custom_models.py b/tests/test_custom_models.py index 79fe2e31d0..dfadeb114f 100644 --- a/tests/test_custom_models.py +++ b/tests/test_custom_models.py @@ -33,6 +33,7 @@ from peft import ( AdaLoraConfig, + BOFTConfig, IA3Config, LoHaConfig, LoKrConfig, @@ -241,6 +242,80 @@ ("Conv2d 3 OFT", "Conv2d", OFTConfig, {"target_modules": ["conv2d"], "coft": True}), ("Conv2d 4 OFT", "Conv2d", OFTConfig, {"target_modules": ["conv2d"], "block_share": True}), ("Conv2d 5 OFT", "Conv2d", OFTConfig, {"target_modules": ["conv2d"], "coft": True, "block_share": True}), + ######## + # BOFT # + ######## + ("Vanilla MLP 1 BOFT", "MLP", BOFTConfig, {"target_modules": ["lin1"], "boft_block_size": 2}), + ( + "Vanilla MLP 2 BOFT", + "MLP", + BOFTConfig, + {"target_modules": ["lin1"], "modules_to_save": ["lin0"], "boft_block_size": 2}, + ), + ( + "Vanilla MLP 3 BOFT", + "MLP", + BOFTConfig, + { + "target_modules": ["lin1"], + "boft_block_size": 2, + "boft_dropout": 0.1, + }, + ), + ( + "Vanilla MLP 4 BOFT", + "MLP", + BOFTConfig, + {"target_modules": ["lin1"], "boft_block_size": 2, "boft_block_num": 0, "boft_n_butterfly_factor": 1}, + ), + ( + "Vanilla MLP 5 BOFT", + "MLP", + BOFTConfig, + {"target_modules": ["lin1"], "boft_block_size": 0, "boft_block_num": 2, "boft_n_butterfly_factor": 1}, + ), + ( + "Vanilla MLP 6 BOFT", + "MLP", + BOFTConfig, + {"target_modules": ["lin1"], "boft_block_size": 10, "boft_block_num": 0, "boft_n_butterfly_factor": 2}, + ), + ( + "Conv2d 1 BOFT", + "Conv2d", + BOFTConfig, + {"target_modules": ["conv2d"], "boft_block_size": 45, "boft_block_num": 0, "boft_n_butterfly_factor": 1}, + ), + ( + "Conv2d 2 BOFT", + "Conv2d", + BOFTConfig, + {"target_modules": ["conv2d"], "boft_block_size": 0, "boft_block_num": 1, "boft_n_butterfly_factor": 1}, + ), + ( + "MLP2 1 BOFT", + "MLP2", + BOFTConfig, + {"target_modules": ["lin1"], "boft_block_size": 2, "boft_block_num": 0, "boft_n_butterfly_factor": 3}, + ), + ( + "MLP2 2 BOFT", + "MLP2", + BOFTConfig, + {"target_modules": ["lin1"], "boft_block_size": 0, "boft_block_num": 8, "boft_n_butterfly_factor": 3}, + ), + ( + "Conv2d2 1 BOFT", + "Conv2d2", + BOFTConfig, + {"target_modules": ["conv2d"], "boft_block_size": 2, "boft_block_num": 0, "boft_n_butterfly_factor": 2}, + ), + ( + "Conv2d2 1 BOFT", + "Conv2d2", + BOFTConfig, + {"target_modules": ["conv2d"], "boft_block_size": 2, "boft_block_num": 0, "boft_n_butterfly_factor": 3}, + ), ] MULTIPLE_ACTIVE_ADAPTERS_TEST_CASES = [ @@ -309,6 +384,7 @@ LoHaConfig: "hada_", LoKrConfig: "lokr_", OFTConfig: "oft_", + BOFTConfig: "boft_", } @@ -331,6 +407,25 @@ def forward(self, X): return X +class MLP2(nn.Module): + def __init__(self, bias=True): + super().__init__() + self.lin0 = nn.Linear(10, 32, bias=bias) + self.relu = nn.ReLU() + self.drop = nn.Dropout(0.5) + self.lin1 = nn.Linear(32, 2, bias=bias) + self.sm = nn.LogSoftmax(dim=-1) + + def forward(self, X): + X = X.float() + X = self.lin0(X) + X = self.relu(X) + X = self.drop(X) + X = self.lin1(X) + X = self.sm(X) + return X + + class Block(nn.Module): def __init__(self, bias=True): super().__init__() @@ -430,6 +525,29 @@ def forward(self, X): return X +class ModelConv2D2(nn.Module): + def __init__(self): + super().__init__() + self.lin0 = nn.Linear(10, 40) + self.conv2d = nn.Conv2d(8, 32, 3) + self.relu = nn.ReLU() + self.flat = nn.Flatten() + self.lin1 = nn.Linear(32, 2) + self.sm = nn.LogSoftmax(dim=-1) + + def forward(self, X): + X = X.float() + X = self.lin0(X) + X = self.relu(X) + X = X.reshape(-1, 8, 3, 3) + X = self.conv2d(X) + X = self.relu(X) + X = self.flat(X) + X = self.lin1(X) + X = self.sm(X) + return X + + class MockTransformerWrapper: """Mock class to behave like a transformers model. @@ -454,6 +572,12 @@ def from_pretrained(cls, model_id, torch_dtype=None): if model_id == "Conv2d": return ModelConv2D().to(torch_dtype) + if model_id == "MLP2": + return MLP2().to(torch_dtype) + + if model_id == "Conv2d2": + return ModelConv2D2().to(torch_dtype) + raise ValueError(f"model_id {model_id} not implemented") @@ -769,12 +893,12 @@ def test_disable_adapter_with_bias_warns(self, test_name, model_id, config_cls, # Note: We test only with custom models since they run really fast. There is really no point in testing the same # thing with decoder, encoder_decoder, etc. - if config_cls != LoraConfig: + if config_cls != LoraConfig or config_cls != BOFTConfig: # skip this test for other configs as bias is specific to Lora - self.skipTest("Testing bias warnings only for LoraConfig") + self.skipTest("Testing bias warnings only for LoraConfig or BOFTConfig") - if not issubclass(config_cls, LoraConfig): - self.skipTest("Bias argument is only supported for LoRA models") + if not issubclass(config_cls, (LoraConfig, BOFTConfig)): + self.skipTest("Bias argument is only supported for LoRA or BOFT models") def run_with_disable(config_kwargs, bias): config_kwargs = config_kwargs.copy() @@ -788,12 +912,21 @@ def run_with_disable(config_kwargs, bias): with peft_model.disable_adapter(): pass # there is nothing to be done - # check that bias=all and bias=lora_only give a warning with the correct message - msg_start = "Careful, disabling adapter layers with bias configured to be" - with pytest.warns(UserWarning, match=msg_start): - run_with_disable(config_kwargs, bias="lora_only") - with pytest.warns(UserWarning, match=msg_start): - run_with_disable(config_kwargs, bias="all") + if config_cls == LoraConfig: + # check that bias=all and bias=lora_only give a warning with the correct message + msg_start = "Careful, disabling adapter layers with bias configured to be" + with pytest.warns(UserWarning, match=msg_start): + run_with_disable(config_kwargs, bias="lora_only") + with pytest.warns(UserWarning, match=msg_start): + run_with_disable(config_kwargs, bias="all") + + if config_cls == BOFTConfig: + # check that bias=all and bias=boft_only give a warning with the correct message + msg_start = "Careful, disabling adapter layers with bias configured to be" + with pytest.warns(UserWarning, match=msg_start): + run_with_disable(config_kwargs, bias="boft_only") + with pytest.warns(UserWarning, match=msg_start): + run_with_disable(config_kwargs, bias="all") # For bias=none, there is no warning. Unfortunately, AFAIK unittest has no option to assert that no warning is # given, therefore, we check that the unittest gives us an AssertionError if we check for a warning @@ -1076,6 +1209,7 @@ def test_load_resized_embedding_ignore_mismatched_sizes(self): AdaLoraConfig(target_modules=["lin0"], init_lora_weights=False), IA3Config(target_modules=["lin0"], feedforward_modules=["lin0"], init_ia3_weights=False), OFTConfig(target_modules=["lin0"], init_weights=False), + BOFTConfig(target_modules=["lin0"], init_weights=False, boft_block_size=2), ] ) def test_adapter_name_makes_no_difference(self, config0): @@ -2215,6 +2349,91 @@ def test_requires_grad_oft_same_targets(self): "base_model.model.lin0.oft_r.adapter1", ) + def test_requires_grad_boft_different_targets(self): + # test two different OFT adapters that target different modules + config0 = BOFTConfig(target_modules=["lin0"], boft_block_size=2) + peft_model = get_peft_model(MLP2(), config0) + + config1 = BOFTConfig(target_modules=["lin1"], boft_block_size=2, inference_mode=True) + peft_model.add_adapter("adapter1", config1) + + # active pter is still "default" + self.check_requires_grad( + peft_model, + "base_model.model.lin0.boft_R.default", + "base_model.model.lin0.boft_s.default", + ) + + # set config0 as active, should not change anything + peft_model.set_adapter("default") + self.check_requires_grad( + peft_model, + "base_model.model.lin0.boft_R.default", + "base_model.model.lin0.boft_s.default", + ) + + # change activate pter to pter1 + peft_model.set_adapter("adapter1") + self.check_requires_grad( + peft_model, + "base_model.model.lin1.boft_R.adapter1", + "base_model.model.lin1.boft_s.adapter1", + ) + + # disable all pters + with peft_model.disable_adapter(): + self.check_requires_grad(peft_model) + + # after context is exited, return to the previous state + self.check_requires_grad( + peft_model, + "base_model.model.lin1.boft_R.adapter1", + "base_model.model.lin1.boft_s.adapter1", + ) + + def test_requires_grad_boft_same_targets(self): + # same as previous test, except that BOFT adapters target the same layer + config0 = BOFTConfig(target_modules=["lin1"], boft_block_size=2) + peft_model = get_peft_model(MLP(), config0) + + config1 = BOFTConfig(target_modules=["lin1"], boft_block_size=2, inference_mode=True) + peft_model.add_adapter("adapter1", config1) + + # active adapter is still "default" + self.check_requires_grad( + peft_model, + "base_model.model.lin1.boft_R.default", + "base_model.model.lin1.boft_s.default", + ) + + # set config0 as active, should not change anything + peft_model.set_adapter("default") + self.check_requires_grad( + peft_model, + "base_model.model.lin1.boft_R.default", + "base_model.model.lin1.boft_s.default", + ) + + # change activate adapter to adapter1 + peft_model.set_adapter("adapter1") + self.check_requires_grad( + peft_model, + "base_model.model.lin1.boft_R.adapter1", + "base_model.model.lin1.boft_s.adapter1", + ) + + # disable all adapters + with peft_model.disable_adapter(): + self.check_requires_grad(peft_model) + + # after context is exited, return to the previous state + peft_model.set_adapter("adapter1") + self.check_requires_grad( + peft_model, + "base_model.model.lin1.boft_R.adapter1", + "base_model.model.lin1.boft_s.adapter1", + ) + class TestMixedAdapterBatches: torch_device = infer_device() diff --git a/tests/test_decoder_models.py b/tests/test_decoder_models.py index cd07c0eab2..642cc99e6a 100644 --- a/tests/test_decoder_models.py +++ b/tests/test_decoder_models.py @@ -19,7 +19,7 @@ from parameterized import parameterized from transformers import AutoModelForCausalLM, AutoTokenizer -from peft import AdaLoraConfig, LoraConfig, PromptTuningConfig, PromptTuningInit, get_peft_model +from peft import AdaLoraConfig, BOFTConfig, LoraConfig, PromptTuningConfig, PromptTuningInit, get_peft_model from .testing_common import PeftCommonTester, PeftTestConfigManager @@ -45,6 +45,18 @@ def skip_adalora_and_gpt2(test_list): return [test for test in test_list if not (("GPT2LMHeadModel" in test[1]) and (test[2] == AdaLoraConfig))] +def skip_boft_and_gpt2(test_list): + return [test for test in test_list if not (("GPT2LMHeadModel" in test[1]) and (test[2] == BOFTConfig))] + + +def skip_adalora_or_boft_and_gpt2(test_list): + return [ + test + for test in test_list + if not (("GPT2LMHeadModel" in test[1]) and ((test[2] == AdaLoraConfig) or (test[2] == BOFTConfig))) + ] + + class PeftDecoderModelTester(unittest.TestCase, PeftCommonTester): r""" Test if the PeftModel behaves as expected. This includes: @@ -66,15 +78,15 @@ def prepare_inputs_for_testing(self): return input_dict - @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID)) + @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID, filter_params_func=skip_boft_and_gpt2)) def test_attributes_parametrized(self, test_name, model_id, config_cls, config_kwargs): self._test_model_attr(model_id, config_cls, config_kwargs) - @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID)) + @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID, filter_params_func=skip_boft_and_gpt2)) def test_adapter_name(self, test_name, model_id, config_cls, config_kwargs): self._test_adapter_name(model_id, config_cls, config_kwargs) - @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID)) + @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID, filter_params_func=skip_boft_and_gpt2)) def test_prepare_for_training_parametrized(self, test_name, model_id, config_cls, config_kwargs): self._test_prepare_for_training(model_id, config_cls, config_kwargs) @@ -132,23 +144,23 @@ def test_prompt_tuning_config_invalid_args(self): tokenizer_kwargs={"trust_remote_code": True, "foo": "bar"}, ) - @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID)) + @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID, filter_params_func=skip_boft_and_gpt2)) def test_save_pretrained(self, test_name, model_id, config_cls, config_kwargs): self._test_save_pretrained(model_id, config_cls, config_kwargs) - @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID)) + @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID, filter_params_func=skip_boft_and_gpt2)) def test_save_pretrained_pickle(self, test_name, model_id, config_cls, config_kwargs): self._test_save_pretrained(model_id, config_cls, config_kwargs, safe_serialization=False) - @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID)) + @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID, filter_params_func=skip_boft_and_gpt2)) def test_save_pretrained_selected_adapters(self, test_name, model_id, config_cls, config_kwargs): self._test_save_pretrained_selected_adapters(model_id, config_cls, config_kwargs) - @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID)) + @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID, filter_params_func=skip_boft_and_gpt2)) def test_save_pretrained_selected_adapters_pickle(self, test_name, model_id, config_cls, config_kwargs): self._test_save_pretrained_selected_adapters(model_id, config_cls, config_kwargs, safe_serialization=False) - @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID)) + @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID, filter_params_func=skip_boft_and_gpt2)) def test_from_pretrained_config_construction(self, test_name, model_id, config_cls, config_kwargs): self._test_from_pretrained_config_construction(model_id, config_cls, config_kwargs) @@ -158,6 +170,7 @@ def test_from_pretrained_config_construction(self, test_name, model_id, config_c "model_ids": PEFT_DECODER_MODELS_TO_TEST, "lora_kwargs": {"init_lora_weights": [False]}, "ia3_kwargs": {"init_ia3_weights": [False]}, + "boft_kwargs": {"init_weights": [False]}, "task_type": "CAUSAL_LM", }, ) @@ -171,8 +184,10 @@ def test_merge_layers(self, test_name, model_id, config_cls, config_kwargs): "model_ids": PEFT_DECODER_MODELS_TO_TEST, "lora_kwargs": {"init_lora_weights": [False]}, "ia3_kwargs": {"init_ia3_weights": [False]}, + "boft_kwargs": {"init_weights": [False]}, "task_type": "CAUSAL_LM", }, + filter_params_func=skip_boft_and_gpt2, ) ) def test_merge_layers_multi(self, test_name, model_id, config_cls, config_kwargs): @@ -184,6 +199,7 @@ def test_merge_layers_multi(self, test_name, model_id, config_cls, config_kwargs "model_ids": PEFT_DECODER_MODELS_TO_TEST, "lora_kwargs": {"init_lora_weights": [False]}, "ia3_kwargs": {"init_ia3_weights": [False]}, + "boft_kwargs": {"init_weights": [False]}, "task_type": "CAUSAL_LM", }, ) @@ -203,11 +219,11 @@ def test_merge_layers_nan(self, test_name, model_id, config_cls, config_kwargs): def test_mixed_adapter_batches(self, test_name, model_id, config_cls, config_kwargs): self._test_mixed_adapter_batches(model_id, config_cls, config_kwargs) - @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID)) + @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID, filter_params_func=skip_boft_and_gpt2)) def test_generate(self, test_name, model_id, config_cls, config_kwargs): self._test_generate(model_id, config_cls, config_kwargs) - @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID)) + @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID, filter_params_func=skip_boft_and_gpt2)) def test_generate_pos_args(self, test_name, model_id, config_cls, config_kwargs): # positional args are supported for PeftModelForCausalLM self._test_generate_pos_args(model_id, config_cls, config_kwargs, raises_err=False) @@ -224,7 +240,7 @@ def test_generate_half_prec(self, test_name, model_id, config_cls, config_kwargs def test_prefix_tuning_half_prec_conversion(self, test_name, model_id, config_cls, config_kwargs): self._test_prefix_tuning_half_prec_conversion(model_id, config_cls, config_kwargs) - @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID)) + @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID, filter_params_func=skip_boft_and_gpt2)) def test_training_decoders(self, test_name, model_id, config_cls, config_kwargs): self._test_training(model_id, config_cls, config_kwargs) @@ -232,11 +248,11 @@ def test_training_decoders(self, test_name, model_id, config_cls, config_kwargs) def test_training_decoders_layer_indexing(self, test_name, model_id, config_cls, config_kwargs): self._test_training_layer_indexing(model_id, config_cls, config_kwargs) - @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID)) + @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID, filter_params_func=skip_boft_and_gpt2)) def test_training_decoders_gradient_checkpointing(self, test_name, model_id, config_cls, config_kwargs): self._test_training_gradient_checkpointing(model_id, config_cls, config_kwargs) - @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID)) + @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID, filter_params_func=skip_boft_and_gpt2)) def test_inference_safetensors(self, test_name, model_id, config_cls, config_kwargs): self._test_inference_safetensors(model_id, config_cls, config_kwargs) @@ -244,15 +260,15 @@ def test_inference_safetensors(self, test_name, model_id, config_cls, config_kwa def test_peft_model_device_map(self, test_name, model_id, config_cls, config_kwargs): self._test_peft_model_device_map(model_id, config_cls, config_kwargs) - @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID)) + @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID, filter_params_func=skip_boft_and_gpt2)) def test_delete_adapter(self, test_name, model_id, config_cls, config_kwargs): self._test_delete_adapter(model_id, config_cls, config_kwargs) - @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID)) + @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID, filter_params_func=skip_boft_and_gpt2)) def test_delete_inactive_adapter(self, test_name, model_id, config_cls, config_kwargs): self._test_delete_inactive_adapter(model_id, config_cls, config_kwargs) - @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID)) + @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID, filter_params_func=skip_boft_and_gpt2)) def test_adding_multiple_adapters_with_bias_raises(self, test_name, model_id, config_cls, config_kwargs): self._test_adding_multiple_adapters_with_bias_raises(model_id, config_cls, config_kwargs) @@ -263,9 +279,10 @@ def test_adding_multiple_adapters_with_bias_raises(self, test_name, model_id, co "lora_kwargs": {"init_lora_weights": [False]}, "adalora_kwargs": {"init_lora_weights": [False]}, "ia3_kwargs": {"init_ia3_weights": [False]}, + "boft_kwargs": {"init_weights": [False]}, "task_type": "CAUSAL_LM", }, - filter_params_func=skip_adalora_and_gpt2, + filter_params_func=skip_adalora_or_boft_and_gpt2, ) ) def test_unload_adapter(self, test_name, model_id, config_cls, config_kwargs): @@ -276,6 +293,7 @@ def test_unload_adapter(self, test_name, model_id, config_cls, config_kwargs): { "model_ids": PEFT_DECODER_MODELS_TO_TEST, "lora_kwargs": {"init_lora_weights": [False]}, + "boft_kwargs": {"init_weights": [False]}, "task_type": "CAUSAL_LM", }, ) @@ -294,8 +312,10 @@ def test_training_prompt_learning_tasks(self, test_name, model_id, config_cls, c "lora_kwargs": {"init_lora_weights": [False]}, "ia3_kwargs": {"init_ia3_weights": [False]}, "adalora_kwargs": {"init_lora_weights": [False]}, + "boft_kwargs": {"init_weights": [False]}, "task_type": "CAUSAL_LM", }, + filter_params_func=skip_boft_and_gpt2, ) ) def test_disable_adapter(self, test_name, model_id, config_cls, config_kwargs): @@ -311,7 +331,7 @@ def test_generate_adalora_no_dropout(self): } self._test_generate(model_id, AdaLoraConfig, config_kwargs) - @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID)) + @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID, filter_params_func=skip_boft_and_gpt2)) def test_passing_input_embeds_works(self, test_name, model_id, config_cls, config_kwargs): self._test_passing_input_embeds_works(test_name, model_id, config_cls, config_kwargs) diff --git a/tests/test_encoder_decoder_models.py b/tests/test_encoder_decoder_models.py index 240f123948..1dbb689a7f 100644 --- a/tests/test_encoder_decoder_models.py +++ b/tests/test_encoder_decoder_models.py @@ -169,6 +169,7 @@ def test_adding_multiple_adapters_with_bias_raises(self, test_name, model_id, co "lora_kwargs": {"init_lora_weights": [False]}, "adalora_kwargs": {"init_lora_weights": [False]}, "ia3_kwargs": {"init_ia3_weights": [False]}, + "boft_kwargs": {"init_weights": [False]}, "task_type": "SEQ_2_SEQ_LM", }, ) @@ -199,6 +200,7 @@ def test_training_prompt_learning_tasks(self, test_name, model_id, config_cls, c "lora_kwargs": {"init_lora_weights": [False]}, "adalora_kwargs": {"init_lora_weights": [False]}, "ia3_kwargs": {"init_ia3_weights": [False]}, + "boft_kwargs": {"init_weights": [False]}, "task_type": "SEQ_2_SEQ_LM", }, ) diff --git a/tests/test_feature_extraction_models.py b/tests/test_feature_extraction_models.py index 84c246f58e..a0cf85f2c2 100644 --- a/tests/test_feature_extraction_models.py +++ b/tests/test_feature_extraction_models.py @@ -44,9 +44,12 @@ def skip_non_prompt_tuning(test_list): def skip_deberta_lora_tests(test_list): r""" - Skip tests that are checkpointing with lora/ia3 tests for Deberta models (couldn't find much info on the error) + Skip tests that are checkpointing with lora/ia3/boft tests for Deberta models (couldn't find much info on the + error) """ - return [test for test in test_list if not (any(k in test[0] for k in ["lora", "ia3"]) and "Deberta" in test[0])] + return [ + test for test in test_list if not (any(k in test[0] for k in ["lora", "ia3", "boft"]) and "Deberta" in test[0]) + ] def skip_deberta_pt_tests(test_list): @@ -107,6 +110,7 @@ def test_from_pretrained_config_construction(self, test_name, model_id, config_c "model_ids": PEFT_FEATURE_EXTRACTION_MODELS_TO_TEST, "lora_kwargs": {"init_lora_weights": [False]}, "ia3_kwargs": {"init_ia3_weights": [False]}, + "boft_kwargs": {"init_weights": [False]}, "task_type": "FEATURE_EXTRACTION", }, ) @@ -157,6 +161,7 @@ def test_delete_inactive_adapter(self, test_name, model_id, config_cls, config_k "lora_kwargs": {"init_lora_weights": [False]}, "adalora_kwargs": {"init_lora_weights": [False]}, "ia3_kwargs": {"init_ia3_weights": [False]}, + "boft_kwargs": {"init_weights": [False]}, "task_type": "FEATURE_EXTRACTION", }, ) @@ -169,6 +174,7 @@ def test_unload_adapter(self, test_name, model_id, config_cls, config_kwargs): { "model_ids": PEFT_FEATURE_EXTRACTION_MODELS_TO_TEST, "lora_kwargs": {"init_lora_weights": [False]}, + "boft_kwargs": {"init_weights": [False]}, "task_type": "FEATURE_EXTRACTION", }, ) diff --git a/tests/test_stablediffusion.py b/tests/test_stablediffusion.py index 9f582a1e4f..b8cc2e203a 100644 --- a/tests/test_stablediffusion.py +++ b/tests/test_stablediffusion.py @@ -19,7 +19,7 @@ from diffusers import StableDiffusionPipeline from parameterized import parameterized -from peft import LoHaConfig, LoraConfig, OFTConfig, get_peft_model +from peft import BOFTConfig, LoHaConfig, LoraConfig, OFTConfig, get_peft_model from .testing_common import ClassInstantier, PeftCommonTester from .testing_utils import temp_seed @@ -71,12 +71,27 @@ "module_dropout": 0.0, }, }, + { + "text_encoder": { + "boft_block_num": 1, + "boft_block_size": 0, + "target_modules": ["k_proj", "q_proj", "v_proj", "out_proj", "fc1", "fc2"], + "boft_dropout": 0.0, + }, + "unet": { + "boft_block_num": 1, + "boft_block_size": 0, + "target_modules": ["proj_in", "proj_out", "to_k", "to_q", "to_v", "to_out.0", "ff.net.0.proj", "ff.net.2"], + "boft_dropout": 0.0, + }, + }, ) CLASSES_MAPPING = { "lora": (LoraConfig, CONFIG_TESTING_KWARGS[0]), "loha": (LoHaConfig, CONFIG_TESTING_KWARGS[1]), "lokr": (LoHaConfig, CONFIG_TESTING_KWARGS[1]), "oft": (OFTConfig, CONFIG_TESTING_KWARGS[2]), + "boft": (BOFTConfig, CONFIG_TESTING_KWARGS[3]), } @@ -129,6 +144,7 @@ def prepare_inputs_for_testing(self): "lora_kwargs": {"init_lora_weights": [False]}, "loha_kwargs": {"init_weights": [False]}, "oft_kwargs": {"init_weights": [False]}, + "boft_kwargs": {"init_weights": [False]}, }, ) ) @@ -161,6 +177,7 @@ def test_merge_layers(self, test_name, model_id, config_cls, config_kwargs): "lora_kwargs": {"init_lora_weights": [False]}, "loha_kwargs": {"init_weights": [False]}, "oft_kwargs": {"init_weights": [False]}, + "boft_kwargs": {"init_weights": [False]}, }, ) ) @@ -221,6 +238,7 @@ def test_add_weighted_adapter_base_unchanged(self, test_name, model_id, config_c "loha_kwargs": {"init_weights": [False]}, "lokr_kwargs": {"init_weights": [False]}, "oft_kwargs": {"init_weights": [False]}, + "boft_kwargs": {"init_weights": [False]}, }, ) ) diff --git a/tests/testing_common.py b/tests/testing_common.py index 0538b6cf59..8632603083 100644 --- a/tests/testing_common.py +++ b/tests/testing_common.py @@ -27,6 +27,7 @@ from peft import ( AdaLoraConfig, + BOFTConfig, IA3Config, LoHaConfig, LoKrConfig, @@ -78,6 +79,10 @@ { "target_modules": None, }, + # BOFT + { + "target_modules": None, + }, ) CLASSES_MAPPING = { @@ -87,6 +92,7 @@ "prompt_encoder": (PromptEncoderConfig, CONFIG_TESTING_KWARGS[3]), "prompt_tuning": (PromptTuningConfig, CONFIG_TESTING_KWARGS[4]), "adalora": (AdaLoraConfig, CONFIG_TESTING_KWARGS[5]), + "boft": (BOFTConfig, CONFIG_TESTING_KWARGS[6]), } @@ -510,6 +516,9 @@ def _test_merge_layers(self, model_id, config_cls, config_kwargs): if issubclass(config_cls, PromptLearningConfig): return pytest.skip(f"Test not applicable for {config_cls}") + if issubclass(config_cls, BOFTConfig): + return pytest.skip(f"Test not applicable for {config_cls}") + if ("gpt2" in model_id.lower()) and (config_cls != LoraConfig): self.skipTest("Merging GPT2 adapters not supported for IA³ (yet)") @@ -561,7 +570,7 @@ def _test_merge_layers(self, model_id, config_cls, config_kwargs): assert torch.allclose(logits_merged, logits_merged_from_pretrained, atol=atol, rtol=rtol) def _test_merge_layers_multi(self, model_id, config_cls, config_kwargs): - supported_peft_types = [PeftType.LORA, PeftType.LOHA, PeftType.LOKR, PeftType.IA3, PeftType.OFT] + supported_peft_types = [PeftType.LORA, PeftType.LOHA, PeftType.LOKR, PeftType.IA3, PeftType.OFT, PeftType.BOFT] if ("gpt2" in model_id.lower()) and (config_cls == IA3Config): self.skipTest("Merging GPT2 adapters not supported for IA³ (yet)") @@ -626,9 +635,6 @@ def _test_merge_layers_multi(self, model_id, config_cls, config_kwargs): assert torch.allclose(logits_merged_adapter_default, logits_adapter_1, atol=1e-3, rtol=1e-3) def _test_merge_layers_is_idempotent(self, model_id, config_cls, config_kwargs): - if ("gpt2" in model_id.lower()) and (config_cls != LoraConfig): - self.skipTest("Merging GPT2 adapters not supported for IA³ (yet)") - model = self.transformers_class.from_pretrained(model_id) config = config_cls( base_model_name_or_path=model_id, @@ -952,7 +958,14 @@ def _test_training_gradient_checkpointing(self, model_id, config_cls, config_kwa loss = output.sum() loss.backward() - parameter_prefix = "ia3" if config_cls == IA3Config else "lora" + # parameter_prefix = "ia3" if config_cls == IA3Config else "lora" + if config_cls == IA3Config: + parameter_prefix = "ia3" + elif config_cls == BOFTConfig: + parameter_prefix = "boft" + else: + parameter_prefix = "lora" + for n, param in model.named_parameters(): if parameter_prefix in n: assert param.grad is not None @@ -1005,7 +1018,7 @@ def _test_training_prompt_learning_tasks(self, model_id, config_cls, config_kwar assert param.grad is not None def _test_delete_adapter(self, model_id, config_cls, config_kwargs): - supported_peft_types = [PeftType.LORA, PeftType.LOHA, PeftType.LOKR, PeftType.IA3, PeftType.OFT] + supported_peft_types = [PeftType.LORA, PeftType.LOHA, PeftType.LOKR, PeftType.IA3, PeftType.OFT, PeftType.BOFT] # IA3 does not support deleting adapters yet, but it just needs to be added # AdaLora does not support multiple adapters config = config_cls( @@ -1043,7 +1056,7 @@ def _test_delete_adapter(self, model_id, config_cls, config_kwargs): def _test_delete_inactive_adapter(self, model_id, config_cls, config_kwargs): # same as test_delete_adapter, but this time an inactive adapter is deleted - supported_peft_types = [PeftType.LORA, PeftType.LOHA, PeftType.LOKR, PeftType.IA3, PeftType.OFT] + supported_peft_types = [PeftType.LORA, PeftType.LOHA, PeftType.LOKR, PeftType.IA3, PeftType.OFT, PeftType.BOFT] # IA3 does not support deleting adapters yet, but it just needs to be added # AdaLora does not support multiple adapters config = config_cls( @@ -1088,7 +1101,7 @@ def _test_unload_adapter(self, model_id, config_cls, config_kwargs): model = get_peft_model(model, config) model = model.to(self.torch_device) - if config.peft_type not in ("LORA", "ADALORA", "IA3"): + if config.peft_type not in ("LORA", "ADALORA", "IA3", "BOFT"): with pytest.raises(AttributeError): model = model.unload() else: @@ -1377,9 +1390,9 @@ def get_output(model): # TODO: add tests to check if disabling adapters works after calling merge_adapter def _test_adding_multiple_adapters_with_bias_raises(self, model_id, config_cls, config_kwargs): - # When trying to add multiple adapters with bias in Lora or AdaLora, an error should be + # When trying to add multiple adapters with bias in Lora, AdaLora or BOFTConfig, an error should be # raised. Also, the peft model should not be left in a half-initialized state. - if not issubclass(config_cls, (LoraConfig, AdaLoraConfig)): + if not issubclass(config_cls, (LoraConfig, AdaLoraConfig, BOFTConfig)): return pytest.skip(f"Test not applicable for {config_cls}") config_kwargs = config_kwargs.copy() @@ -1391,8 +1404,14 @@ def _test_adding_multiple_adapters_with_bias_raises(self, model_id, config_cls, model = self.transformers_class.from_pretrained(model_id) model = get_peft_model(model, config, "adapter0") - with pytest.raises(ValueError): - model.add_adapter("adapter1", replace(config, r=20)) + + if config_cls == LoraConfig or config_cls == AdaLoraConfig: + with pytest.raises(ValueError): + model.add_adapter("adapter1", replace(config, r=20)) + + if config_cls == BOFTConfig: + with pytest.raises(ValueError): + model.add_adapter("adapter1", replace(config, boft_block_num=1, boft_block_size=0)) # (superficial) test that the model is not left in a half-initialized state when adding an adapter fails assert "adapter1" not in model.peft_config