diff --git a/examples/safelora/README.md b/examples/safelora/README.md new file mode 100644 index 0000000000..57e5d1cc26 --- /dev/null +++ b/examples/safelora/README.md @@ -0,0 +1,46 @@ +# Safe LoRA + +The official code of Safe LoRA: The Silver Lining of Reducing Safety Risks when Fine-tuning Large Language Models + +## Quick Start + +### Get Weights with SafeLoRA +Please import the `SafeLoraConfig` and `apply_safelora` first. +Then, fill in the paths for the base, aligned, and PEFT models according to your needs. There are two types of `select_layers_type`: `threshold` and `number`. The `threshold` type will determine how many layers will be projected based on the value you set. The `number` type directly specifies the number of projected layers. `save_weights=True` will save and replace your original peft model weights. + +```python + +from peft.utils.safelora import SafeLoraConfig, apply_safelora + +peft_path = "../finetuneLLM/finetuned_models/samsumBad-7b-fp16-peft-seed-42" +config = SafeLoraConfig( + base_model_path="meta-llama/Llama-2-7b-hf", + aligned_model_path="TheBloke/Llama-2-7B-Chat-fp16", + peft_model_path=peft_path, + device="cuda", + select_layers_type="threshold", + save_weights=True, +) +final_lora_weight = apply_safelora(config) + +``` +### Save SafeLoRA's Weights +If you set `save_weights=False`, but still want to save the weights, you can use the following code. + +```python +from safetensors.torch import save_file + +path = ... # your PEFT model path +save_file(final_lora_weight, os.path.join(path, "adapter_model.safetensors")) +``` + +### Use SafeLoRA Model +Next, you can load the base model of the Peft Model along with the Peft model itself to use a model that has both downstream task utility and alignment. + +```python +from transformers import AutoModelForCausalLM +from peft import PeftConfig, PeftModel + +model = AutoModelForCausalLM.from_pretrained() +model = PeftModel.from_pretrained(model, ) +``` diff --git a/examples/safelora/safelora_inference.py b/examples/safelora/safelora_inference.py new file mode 100644 index 0000000000..2b8f24d9bb --- /dev/null +++ b/examples/safelora/safelora_inference.py @@ -0,0 +1,28 @@ +import os + +from safetensors.torch import save_file +from transformers import AutoModelForCausalLM + +from peft import PeftModel +from peft.utils.safelora import SafeLoraConfig, apply_safelora + + +peft_path = "../finetuneLLM/finetuned_models/samsumBad-7b-fp16-peft-seed-42" +config = SafeLoraConfig( + base_model_path="meta-llama/Llama-2-7b-hf", + aligned_model_path="TheBloke/Llama-2-7B-Chat-fp16", + peft_model_path=peft_path, + device="cuda", + select_layers_type="threshold", + save_weights=True, +) + +final_lora_weight = apply_safelora(config) + +save_file( + final_lora_weight, + f"{os.path.join('../finetuneLLM/finetuned_models/samsumBad-7b-fp16-peft-seed-42', 'adapter_model.safetensors')}", +) + +model = AutoModelForCausalLM.from_pretrained("TheBloke/Llama-2-7B-Chat-fp16") +model = PeftModel.from_pretrained(model, peft_path) diff --git a/src/peft/utils/loftq_utils.py b/src/peft/utils/loftq_utils.py index fda66949f4..b052e37c60 100644 --- a/src/peft/utils/loftq_utils.py +++ b/src/peft/utils/loftq_utils.py @@ -23,7 +23,7 @@ import torch from huggingface_hub import snapshot_download -from huggingface_hub.errors import HFValidationError, LocalEntryNotFoundError +from huggingface_hub.errors import HFValidationError from safetensors import SafetensorError, safe_open from transformers.utils import cached_file from transformers.utils.hub import get_checkpoint_shard_files @@ -267,26 +267,31 @@ class _SafetensorLoader: """ - def __init__(self, peft_model, model_path): + def __init__(self, peft_model_or_model_id, model_path=None, local_files_only=True): if model_path is None: - try: - model_path = snapshot_download(peft_model.base_model.config._name_or_path, local_files_only=True) - except (AttributeError, HFValidationError) as exc: - raise ValueError( - "The provided model does not appear to be a transformers model or is a local model. In this case, " - "you must pass the model_path argument that points to the safetensors file." - ) from exc - except LocalEntryNotFoundError as exc: - raise ValueError( - "The model.safetensors file must be present on disk, but it could not be found." - ) from exc + if isinstance(peft_model_or_model_id, str): + name_or_path = peft_model_or_model_id + base_model_prefix = None + else: + name_or_path = peft_model_or_model_id + base_model_prefix = getattr(peft_model_or_model_id.get_base_model(), "base_model_prefix", None) + if os.path.exists(name_or_path): + model_path = name_or_path + else: + try: + model_path = snapshot_download(name_or_path, local_files_only=local_files_only) + except (AttributeError, HFValidationError): + raise ValueError( + "The provided model does not appear to be a transformers model or is a local model. In this case, " + "you must pass the model_path argument that points to the safetensors file." + ) suffix = "model.safetensors" if not model_path.endswith(suffix): model_path = os.path.join(model_path, suffix) self.model_path = model_path - self.base_model_prefix = getattr(peft_model.get_base_model(), "base_model_prefix", None) + self.base_model_prefix = base_model_prefix self.prefix = "base_model.model." self.is_sharded = False self.weight_map = None diff --git a/src/peft/utils/safelora.py b/src/peft/utils/safelora.py new file mode 100644 index 0000000000..780ba2fd41 --- /dev/null +++ b/src/peft/utils/safelora.py @@ -0,0 +1,227 @@ +# Copyright 2024-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Reference paper: https://arxiv.org/abs/2405.16833 + + +import copy +import os +from dataclasses import dataclass, field +from typing import Literal + +import torch +from safetensors import safe_open +from safetensors.torch import save_file + +from peft import PeftConfig + +from .loftq_utils import _SafetensorLoader as SafetensorLoader + + +@dataclass +class SafeLoraConfig: + """ + This is the configuration class to store the configuration of a safeLora. + + + Args: + + base_model_path (`str`): The path of the base model for obtaining the aligned matrix. + + aligned_model_path (`str`): The path of the aligned model for obtaining the aligned matrix. + + peft_model_path (`str`): The path of the LoRA weights and config. + + select_layers_type (`Literal["threshold", "number"]`): How to select projection layers? options: [threshold, number] + + threshold (`float`): The threshold of cosine similarity for selecting projected layers. + + num_proj_layers (`int`): The number of projected layers. + + device (`str`): Device that is used for SafeLoRA (cuda or cpu). + + save_weights (`bool`): Replacing and saving SafeLoRA weights to the original LoRA file. + + local_files_only (`bool`): Using for snapshot_download. + + dtype (`torch.dtype`): Data type for model weights, e.g., torch.float32 or torch.bfloat16. + + """ + + base_model_path: str = field( + default="meta-llama/Llama-2-7b-hf", + metadata={"help": "The path of the base model for obtaining the aligned matrix."}, + ) + + aligned_model_path: str = field( + default="TheBloke/Llama-2-7B-Chat-fp16", + metadata={"help": "The path of the aligned model for obtaining the aligned matrix."}, + ) + + peft_model_path: str = field( + default="LisaSchunke/llama-2-7b-peft-finetuned-20000-dataset", + metadata={"help": "The path of the LoRA wieghts and configs."}, + ) + + select_layers_type: Literal["threshold", "number"] = field( + default="number", + metadata={"help": "How to select projection layers? options: [threshold, number]."}, + ) + + threshold: float = field( + default=0.5, + metadata={"help": "The threshold of cosine similarity for selecting projected layers."}, + ) + + num_proj_layers: int = field( + default=10, + metadata={"help": "The number of projected layers."}, + ) + + device: str = field( + default="cuda", + metadata={"help": "Device is used in SafeLoRA. (cuda or cpu)"}, + ) + + save_weights: bool = field( + default=True, + metadata={"help": "Replacing and saving SafeLoRA weights to the original LoRA file."}, + ) + local_files_only: bool = field( + default=False, + metadata={"help": "Using for snapshot_download."}, + ) + + dtype: torch.dtype = field( + default=torch.bfloat16, + metadata={ + "help": "Data type for model weights, e.g., torch.float32 or torch.bfloat16. If your device is CPU, you should use torch.float32." + }, + ) + + def __post_init__(self): + if self.base_model_path is None: + raise ValueError("base_model_path cannot be None.") + if self.aligned_model_path is None: + raise ValueError("aligned_model_path cannot be None.") + if self.peft_model_path is None: + raise ValueError("peft_model_path cannot be None.") + + +def get_aligned_matrix(base_model_path, aligned_model_path, peft_config, safelora_config): + """ + Get projected matrix by following the config (target_modules) from the peft model. + The dimensions between the base model's weights and the aligned model's weights should be the same. + """ + sl_align = SafetensorLoader(aligned_model_path, local_files_only=safelora_config.local_files_only) + sl_base = SafetensorLoader(base_model_path, local_files_only=safelora_config.local_files_only) + + base_model_parameters = [ + name for name in sl_base.weight_map.keys() if any(v in name for v in list(peft_config.target_modules)) + ] + + align_model_parameters = [ + name for name in sl_align.weight_map.keys() if any(v in name for v in list(peft_config.target_modules)) + ] + safety_vector = [] + for name_base, name_align in zip(base_model_parameters, align_model_parameters): + if sl_base.get_tensor(name_base).shape != sl_align.get_tensor(name_align).shape: + raise ValueError( + "The dimensions of the base model's weight should be the same with the aligned model's weight." + ) + if (sl_base.get_tensor(name_base) == sl_align.get_tensor(name_align)).all(): + raise ValueError("The weights of the base Model and the aligned Model should be different.") + vec = sl_base.get_tensor(name_base) - sl_align.get_tensor(name_align) + vec = vec.to(safelora_config.dtype).to(safelora_config.device) + vec = torch.mm(vec, vec.t()) / torch.norm(vec) + safety_vector.append((vec).detach().cpu()) + return safety_vector + + +def project_weights(safelora_config, peft_weights, v): + ori_peft_weights = copy.deepcopy(peft_weights) + vars_names_LoRA_A = [name for name in peft_weights.keys() if "lora_A" in name] + vars_names_LoRA_B = [name for name in peft_weights.keys() if "lora_B" in name] + num_projected_layers = 0 + dis = [] + cos_total = [] + for idx, (name_A, name_B) in enumerate(zip(vars_names_LoRA_A, vars_names_LoRA_B)): + A = ori_peft_weights[name_A] + P = v[idx].to(safelora_config.dtype).to(safelora_config.device) + W = torch.mm(P, ori_peft_weights[name_B]) + fW = torch.mm(W, A) + ori = torch.mm(ori_peft_weights[name_B], A) + cos = torch.nn.functional.cosine_similarity(fW.reshape(1, -1), ori.reshape(1, -1)) + cos_total.append(cos.item()) + if cos <= safelora_config.threshold: + num_projected_layers += 1 + peft_weights[name_B] = W + else: + peft_weights[name_B] = ori_peft_weights[name_B] + + dist = 1 / (1 + torch.norm(peft_weights[name_B].reshape(1, -1) - W.reshape(1, -1))) + + dis.append(dist.item()) + return peft_weights, cos_total + + +def apply_safelora(safelora_config: SafeLoraConfig): + """ + + The official code of Safe LoRA: The Silver Lining of Reducing Safety Risks when Finetuning Large Language Models: https://arxiv.org/abs/2405.16833 + + After fine-tuning large language models (LLMs) using LoRA, the alignment of the resulting models may decrease. + Therefore, applying `apply_safelora()` is intended to help preserve the alignment of the final models. + + It is important to note that the model weights of the aligned model and the base model must be of the same size. + + + Example: + + from peft.utils.safelora import SafeLoraConfig, apply_safelora + + config = SafeLoraConfig(base_model_path='meta-llama/Llama-2-7b-hf',\ + aligned_model_path='TheBloke/Llama-2-7B-Chat-fp16', + peft_model_path = 'LisaSchunke/llama-2-7b-peft-finetuned-20000-dataset', + device='cuda', + select_layers_type='threshold', + save_weights=True) + + final_lora_weight = apply_safelora(config) + + """ + + peft_config = PeftConfig.from_pretrained(safelora_config.peft_model_path) + + projected_matrix = get_aligned_matrix( + safelora_config.base_model_path, safelora_config.aligned_model_path, peft_config, safelora_config + ) + + with safe_open( + f"{os.path.join(safelora_config.peft_model_path, 'adapter_model.safetensors')}", + framework="pt", + device=safelora_config.device, + ) as f: + peft_weights = {name: f.get_tensor(name).to(safelora_config.dtype) for name in f.keys()} + if safelora_config.select_layers_type == "threshold": + final_weights, _ = project_weights(safelora_config, peft_weights, projected_matrix) + elif safelora_config.select_layers_type == "number": + _, cos = project_weights(safelora_config, peft_weights, projected_matrix) + thrs = torch.sort(torch.Tensor(cos))[0][: safelora_config.num_proj_layers][-1] + safelora_config.threshold = thrs + final_weights, _ = project_weights(safelora_config, peft_weights, projected_matrix) + + if safelora_config.save_weights: + save_file(final_weights, f"{os.path.join(safelora_config.peft_model_path, 'adapter_model.safetensors')}") + + return final_weights