Skip to content

Add gradient noise scale logging #2019

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 5 commits into
base: sd3
Choose a base branch
from

Conversation

rockerBOO
Copy link
Contributor

https://arxiv.org/abs/1812.06162

In this paper, we demonstrate that a simple and easy-to-measure statistic called the gradient noise scale predicts the largest useful batch size across many domains and applications, including a number of supervised learning datasets (MNIST, SVHN, CIFAR-10, ImageNet, Billion Word), reinforcement learning domains (Atari and Dota), and even generative model training (autoencoders on SVHN). We find that the noise scale increases as the loss decreases over a training run and depends on the model size primarily through improved model performance.

Larger batch sizes Simple noise scale
Screenshot 2025-03-30 at 14-38-44 1812 06162v1 pdf Screenshot 2025-03-30 at 14-36-57 1812 06162v1 pdf

Because we accumulate the gradient for all gradients in the LoRA network, we would want to limit the number of batches (steps) but we can associate the loss and the gradient noise scale to find the optimal batch size according to the paper.

@rockerBOO
Copy link
Contributor Author

Added noise_variance and critical_batch_size which may be a little misleading but can refer to the paper for more info. I'm accumulating the dynamic batch size as part of the critical batch size calculation (like bucketing can create uneven amount of batches) so should be more accurate than a flat batch size configuration value.

Screenshot 2025-04-03 at 16-01-26 faithful-night-290 women-flux-kohya-lora – Weights   Biases

I have longer tests but I have added these since so I will need to do some more tests. I think the idea is noise variance should be flat if it is more ideal, and to change the batch size if it's not flat. Gradient noise scale and critical batch size are relatable in the paper but their relevance is architecture and dataset specific so you may find their relationship dynamic more relatable for different runs than what might be a good valuation. The paper doesn't go into diffusion models since it's from 2018 but maybe newer papers have looked further into this dynamic.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant