diff --git a/docs/source/conceptual_guides/lora.md b/docs/source/conceptual_guides/lora.md index 67e16edda0..d06b127f74 100644 --- a/docs/source/conceptual_guides/lora.md +++ b/docs/source/conceptual_guides/lora.md @@ -93,3 +93,22 @@ For an example of LoRA method application to various downstream tasks, please re While the original paper focuses on language models, the technique can be applied to any dense layers in deep learning models. As such, you can leverage this technique with diffusion models. See [Dreambooth fine-tuning with LoRA](../task_guides/task_guides/dreambooth_lora) task guide for an example. + +## Initialization options + +The initialization of LoRA weights is controlled by the parameter `init_lora_weights` of the `LoraConfig`. By default, PEFT initializes LoRA weights the same way as the [reference implementation](https://github.com/microsoft/LoRA), i.e. using Kaiming-uniform for weight A and initializing weight B as zeros, resulting in an identity transform. + +It is also possible to pass `init_lora_weights="gaussian"`. As the name suggests, this results in initializing weight A with a Gaussian distribution (weight B is still zeros). This corresponds to the way that [diffusers](https://huggingface.co/docs/diffusers/index) initializes LoRA weights. + +When quantizing the base model, e.g. for QLoRA training, consider using the [LoftQ initialization](https://arxiv.org/abs/2310.08659), which has been shown to improve the performance with quantization. The idea is that the LoRA weights are initialized such that the quantization error is minimized. To use this option, *do not* quantize the base model. Instead, proceed as follows: + +```python +from peft import LoftQConfig, LoraConfig, get_peft_model + +base_model = AutoModelForCausalLM.from_pretrained(...) # don't quantize here +loftq_config = LoftQConfig(loftq_bits=4, ...) # set 4bit quantization +lora_config = LoraConfig(..., init_lora_weights="loftq", loftq_config=loftq_config) +peft_model = get_peft_model(base_model, lora_config) +``` + +Finally, there is also an option to set `initialize_lora_weights=False`. When choosing this option, the LoRA weights are initialized such that they do *not* result in an identity transform. This is useful for debugging and testing purposes and should not be used otherwise.