Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Weight LoRA #2406

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions docs/source/package_reference/weightlora.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
<!--Copyright 2023 The HuggingFace Team. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.

⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.

-->

# WeightLoRA

Weight LoRA is a less complex, but important, PEFT method that adds a weight $w_i$ to each LoRA adapter (here i -- adapter number). This is done in order to perform, in addition to the classical optimisation over all LoRAs $A_1, B_1, ..., A_n, B_n$, an alternative optimisation over a vector of weights $w := (w_1, ..., w_n)^T \in R^n$ with a wide variety of constraints. In our research paper, we consider two approaches: 1) the vector $w$ must be in simplex $\Delta_{n-1}$, and 2) the vector $w$ has only $K$ non-zero coordinates. Both of these methods solve the problem of finding the most important LoRA adapters in the model and concentrating training on them while disabling the rest.

The abstract from the paper is:

The widespread utilization of language models in modern applications is inconceivable without Parameter-Efficient Fine-Tuning techniques, such as low-rank adaptation (LoRA), which adds trainable adapters to selected layers. Although LoRA may obtain accurate solutions, it requires significant memory to train large models and intuition on which layers to add adapters. In this paper, we propose a novel method, WeightLoRA, which overcomes this issue by adaptive selection of the most critical LoRA heads throughout the optimization process. As a result, we can significantly reduce the number of trainable parameters while maintaining the capability to obtain consistent or even superior metric values. Finally, we conduct experiments for the series of competitive benchmarks and DeBERTa and BART models, comparing our approach with the most popular LoRA modifications. The experimental results demonstrate the efficacy of WeightLoRA and the superior performance of WeightLoRA+ in comparison to the baselines in nearly all cases.

## WeightLoraConfig

[[autodoc]] tuners.weight_lora.config.WeightLoraConfig

## WeightLoraModel

[[autodoc]] tuners.weight_lora.model.WeightLoraModel
4 changes: 4 additions & 0 deletions src/peft/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@
VeraModel,
XLoraConfig,
XLoraModel,
WeightLoraConfig,
WeightLoraModel,
get_eva_state_dict,
initialize_lora_eva_weights,
)
Expand Down Expand Up @@ -188,6 +190,8 @@
"VeraModel",
"XLoraConfig",
"XLoraModel",
"WeightLoraConfig",
"WeightLoraModel",
"bloom_model_postprocess_past_key_value",
"cast_mixed_precision_params",
"get_eva_state_dict",
Expand Down
3 changes: 3 additions & 0 deletions src/peft/tuners/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from .vblora import VBLoRAConfig, VBLoRAModel
from .vera import VeraConfig, VeraModel
from .xlora import XLoraConfig, XLoraModel
from .weight_lora import WeightLoraConfig, WeightLoraModel


__all__ = [
Expand Down Expand Up @@ -97,6 +98,8 @@
"VeraModel",
"XLoraConfig",
"XLoraModel",
"WeightLoraConfig",
"WeightLoraModel",
"get_eva_state_dict",
"initialize_lora_eva_weights",
]
23 changes: 23 additions & 0 deletions src/peft/tuners/weight_lora/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Copyright 2023-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from peft.utils import register_peft_method

from .config import WeightLoraConfig
from .layer import Linear, WeightLoraLayer
from .model import WeightLoraModel

__all__ = ["Linear", "WeightLoraConfig", "WeightLoraLayer", "WeightLoraModel"]

register_peft_method(name="weightlora", config_cls=WeightLoraConfig, model_cls=WeightLoraModel, prefix="weight_lora_", is_mixed_compatible=True)
110 changes: 110 additions & 0 deletions src/peft/tuners/weight_lora/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
# Copyright 2023-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass, field
from typing import List, Optional, Union

from peft.tuners.lora import LoraConfig
from peft.utils import PeftType


@dataclass
class WeightLoraConfig(LoraConfig):
"""
Configuration class of [`WeightLoraModel`].

Args:
r (`int`):
Lora rank.
lora_alpha (`int`):
The alpha parameter for Lora scaling.
rank_dropout (`float`):
The dropout probability for rank dimension during training.
module_dropout (`float`):
The dropout probability for disabling Lora modules during training.
target_modules (`Optional[Union[List[str], str]]`):
The names of the modules to apply the adapter to. If this is specified, only the modules with the specified
names will be replaced. When passing a string, a regex match will be performed. When passing a list of
strings, either an exact match will be performed or it is checked if the name of the module ends with any
of the passed strings. If this is specified as 'all-linear', then all linear modules are chosen,
excluding the output layer. If this is not specified, modules will be chosen according to the model
architecture. If the architecture is not known, an error will be raised -- in this case, you should specify
the target modules manually.
init_weights (`bool`):
Whether to perform initialization of adapter weights. This defaults to `True`, passing `False` is
discouraged.
layers_to_transform (`Union[List[int], int]`):
The layer indices to transform. If a list of ints is passed, it will apply the adapter to the layer indices
that are specified in this list. If a single integer is passed, it will apply the transformations on the
layer at this index.
layers_pattern (`str`):
The layer pattern name, used only if `layers_to_transform` is different from `None`.
rank_pattern (`dict`):
The mapping from layer names or regexp expression to ranks which are different from the default rank
specified by `r`.
alpha_pattern (`dict`):
The mapping from layer names or regexp expression to alphas which are different from the default alpha
specified by `alpha`.
modules_to_save (`Optional[List[str]]`):
List of modules apart from adapter layers to be set as trainable and saved in the final checkpoint.
"""

r: int = field(default=8, metadata={"help": "Lora rank"})
lora_alpha: int = field(default=8, metadata={"help": "Lora alpha"})
rank_dropout: float = field(
default=0.0, metadata={"help": "The dropout probability for rank dimension during training"}
)
module_dropout: float = field(
default=0.0, metadata={"help": "The dropout probability for disabling Lora modules during training"}
)
target_modules: Optional[Union[List[str], str]] = field(
default=None,
metadata={
"help": "List of module names or regex expression of the module names to replace with Lora."
"For example, ['q', 'v'] or '.*decoder.*(SelfAttention|EncDecAttention).*(q|v)$' "
"This can also be a wildcard 'all-linear' which matches all linear/Conv1D layers except the output layer."
},
)
init_weights: bool = field(
default=True,
metadata={
"help": (
"Whether to initialize the weights of the Lora layers with their default initialization. Don't change "
"this setting, except if you know exactly what you're doing."
),
},
)
layers_to_transform: Optional[Union[List[int], int]] = field(
default=None,
metadata={
"help": "The layer indexes to transform, is this argument is specified, PEFT will transform only the layers indexes that are specified inside this list. If a single integer is passed, PEFT will transform only the layer at this index."
},
)
layers_pattern: Optional[str] = field(
default=None,
metadata={
"help": "The layer pattern name, used only if `layers_to_transform` is different to None and if the layer pattern is not in the common layers pattern."
},
)
modules_to_save: Optional[List[str]] = field(
default=None,
metadata={
"help": "List of modules apart from Lora layers to be set as trainable and saved in the final checkpoint. "
"For example, in Sequence Classification or Token Classification tasks, "
"the final layer `classifier/score` are randomly initialized and as such need to be trainable and saved."
},
)

def __post_init__(self):
self.peft_type = PeftType.WEIGHTLORA
Loading