From 4d38727c90d577b3929492baebee9dbaa487e34f Mon Sep 17 00:00:00 2001 From: Carter Feldman Date: Thu, 27 Mar 2025 09:29:38 +0800 Subject: [PATCH] Update sdxl_original_unet.py for apple silicon Remove cuda specific stuff to support apple silicon --- library/sdxl_original_unet.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/library/sdxl_original_unet.py b/library/sdxl_original_unet.py index 6ea4bc332..2bb8482b8 100644 --- a/library/sdxl_original_unet.py +++ b/library/sdxl_original_unet.py @@ -29,6 +29,7 @@ import torch.utils.checkpoint from torch import nn from torch.nn import functional as F +import torch.amp as amp from einops import rearrange @@ -1110,9 +1111,8 @@ def call_module(module, h, emb, context): import transformers optimizer = transformers.optimization.Adafactor(unet.parameters(), relative_step=True) # working at 22.2GB with torch2 - - scaler = torch.cuda.amp.GradScaler(enabled=True) - + # scaler = torch.cuda.amp.GradScaler(enabled=True) + scaler = amp.GradScaler(enabled=True) print("start training") steps = 10 batch_size = 1 @@ -1127,7 +1127,7 @@ def call_module(module, h, emb, context): ctx = torch.randn(batch_size, 77, 2048).cuda() y = torch.randn(batch_size, ADM_IN_CHANNELS).cuda() - with torch.cuda.amp.autocast(enabled=True): + with amp.autocast(enabled=True): output = unet(x, t, ctx, y) target = torch.randn_like(output) loss = torch.nn.functional.mse_loss(output, target)