Skip to content

Commit

Permalink
FIX: Error with OLoRA init when using bnb (huggingface#2011)
Browse files Browse the repository at this point in the history
  • Loading branch information
BenjaminBossan authored Sep 3, 2024
1 parent 01275b4 commit 37b9c5c
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 7 deletions.
36 changes: 29 additions & 7 deletions src/peft/tuners/lora/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from transformers.pytorch_utils import Conv1D

from peft.tuners.tuners_utils import BaseTunerLayer, check_adapters_to_merge
from peft.utils.integrations import dequantize_module_weight, gather_params_ctx
from peft.utils.integrations import dequantize_module_weight, gather_params_ctx, get_bnb_param_type
from peft.utils.other import transpose

from .config import LoraConfig
Expand Down Expand Up @@ -167,11 +167,16 @@ def reset_lora_parameters(self, adapter_name, init_lora_weights):
nn.init.normal_(self.lora_embedding_B[adapter_name])

def olora_init(self, adapter_name):
dtype = self.get_base_layer().weight.dtype
if dtype in [torch.int8, torch.uint8]:
weight_tensor = dequantize_module_weight(self.get_base_layer())
base_layer = self.get_base_layer()
orig_weight = base_layer.weight
bnb_param_type = get_bnb_param_type(orig_weight)
dtype = orig_weight.dtype

if bnb_param_type:
# check without importing bitsandbytes and robust to bnb_4bit_quant_storage=float*
weight_tensor = dequantize_module_weight(base_layer)
elif dtype in [torch.float32, torch.float16, torch.bfloat16]:
weight_tensor = self.get_base_layer().weight
weight_tensor = orig_weight
else:
raise TypeError(f"Unsupported data type for the base layer. Got {dtype}.")

Expand All @@ -186,8 +191,25 @@ def olora_init(self, adapter_name):
self.lora_B[adapter_name].weight.data = Qr.contiguous()

weight_tensor.data -= scale_factor * self.lora_B[adapter_name].weight @ self.lora_A[adapter_name].weight
weight_tensor = weight_tensor.to(dtype)
self.get_base_layer().weight.data = weight_tensor
if bnb_param_type == "4bit":
weight_tensor = orig_weight.__class__(
weight_tensor,
quant_type=orig_weight.quant_type,
quant_storage=orig_weight.quant_storage,
compress_statistics=orig_weight.compress_statistics,
module=orig_weight.module,
).to(orig_weight.device)
base_layer.weight = weight_tensor
elif bnb_param_type == "8bit":
weight_tensor = orig_weight.__class__(
weight_tensor,
requires_grad=orig_weight.requires_grad,
has_fp16_weights=orig_weight.has_fp16_weights,
).to(orig_weight.device)
base_layer.weight = weight_tensor
else:
weight_tensor = weight_tensor.to(dtype)
base_layer.weight.data = weight_tensor

def pissa_init(self, adapter_name, init_lora_weights):
weight = self.get_base_layer().weight
Expand Down
12 changes: 12 additions & 0 deletions src/peft/utils/integrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from contextlib import contextmanager
from typing import Literal

import packaging.version
import torch
Expand Down Expand Up @@ -104,3 +107,12 @@ def dequantize_bnb_weight(weight: torch.nn.Parameter, state=None):
if is_cpu:
dequantized = dequantized.to(device)
return dequantized


def get_bnb_param_type(param: torch.nn.Parameter) -> Literal[False, "4bit", "8bit"]:
"""Returns '4bit' or '8bit' if bitsandbytes parameter, else False"""
if param.__class__.__name__ == "Params4bit":
return "4bit"
if param.__class__.__name__ == "Int8Params":
return "8bit"
return False
35 changes: 35 additions & 0 deletions tests/test_gpu_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -1786,6 +1786,41 @@ def test_bloomz_olora_8bit(self, device, tmp_path):
# Same test as test_bloomz_olora_4bit but with 8 bits.
self.get_errors(bits=8, device=device, tmp_path=tmp_path)

@pytest.mark.parametrize("bits", [4, 8])
def test_olora_with_quantized_model(self, bits):
import bitsandbytes as bnb

# issue 1999
model_id = "hf-internal-testing/tiny-random-OPTForCausalLM"
if bits == 4:
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_quant_storage=torch.float16,
bnb_4bit_use_double_quant=True,
)
elif bits == 8:
bnb_config = BitsAndBytesConfig(load_in_8bit=True)
else:
raise ValueError("bits must be 4 or 8")

model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=bnb_config)
model = prepare_model_for_kbit_training(model)
config = LoraConfig(init_lora_weights="olora")
model = get_peft_model(model, config)

# check that the correct type is used for the weights
base_layer = model.base_model.model.model.decoder.layers[0].self_attn.v_proj.base_layer.weight
if bits == 4:
assert isinstance(base_layer, bnb.nn.modules.Params4bit)
else:
assert isinstance(base_layer, bnb.nn.modules.Int8Params)

inputs = torch.arange(10).unsqueeze(0).to(model.device)
logits = model(inputs).logits # does not raise
assert torch.isfinite(logits).all()


@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires a GPU")
class TestLoftQ:
Expand Down

0 comments on commit 37b9c5c

Please sign in to comment.