diff --git a/src/peft/peft_model.py b/src/peft/peft_model.py index e20e64df3e..470305930a 100644 --- a/src/peft/peft_model.py +++ b/src/peft/peft_model.py @@ -155,7 +155,8 @@ def from_pretrained(cls, model, model_id, **kwargs): f"Please check that the file {WEIGHTS_NAME} is present at {model_id}." ) - adapters_weights = torch.load(filename) + adapters_weights = torch.load( + filename, map_location=torch.device("cuda" if torch.cuda.is_available() else "cpu")) # load the weights into the model model = set_peft_model_state_dict(model, adapters_weights) if getattr(model, "hf_device_map", None) is not None: @@ -266,7 +267,12 @@ def print_trainable_parameters(self): trainable_params = 0 all_param = 0 for _, param in self.named_parameters(): - all_param += param.numel() + num_params = param.numel() + # if using DS Zero 3 and the weights are initialized empty + if num_params == 0 and hasattr(param, "ds_numel"): + num_params = param.ds_numel + + all_param += num_params if param.requires_grad: trainable_params += param.numel() print( diff --git a/src/peft/utils/other.py b/src/peft/utils/other.py index 3f4d2d5cd1..132b033484 100644 --- a/src/peft/utils/other.py +++ b/src/peft/utils/other.py @@ -30,7 +30,9 @@ def bloom_model_postprocess_past_key_value(past_key_values): return tuple(zip(keys, values)) -def prepare_model_for_int8_training(model, output_embedding_layer_name="lm_head"): +def prepare_model_for_int8_training( + model, output_embedding_layer_name="lm_head", use_gradient_checkpointing=True, layer_norm_names=["layer_norm"] +): r""" This method wrapps the entire protocol for preparing a model before running a training. This includes: 1- Cast the layernorm in fp32 2- making output embedding layer require grads 3- Add the upcasting of the lm @@ -48,20 +50,20 @@ def prepare_model_for_int8_training(model, output_embedding_layer_name="lm_head" if loaded_in_8bit: # cast layer norm in fp32 for stability for 8bit models - if param.ndim == 1 and "layer_norm" in name: + if param.ndim == 1 and any(layer_norm_name in name for layer_norm_name in layer_norm_names): param.data = param.data.to(torch.float32) - # For backward compatibility - if hasattr(model, "enable_input_require_grads"): - model.enable_input_require_grads() - else: + if loaded_in_8bit and use_gradient_checkpointing: + # For backward compatibility + if hasattr(model, "enable_input_require_grads"): + model.enable_input_require_grads() + else: - def make_inputs_require_grad(module, input, output): - output.requires_grad_(True) + def make_inputs_require_grad(module, input, output): + output.requires_grad_(True) - model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) + model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) - if loaded_in_8bit: # enable gradient checkpointing for memory efficiency model.gradient_checkpointing_enable()