Skip to content

Commit

Permalink
fix Text Encoder only LoRA training
Browse files Browse the repository at this point in the history
  • Loading branch information
kohya-ss committed Jan 27, 2025
1 parent 59b3b94 commit 0778dd9
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 3 deletions.
2 changes: 1 addition & 1 deletion flux_train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,7 @@ def get_noise_pred_and_target(
def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t5_attn_mask):
# if not args.split_mode:
# normal forward
with torch.set_grad_enabled(is_train and train_unet), accelerator.autocast():
with torch.set_grad_enabled(is_train), accelerator.autocast():
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)
model_pred = unet(
img=img,
Expand Down
2 changes: 1 addition & 1 deletion sd3_train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,7 @@ def get_noise_pred_and_target(
t5_attn_mask = None

# call model
with torch.set_grad_enabled(is_train and train_unet), accelerator.autocast():
with torch.set_grad_enabled(is_train), accelerator.autocast():
# TODO support attention mask
model_pred = unet(noisy_model_input, timesteps, context=context, y=lg_pooled)

Expand Down
2 changes: 1 addition & 1 deletion train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def get_noise_pred_and_target(
t.requires_grad_(True)

# Predict the noise residual
with torch.set_grad_enabled(is_train and train_unet), accelerator.autocast():
with torch.set_grad_enabled(is_train), accelerator.autocast():
noise_pred = self.call_unet(
args,
accelerator,
Expand Down

0 comments on commit 0778dd9

Please sign in to comment.