-
Notifications
You must be signed in to change notification settings - Fork 28.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
model.gradient_checkpointing_enable() makes loss.requires_grad be False #35826
Comments
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
Hi 👋, Just wanted to follow up on this issue. I understand you might be busy, but I'm curious if there's any update or if there's additional information I can provide to help resolve it? 🙏 Thanks for your time! |
Thanks for the report ! I'll have a look soon but could you try to see if the problem comes from peft ? You can try to use a smaller model for debugging if you don't have enough ram. |
Thanks for your reply! 😄 import os
os.environ["CUDA_VISIBLE_DEVICES"] = "4"
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import get_peft_model, LoraConfig, TaskType
import copy
from types import MethodType
from functools import partial
import inspect
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple
from transformers import PreTrainedModel
def main():
train_data = {"input": "input test", "output": "output test"}
model_name = "/workspace/model/CodeLlama-13b-Instruct-hf"
output_dir = "./test_debug"
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16,device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
# set the pad token of the model's configuration
model.config.pad_token_id = model.config.eos_token_id
# return
# if not getattr(model, "supports_gradient_checkpointing", False):
# print("Current model does not support gradient checkpointing.")
# else:
# # use_reentrant=False might increase VRAM usage (have not been empirically verified yet)
# # According to: https://github.com/huggingface/transformers/issues/28339
# model.gradient_checkpointing_enable = MethodType(_gradient_checkpointing_enable, model)
# model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": True})
# setattr(model.config, "use_cache", False) # turn off when gradient checkpointing is enabled
# print("Gradient checkpointing enabled.")
input_ids = tokenizer.encode(train_data["input"])
output_ids = tokenizer.encode(train_data["output"])
model_inputs_output = input_ids + output_ids + [tokenizer.eos_token_id]
model_inputs_output = torch.tensor(model_inputs_output, dtype=torch.int64)
labels = copy.deepcopy(model_inputs_output)
labels[: len(input_ids)] = -1 #
example_mask = model_inputs_output.ge(0)
label_mask = labels.ge(0)
model_inputs_output[~example_mask] = 0
labels[~label_mask] = -100
train_dataset = {
"input_ids": model_inputs_output.unsqueeze(0).to("cuda"),
"attention_mask": example_mask.unsqueeze(0).to("cuda"),
"labels": labels.unsqueeze(0).to("cuda")
}
# lora_config = LoraConfig(
# r=8,
# lora_alpha=16,
# target_modules=["q_proj", "gate_proj", "v_proj", "o_proj", "up_proj", "k_proj", "down_proj"], # 与llama-factory一致
# lora_dropout=0.05,
# task_type= TaskType.CAUSAL_LM
# )
# model = get_peft_model(model, lora_config)
model.gradient_checkpointing_enable()
model.train()
# model.print_trainable_parameters()
model.to("cuda")
output = model(**train_dataset)
loss = output["loss"]
print(f"loss: {loss.requires_grad}")
if __name__ == "__main__":
main() Under the current circumstances: When not utilizing PEFT, both the model.print_trainable_parameters() method and the custom _gradient_checkpointing_enable function yield However, when employing PEFT, only the custom _gradient_checkpointing_enable function Is this an problem or a bug caused by a code conflict? |
Sorry for closing issue due to mis-touch 😭 I compared the in in above code is The second parameter Debug is here, and I'm not sure why that would lead to the existing situation. Thanks for your time! 👍 |
Thanks for all the details, I found the issue that you are experiencing. There is an issue when you try to enable gradient checkpointing after you created the peft model. As a temporary fix, you should enable it before. This PR will fix your issue huggingface/peft#2398 |
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
System Info
Python 3.9.19
transformers 4.42.0
torch 2.2.2+cu118
peft 0.12.0
Who can help?
No response
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
When I tried using model.gradient_checkpointing_enable() to reduce memory consumption during training, I encountered an error: "RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn." After troubleshooting, I found that the issue seems to be caused by loss.requires_grad being set to False, which prevents backpropagation. The following is the reproducible code to directly obtain
loss.requires_grad False
Output is
This is confusing because
model.gradient_checkpointing_enable()
is designed to reduce memory consumption, but ifloss.requires_grad
is set toFalse
, it disrupts the normal training process. Meanwhile, when I use similar code from LLama-factory to achieve the effect of model.gradient_checkpointing_enable(), I find thatloss.requires_grad
isTrue
. Below is the code:output is
Expected behavior
I am not entirely sure if this is a bug in the implementation of
model.gradient_checkpointing_enable()
. If it is not, please feel free to close the issue directly and let me know. Thank you for taking the time to look into this issue :)The text was updated successfully, but these errors were encountered: