You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
⚠️We cannot help you without you sharing reproducible code. Do not ignore this part :)
Steps to reproduce the behavior:
Run the cells in order
Model is currently set to full fine-tuning. Error does not appear with LoRA, so it appears to be with certain layers
Memory issues might occur in Google Colab
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
[<ipython-input-1-78cd311674a4>](https://localhost:8080/#) in <cell line: 191>()
200 task_type = dict_task_types[task_type],
201 dataset_name = dataset_name)
--> 202 obj.finetune_model()
3 frames
[<ipython-input-1-78cd311674a4>](https://localhost:8080/#) in finetune_model(self)
164 total_loss += loss.detach().float()
165 loss.backward()
--> 166 optimizer.step()
167 lr_scheduler.step()
168 optimizer.zero_grad()
[/usr/local/lib/python3.10/dist-packages/opacus/optimizers/optimizer.py](https://localhost:8080/#) in step(self, closure)
551 with torch.enable_grad():
552 closure()
--> 553 if self.pre_step():
554 return self.original_optimizer.step()
555 else:
[/usr/local/lib/python3.10/dist-packages/opacus/optimizers/optimizer.py](https://localhost:8080/#) in pre_step(self, closure)
536 if self.grad_samples is None or len(self.grad_samples) == 0:
537 return True
--> 538 self.clip_and_accumulate()
539 if self._check_skip_next_step():
540 self._is_last_step_skipped = True
[/usr/local/lib/python3.10/dist-packages/opacus/optimizers/optimizer.py](https://localhost:8080/#) in clip_and_accumulate(self)
442 g.reshape(len(g), -1).norm(2, dim=-1) for g in self.grad_samples
443 ]
--> 444 per_sample_norms = torch.stack(per_param_norms, dim=1).norm(2, dim=1)
445 per_sample_clip_factor = (
446 self.max_grad_norm / (per_sample_norms + 1e-6)
RuntimeError: stack expects each tensor to be equal size, but got [1] at entry 0 and [384] at entry 5
Expected behavior
No gradient mismatch should be there. A deeper dive into this seems to indicate a gradient mismatch, and it appears that during the computation of the activations, the mismatch appears (this might have something to do with the forward hooks). The batch size is in the first dimension for all data points. Error seems to be appearing because of relative_attention_bias related weights in the model - on freezing it, training proceeds as it should (there is a permute function transformation applied to the relative_attention_bias layer that is then stored in another variable, which might be leading to this - refer to modeling_t5.py in the huggingface transformers library).
wget https://raw.githubusercontent.com/pytorch/pytorch/master/torch/utils/collect_env.py
# For security purposes, please check the contents of collect_env.py before running it.
python collect_env.py
PyTorch Version (e.g., 1.0): 2.5.0+cu121
OS (e.g., Linux): Ubuntu 22.04.3 LTS (x86_64)
How you installed PyTorch (conda, pip, source): N/A
Build command you used (if compiling from source): N/A
Python version: 3.10.12
CUDA/cuDNN version: 12.1
GPU models and configuration: T4
Any other relevant information: Ran this on system with A6000, older version of cuda, etc. Encountered the same error.
Additional context
The text was updated successfully, but these errors were encountered:
Good catch. I believe the root cause is as what you describe: the same name but different shapes of values fools the hooks, causing this gradient mismatch.
Need to think more on the general solution to avoid this.
🐛 Bug
Please reproduce using our template Colab and post here the link
https://colab.research.google.com/drive/1Eu0rxSdbdJbZUBlgJ4wxR5QXJ732bc90?usp=sharing (Some parts of the code are redundant, apologies for that)
To Reproduce
Expected behavior
No gradient mismatch should be there. A deeper dive into this seems to indicate a gradient mismatch, and it appears that during the computation of the activations, the mismatch appears (this might have something to do with the forward hooks). The batch size is in the first dimension for all data points. Error seems to be appearing because of relative_attention_bias related weights in the model - on freezing it, training proceeds as it should (there is a permute function transformation applied to the relative_attention_bias layer that is then stored in another variable, which might be leading to this - refer to modeling_t5.py in the huggingface transformers library).
Environment
Please copy and paste the output from our
environment collection script
(or fill out the checklist below manually).
You can get the script and run it with:
conda
,pip
, source): N/AAdditional context
The text was updated successfully, but these errors were encountered: