Skip to content

add loralinear #10385

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Apr 23, 2025
6 changes: 6 additions & 0 deletions paddlenlp/peft/lora/lora_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@
pissa: bool = False,
lora_use_mixer: bool = False,
use_mora: bool = False,
mp_moe: bool = False,
is_distributed: bool = False,
**kwargs
):
nn.Linear.__init__(self, in_features, out_features, **kwargs)
Expand Down Expand Up @@ -143,6 +145,10 @@
self.weight.stop_gradient = True
self._use_quick_lora = use_quick_lora and lora_dropout == 0.0
self.disable_lora = False
if mp_moe or is_distributed:
for p in self.parameters():
p.is_distributed = is_distributed
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

用于EP,is_distributed标识训练开始的时候不要同步参数和mp_moe用于uc

p.mp_moe = mp_moe

Check warning on line 151 in paddlenlp/peft/lora/lora_layers.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/peft/lora/lora_layers.py#L149-L151

Added lines #L149 - L151 were not covered by tests

def pissa_init(self, rank):
weight = self.weight
Expand Down
163 changes: 33 additions & 130 deletions paddlenlp/peft/lora/lora_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import tempfile
from collections import OrderedDict
from functools import partial
from typing import Dict, List, Union
from typing import Dict, Union

import aistudio_sdk
import numpy as np
Expand Down Expand Up @@ -97,38 +97,29 @@
LoRALinear = lora_layers["LoRALinear"]
RowParallelLoRALinear = lora_layers["RowParallelLoRALinear"]
RowSequenceParallelLoRALinear = lora_layers["RowSequenceParallelLoRALinear"]

from ...quantization.quantization_linear import (
ColumnParallelQuantizationLinear,
QuantizationLinear,
RowParallelQuantizationLinear,
)
from .lora_quantization_layers import (
ColumnParallelQuantizationLoRALinear,
QuantizationLoRALinear,
RowParallelQuantizationLoRALinear,
)

AVAILABLE_LAYERS = [
ColumnParallelLoRALinear,
ColumnSequenceParallelLoRALinear,
LoRAConv2D,
LoRALinear,
RowParallelLoRALinear,
RowSequenceParallelLoRALinear,
ColumnParallelQuantizationLoRALinear,
QuantizationLoRALinear,
RowParallelQuantizationLoRALinear,
]
try:
from ...quantization.quantization_linear import (
ColumnParallelQuantizationLinear,
QuantizationLinear,
RowParallelQuantizationLinear,
)
from .lora_quantization_layers import (
ColumnParallelQuantizationLoRALinear,
QuantizationLoRALinear,
RowParallelQuantizationLoRALinear,
)

AVAILABLE_LAYERS += [
ColumnParallelQuantizationLoRALinear,
QuantizationLoRALinear,
RowParallelQuantizationLoRALinear,
]
except:
QuantizationLinear = None
ColumnParallelQuantizationLinear = None
RowParallelQuantizationLinear = None
QuantizationLoRALinear = None
ColumnParallelQuantizationLoRALinear = None
RowParallelQuantizationLoRALinear = None


class LoRAModel(nn.Layer):
Expand Down Expand Up @@ -426,11 +417,6 @@

if self.is_pipelinemodel:
self.model._single_to_pp_mapping = None
if self.quantized and merge_tensor_parallel and self.lora_config.tensor_parallel_degree > 1:
merge_tensor_parallel = False
logger.warning(
"Quantized strategy does not support merge_tensor_parallel. Set merge_tensor_parallel to False."
)
if self.is_pipelinemodel and merge_tensor_parallel and self.lora_config.tensor_parallel_degree > 1:
merge_tensor_parallel = False
logger.warning(
Expand Down Expand Up @@ -479,7 +465,7 @@
model_config_to_save.tensor_parallel_degree = -1
model_config_to_save.save_pretrained(save_directory)

def _find_and_replace_module(self, model, module_name, lora_config, enable_lora):
def _find_and_replace_module(self, model, module_name, lora_config):
parent_module = model
attribute_chain = module_name.split(".")
for name in attribute_chain[:-1]:
Expand All @@ -500,14 +486,10 @@
use_quick_lora=lora_config.use_quick_lora,
lora_use_mixer=lora_config.lora_use_mixer,
use_mora=lora_config.use_mora,
mp_moe=getattr(module.weight, "mp_moe", False),
is_distributed=getattr(module.weight, "is_distributed", False),
)
# Hack for mp group moe, need to find a better solution.
if getattr(module.weight, "mp_moe", False):
lora_module.lora_A.mp_moe = True
lora_module.lora_B.mp_moe = True
lora_module.lora_A.is_distributed = True
lora_module.lora_B.is_distributed = True
if isinstance(module, nn.Conv2D):
elif isinstance(module, nn.Conv2D):
lora_module = LoRAConv2D(
in_channels=module._in_channels,
out_channels=module._out_channels,
Expand Down Expand Up @@ -621,68 +603,20 @@
self.add_lora_split_mapping(module_name + ".weight_quanter._scale", is_column=False)
self.add_lora_split_mapping(module_name + ".activation_quanter._scale", is_column=False)
self.add_lora_split_mapping(module_name + ".activation_quanter.quanter._scale", is_column=False)
elif QuantizationLinear is not None and isinstance(module, QuantizationLinear):
lora_module = QuantizationLoRALinear(
in_features=module.in_features,
out_features=module.out_features,
quant_algo=module.quant_algo,
dtype=module._dtype,
bias_attr=False if module.bias is None else None,
block_size=module.block_size,
double_quant_block_size=module.double_quant_block_size,
double_quant=module.double_quant,
r=lora_config.r,
lora_alpha=lora_config.lora_alpha,
lora_dropout=lora_config.lora_dropout,
)
self.quantized = True
elif ColumnParallelQuantizationLinear is not None and isinstance(module, ColumnParallelQuantizationLinear):
lora_module = ColumnParallelQuantizationLoRALinear(
in_features=module.in_features,
out_features=module.out_features,
quant_algo=module.quant_algo,
dtype=module._dtype,
bias_attr=False if module.bias is None else None,
gather_output=module.gather_output,
r=lora_config.r,
lora_alpha=lora_config.lora_alpha,
lora_dropout=lora_config.lora_dropout,
lora_A_weight_attr=paddle.ParamAttr(
initializer=nn.initializer.KaimingUniform(negative_slope=math.sqrt(5), nonlinearity="leaky_relu")
),
)
self.quantized = True
elif RowParallelQuantizationLinear is not None and isinstance(module, RowParallelQuantizationLinear):
lora_module = RowParallelQuantizationLoRALinear(
in_features=module.in_features,
out_features=module.out_features,
quant_algo=module.quant_algo,
dtype=module._dtype,
bias_attr=False if module.bias is None else None,
input_is_parallel=module.input_is_parallel,
r=lora_config.r,
lora_alpha=lora_config.lora_alpha,
lora_dropout=lora_config.lora_dropout,
)
self.quantized = True
elif isinstance(module, QuantizationLinear):
lora_module = QuantizationLoRALinear(module, lora_config)

Check warning on line 607 in paddlenlp/peft/lora/lora_model.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/peft/lora/lora_model.py#L607

Added line #L607 was not covered by tests
elif isinstance(module, ColumnParallelQuantizationLinear):
lora_module = ColumnParallelQuantizationLoRALinear(module, lora_config)

Check warning on line 609 in paddlenlp/peft/lora/lora_model.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/peft/lora/lora_model.py#L609

Added line #L609 was not covered by tests
elif isinstance(module, RowParallelQuantizationLinear):
lora_module = RowParallelQuantizationLoRALinear(module, lora_config)

Check warning on line 611 in paddlenlp/peft/lora/lora_model.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/peft/lora/lora_model.py#L611

Added line #L611 was not covered by tests
if lora_module is None:
raise ValueError(
f"LoRA strategy only supports paddle.nn.Linear or paddle.distributed.fleet.meta_parallel.ColumnParallelLinear or paddlenlp.transformers.sequence_utils. {module}({module_name} {type(module).__name__}) is not supported。"
)
if getattr(lora_module, "quant_weight", None) is not None:
lora_module.quant_weight = module.quant_weight
if getattr(lora_module, "quant_scale", None) is not None:
lora_module.quant_scale = module.quant_scale
if getattr(lora_module, "qquant_scale", None) is not None:
lora_module.qquant_scale = module.qquant_scale
if getattr(lora_module, "double_quant_scale", None) is not None:
lora_module.double_quant_scale = module.double_quant_scale
if getattr(lora_module, "quant_sacle_offset", None) is not None:
lora_module.quant_sacle_offset = module.quant_sacle_offset
else:
if getattr(lora_module, "weight", None) is not None:
lora_module.weight = module.weight
if module.bias is not None:
lora_module.bias = module.bias
if module.bias is not None:
lora_module.bias = module.bias
setattr(parent_module, attribute_chain[-1], lora_module)

def _find_and_restore_module(self, module_name):
Expand Down Expand Up @@ -768,45 +702,14 @@

if lora_config.target_modules is None:
return model
elif isinstance(lora_config.target_modules, str):
target_modules = [lora_config.target_modules]
if lora_config.enable_lora_list is None or (
isinstance(lora_config.enable_lora_list, List)
and all(isinstance(item, bool) for item in lora_config.enable_lora_list)
):
enable_lora_list = [lora_config.enable_lora_list]
else:
raise TypeError(
f"Invalid `enable_lora_list` value: {lora_config.enable_lora_list}. Since `target_modules` is `str`, `enable_lora_list` must be `None` or `List[bool]`"
)
else:
target_modules = lora_config.target_modules
if lora_config.enable_lora_list is None:
enable_lora_list = [None for _ in range(len(target_modules))]
elif isinstance(lora_config.enable_lora_list, List):
enable_lora_list = lora_config.enable_lora_list
if len(enable_lora_list) != len(target_modules):
raise TypeError(
f"Invalid lora_config.enable_lora_list value: {lora_config.enable_lora_list}. Since lora_config.target_modules is `List[str]`, `enable_lora_list` should have the same length as `target_modules`"
)
for enable_lora in enable_lora_list:
if not (
enable_lora is None
or (isinstance(enable_lora, List) and all(isinstance(item, bool) for item in enable_lora))
):
raise TypeError(
f"Invalid `enable_lora_list` value: {lora_config.enable_lora_list}. Since `target_modules` is `List[str]`, `enable_lora_list` must be `None` or `List[Optional[List[bool]]]`"
)
else:
raise TypeError(
f"Invalid `enable_lora_list` value: {lora_config.enable_lora_list}. Since `target_modules` is `List[str]`, `enable_lora_list` must be `None` or `List[Optional[List[bool]]]`"
)
if isinstance(lora_config.target_modules, str):
lora_config.target_modules = [lora_config.target_modules]

Check warning on line 706 in paddlenlp/peft/lora/lora_model.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/peft/lora/lora_model.py#L706

Added line #L706 was not covered by tests

for target_module, enable_lora in zip(target_modules, enable_lora_list):
for target_module in lora_config.target_modules:
for i in model.named_sublayers():
module_name = i[0]
if re.fullmatch(target_module, module_name):
self._find_and_replace_module(model, module_name, lora_config, enable_lora)
self._find_and_replace_module(model, module_name, lora_config)
return model

def restore_original_model(self):
Expand Down
Loading