Skip to content

Update sdxl_original_unet.py for apple silicon #2008

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions library/sdxl_original_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import torch.utils.checkpoint
from torch import nn
from torch.nn import functional as F
import torch.amp as amp
from einops import rearrange


Expand Down Expand Up @@ -1110,9 +1111,8 @@ def call_module(module, h, emb, context):
import transformers

optimizer = transformers.optimization.Adafactor(unet.parameters(), relative_step=True) # working at 22.2GB with torch2

scaler = torch.cuda.amp.GradScaler(enabled=True)

# scaler = torch.cuda.amp.GradScaler(enabled=True)
scaler = amp.GradScaler(enabled=True)
print("start training")
steps = 10
batch_size = 1
Expand All @@ -1127,7 +1127,7 @@ def call_module(module, h, emb, context):
ctx = torch.randn(batch_size, 77, 2048).cuda()
y = torch.randn(batch_size, ADM_IN_CHANNELS).cuda()

with torch.cuda.amp.autocast(enabled=True):
with amp.autocast(enabled=True):
output = unet(x, t, ctx, y)
target = torch.randn_like(output)
loss = torch.nn.functional.mse_loss(output, target)
Expand Down