Skip to content

Commit

Permalink
Merge branch 'sd3' into val-loss-improvement
Browse files Browse the repository at this point in the history
  • Loading branch information
kohya-ss committed Jan 27, 2025
2 parents 0750859 + 0778dd9 commit 42c0a9e
Showing 1 changed file with 1 addition and 0 deletions.
1 change: 1 addition & 0 deletions flux_train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,7 @@ def get_noise_pred_and_target(

def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t5_attn_mask):
# grad is enabled even if unet is not in train mode, because Text Encoder is in train mode

with torch.set_grad_enabled(is_train), accelerator.autocast():
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)
model_pred = unet(
Expand Down

0 comments on commit 42c0a9e

Please sign in to comment.