Skip to content

Commit

Permalink
add debug(CPU) mode
Browse files Browse the repository at this point in the history
  • Loading branch information
xiangze committed Sep 8, 2024
1 parent 151a7b8 commit a02cd59
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 27 deletions.
57 changes: 46 additions & 11 deletions diffuser.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,15 @@ def sample(self, num_samples, size=(1,28,28), num_cls=10, guide_w = 0.0):

class DDPM(nn.Module):

def __init__(self, model, betas, T = 500, dropout_p = 0.1, scheduler_type = 'cosine'):
def __init__(self, model, betas, T = 500, dropout_p = 0.1, scheduler_type = 'cosine',isgpu=True):

super().__init__()
self.model = model.cuda()
self.isgpu=isgpu

self.model = model

if(self.isgpu):
model= model.cuda()

for k, v in ddpm_schedule(betas[0], betas[1], T, scheduler_type).items():
self.register_buffer(k, v)
Expand All @@ -94,13 +99,18 @@ def __init__(self, model, betas, T = 500, dropout_p = 0.1, scheduler_type = 'cos

def forward(self, x, cls):

timestep = torch.randint(1, self.T, (x.shape[0], )).cuda()
timestep = torch.randint(1, self.T, (x.shape[0], ))
if(self.isgpu):
timestep = timestep.cuda()

noise = torch.randn_like(x)

x_t = (self.sqrt_abar_t[timestep, None, None, None] * x + self.sqrt_abar_t1[timestep, None, None, None] * noise)

ctx_mask = torch.bernoulli(torch.zeros_like(cls) + self.dropout_p).cuda()

ctx_mask = torch.bernoulli(torch.zeros_like(cls) + self.dropout_p)
if(self.isgpu):
ctx_mask =ctx_mask.cuda()

return noise, x_t, cls, timestep / self.T, ctx_mask

def A(self,x,t,eps,org=False):
Expand All @@ -111,11 +121,18 @@ def A(self,x,t,eps,org=False):

def sample(self, num_samples, size=(1,28,28), num_cls=10, guide_w = 0.0):

x_i = torch.randn(num_samples, *size).cuda()
c_i = torch.arange(0, num_cls).cuda()
x_i = torch.randn(num_samples, *size)
c_i = torch.arange(0, num_cls)
if(self.isgpu):
x_i = torch.randn(num_samples, *size).cuda()
c_i = c_i.cuda()

c_i = c_i.repeat(int(num_samples / c_i.shape[0]))

ctx_mask = torch.zeros_like(c_i).cuda()
ctx_mask = torch.zeros_like(c_i)
if(self.isgpu):
ctx_mask = ctx_mask.cuda()

c_i = c_i.repeat(2)
ctx_mask = ctx_mask.repeat(2)
ctx_mask[num_samples:] = 1.0
Expand All @@ -126,13 +143,18 @@ def sample(self, num_samples, size=(1,28,28), num_cls=10, guide_w = 0.0):
# T, T-1,T-2 ...,1
for i in range(self.T - 1, 0, -1):

t_is = torch.tensor([i / self.T]).cuda()
t_is = torch.tensor([i / self.T])
if(self.isgpu):
t_is = t_is.cuda()

t_is = t_is.repeat(num_samples, 1, 1, 1)

x_i = x_i.repeat(2, 1, 1, 1)
t_is = t_is.repeat(2, 1, 1, 1)

z = torch.randn(num_samples, *size).cuda() if i > 1 else 0
z = torch.randn(num_samples, *size) if i > 1 else 0
if(self.isgpu):
z=z.cuda()

eps = self.model(x_i, c_i, t_is, ctx_mask)
eps1 = eps[:num_samples]
Expand All @@ -158,7 +180,9 @@ def sample1(self, x,t,z,c_i,ctx_mask, eps,size=(1,28,28), guide_w = 0.0, num_sam
eps: from Unet
'''
#x = x.repeat(2, 1, 1, 1)
z = torch.randn(num_samples, *size).cuda()
z = torch.randn(num_samples, *size)
if(self.isgpu):
z=z.cuda()
x = x[:num_samples]
return x+ (-x+self.sqrt_alpha_t_inv[t] * (x - eps*self.alpha_t_div_sqrt_abar[t]) )*dt + self.sqrt_beta_t[t] * z*np.sqrt(dt)

Expand All @@ -169,6 +193,16 @@ def __init__(self, model, betas, T = 500, dropout_p = 0.1, scheduler_type = 'cos
for k, v in ddpm_schedule(betas[0], betas[1], T, scheduler_type).items():
self.register_buffer(k, v)

def forward(self, x, cls):
timestep = torch.randint(1, self.T, (x.shape[0], )).cuda()
noise = torch.randn_like(x)

x_t = (self.sqrt_abar_t[timestep, None, None, None] * x + self.sqrt_abar_t1[timestep, None, None, None] * noise)

ctx_mask = torch.bernoulli(torch.zeros_like(cls) + self.dropout_p).cuda()

return noise, x_t, cls, timestep / self.T, ctx_mask

def A(self,x,t,eps):
return self.betas[t]*eps

Expand Down Expand Up @@ -197,6 +231,7 @@ def sample(self, num_samples, size=(1,28,28), num_cls=10, guide_w = 0.0):
c_i = c_i.repeat(2)
ctx_mask = ctx_mask.repeat(2)
ctx_mask[num_samples:] = 1.0


#To Store intermediate results and create GIFs.
x_is = []
Expand Down
51 changes: 35 additions & 16 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
from gendata import eda,swissroll
import numpy as np

DEBUG=True

class AverageMeter:
def __init__(self, name=None):
self.name = name
Expand Down Expand Up @@ -74,20 +76,29 @@ def train(unet:UNet, ddpm_model:DDPM, loader, opt, criterion, scaler, num_cls, s
for idx, (img, class_lbl) in enumerate(loop):
# print(idx,img.shape,class_lbl)
# exit()
img = img.cuda(non_blocking = True)
lbl = class_lbl.cuda(non_blocking = True)


lbl=class_lbl
opt.zero_grad(set_to_none = True)

with torch.cuda.amp.autocast_mode.autocast():

if(DEBUG):
noise, x_t, ctx, timestep, ctx_mask = ddpm_model(img, lbl)
pred = unet(x_t.half(), ctx, timestep.half(), ctx_mask.half())
pred = unet(x_t, ctx, timestep, ctx_mask)
loss = criterion(noise, pred)
loss.backward()
opt.step()
else:
img = img.cuda(non_blocking = True)
lbl = lbl.cuda(non_blocking = True)

with torch.cuda.amp.autocast_mode.autocast():

scaler.scale(loss).backward()
scaler.step(opt)
scaler.update()
noise, x_t, ctx, timestep, ctx_mask = ddpm_model(img, lbl)
pred = unet(x_t.half(), ctx, timestep.half(), ctx_mask.half())
loss = criterion(noise, pred)

scaler.scale(loss).backward()
scaler.step(opt)
scaler.update()



Expand Down Expand Up @@ -183,7 +194,7 @@ def __len__(self):
parser.add_argument('-sche', '--scheduler_type',default="linear")
parser.add_argument('-bs', '--batchsize',default=64,type=int)
parser.add_argument('-ds', '--dataset',default="mnist")

parser.add_argument('-D', '--DEBUG',default=0,type=int)
args = parser.parse_args()
# print(args)
num_epochs = args.num_epochs
Expand All @@ -197,7 +208,7 @@ def __len__(self):
batchsize=args.batchsize
dataset=args.dataset
suffix=diftype

DEBUG=args.DEBUG

logfilename="log/TUR_log_skip{}sample{}epoch{}_{}_lr{}_{}".format(skip,TUR_samplenum,num_epochs,scheduler_type,lr,suffix)
if (init_every_sample):
Expand Down Expand Up @@ -248,18 +259,26 @@ def __len__(self):
loader = DataLoader(ds, batch_size = batchsize, shuffle = True, num_workers = 0)
print(len(loader.dataset))

unet = UNet(img_size[0], 128, num_cls).cuda()
unet = UNet(img_size[0], 128, num_cls)
if(not DEBUG):
unet=unet.cuda()

if(diftype=="SMLD"):
ddpm_model = SMLD(unet, (1e-4, 0.02),scheduler_type=scheduler_type).cuda()
ddpm_model = SMLD(unet, (1e-4, 0.02),scheduler_type=scheduler_type)
else:
ddpm_model = DDPM(unet, (1e-4, 0.02),scheduler_type=scheduler_type).cuda()
ddpm_model = DDPM(unet, (1e-4, 0.02),scheduler_type=scheduler_type,isgpu=not DEBUG)

if(not DEBUG):
ddpm_model=ddpm_model.cuda()

opt = torch.optim.Adam(list(ddpm_model.parameters()) + list(unet.parameters()), lr =lr)
criterion = nn.MSELoss()

scaler = torch.cuda.amp.grad_scaler.GradScaler()

if(DEBUG):
scaler =None
else:
scaler = torch.cuda.amp.grad_scaler.GradScaler()

ws = [0.0, 0.5, 1.0]

for epoch in range(num_epochs):
Expand Down

0 comments on commit a02cd59

Please sign in to comment.