Skip to content

Add gradient noise scale logging #2019

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 5 commits into
base: sd3
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -4125,6 +4125,7 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
default=None,
help="tags for model metadata, separated by comma / メタデータに書き込まれるモデルタグ、カンマ区切り",
)
parser.add_argument("--gradient_noise_scale", action="store_true", default=False, help="Calculate the gradient noise scale")

if support_dreambooth:
# DreamBooth training
Expand Down
69 changes: 69 additions & 0 deletions networks/lora_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,9 @@ def __init__(
self.dropout = dropout
self.rank_dropout = rank_dropout
self.module_dropout = module_dropout
self.grad_count = 0
self.sum_grads = None
self.sum_squared_grads = None

self.ggpo_sigma = ggpo_sigma
self.ggpo_beta = ggpo_beta
Expand Down Expand Up @@ -293,6 +296,19 @@ def update_grad_norms(self):
approx_grad = self.scale * ((self.lora_up.weight @ lora_down_grad) + (lora_up_grad @ self.lora_down.weight))
self.grad_norms = torch.norm(approx_grad, dim=1, keepdim=True)

def accumulate_grad(self):
for param in self.parameters():
if param.grad is not None:
grad = param.grad.detach().flatten()
self.grad_count += grad.numel()

# Update running sums
if self.sum_grads is None:
self.sum_grads = grad.sum()
self.sum_squared_grads = (grad**2).sum()
else:
self.sum_grads += grad.sum()
self.sum_squared_grads += (grad**2).sum()

@property
def device(self):
Expand Down Expand Up @@ -976,6 +992,59 @@ def combined_weight_norms(self) -> Tensor:
combined_weight_norms.append(lora.combined_weight_norms.mean(dim=0))
return torch.stack(combined_weight_norms) if len(combined_weight_norms) > 0 else torch.tensor([])

def accumulate_grad(self):
for lora in self.text_encoder_loras + self.unet_loras:
lora.accumulate_grad()

def sum_grads(self):
sum_grads = []
sum_squared_grads = []
count = 0
for lora in self.text_encoder_loras + self.unet_loras:
if lora.sum_grads is not None:
sum_grads.append(lora.sum_grads)
if lora.sum_grads is not None:
sum_squared_grads.append(lora.sum_squared_grads)
count += lora.grad_count

return (
torch.stack(sum_grads) if len(sum_grads) > 0 else torch.tensor([]),
torch.stack(sum_squared_grads) if len(sum_squared_grads) > 0 else torch.tensor([]),
count
)

def gradient_noise_scale(self):
sum_grads, sum_squared_grads, count = self.sum_grads()

if count == 0:
return None, None

# Calculate mean gradient and mean squared gradient
mean_grad = torch.mean(sum_grads / count, dim=0)
mean_squared_grad = torch.mean(sum_squared_grads / count, dim=0)

# Variance = E[X²] - E[X]²
variance = mean_squared_grad - mean_grad**2

# GNS = trace(Σ) / ||μ||²
# trace(Σ) = sum of variances = count * variance (for uniform variance assumption)
trace_cov = count * variance
grad_norm_squared = count * mean_grad**2

gradient_noise_scale = trace_cov / grad_norm_squared
# mean_grad = torch.mean(all_grads, dim=0)
#
# # Calculate trace of covariance matrix
# centered_grads = all_grads - mean_grad
# trace_cov = torch.mean(torch.sum(centered_grads**2, dim=0))
#
# # Calculate norm of mean gradient squared
# grad_norm_squared = torch.sum(mean_grad**2)
#
# # Calculate GNS using provided gradient norm squared
# gradient_noise_scale = trace_cov / grad_norm_squared

return gradient_noise_scale.item(), variance.item()

def load_weights(self, file):
if os.path.splitext(file)[1] == ".safetensors":
Expand Down
12 changes: 12 additions & 0 deletions train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -1377,7 +1377,10 @@ def restore_rng_state(rng_states: tuple[torch.ByteTensor, Optional[torch.ByteTen
skipped_dataloader = accelerator.skip_first_batches(train_dataloader, initial_step - 1)
initial_step = 1

batch_size = 0
for step, batch in enumerate(skipped_dataloader or train_dataloader):
current_batch_size = len(batch['network_multipliers'])
batch_size += current_batch_size
current_step.value = global_step
if initial_step > 0:
initial_step -= 1
Expand Down Expand Up @@ -1418,6 +1421,8 @@ def restore_rng_state(rng_states: tuple[torch.ByteTensor, Optional[torch.ByteTen
network.update_grad_norms()
if hasattr(network, "update_norms"):
network.update_norms()
if args.gradient_noise_scale and hasattr(network, "accumulate_grad"):
network.accumulate_grad()

optimizer.step()
lr_scheduler.step()
Expand Down Expand Up @@ -1491,6 +1496,10 @@ def restore_rng_state(rng_states: tuple[torch.ByteTensor, Optional[torch.ByteTen
mean_grad_norm,
mean_combined_norm,
)
if args.gradient_noise_scale and hasattr(network, "gradient_noise_scale"):
gns, variance = network.gradient_noise_scale()
if gns is not None and variance is not None:
logs = {**logs, "gns/gradient_noise_scale": gns, "gns/noise_variance": variance, "gns/critcal_batch_size": gns / batch_size}
self.step_logging(accelerator, logs, global_step, epoch + 1)

# VALIDATION PER STEP: global_step is already incremented
Expand Down Expand Up @@ -1564,6 +1573,9 @@ def restore_rng_state(rng_states: tuple[torch.ByteTensor, Optional[torch.ByteTen
accelerator.unwrap_model(network).train()
progress_bar.unpause()

if accelerator.sync_gradients:
batch_size = 0 # reset batch size

if global_step >= args.max_train_steps:
break

Expand Down