Skip to content

Commit

Permalink
implement swissroll dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
xiangze committed Aug 11, 2024
1 parent 4e1c6b6 commit 706acf8
Showing 1 changed file with 34 additions and 15 deletions.
49 changes: 34 additions & 15 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
from model.Unet import UNet
from diffuser import DDPM,SMLD
from TURsample import TUR_sample

from gendata import eda,swissroll
import numpy as np

class AverageMeter:
def __init__(self, name=None):
Expand Down Expand Up @@ -148,7 +149,7 @@ def animate_plot(i, xis):

if __name__ == '__main__':
# wandb.init(project = 'MinDiffusion')
num_cls = 10


parser = argparse.ArgumentParser(description='学習、生成段階におけるエントロピー生成率と変数のゆらぎを出力する')
parser.add_argument('-n', '--num_epochs',default=10,type=int)
Expand Down Expand Up @@ -176,13 +177,6 @@ def animate_plot(i, xis):
dataset=args.dataset
suffix=diftype

unet = UNet(1, 128, num_cls).cuda()

if(diftype=="SMLD"):
ddpm_model = SMLD(unet, (1e-4, 0.02),scheduler_type=scheduler_type).cuda()
else:
ddpm_model = DDPM(unet, (1e-4, 0.02),scheduler_type=scheduler_type).cuda()


logfilename="log/TUR_log_skip{}sample{}epoch{}_{}_lr{}_{}".format(skip,TUR_samplenum,num_epochs,scheduler_type,lr,suffix)
if (init_every_sample):
Expand All @@ -193,24 +187,49 @@ def animate_plot(i, xis):
tr = T.Compose([T.ToTensor()])
if(dataset=="mnist"):
ds = tv.datasets.MNIST('data', True, transform = tr, download = True)
img_size=(1,28,28),
img_size=(1,28,28)
num_cls = 10
if(dataset=="kmnist"):
ds = tv.datasets.KMNIST('data', True, transform = tr, download = True)
img_size=(1,28,28),
img_size=(1,28,28)
num_cls = 10
elif(dataset=="cifar10"):
ds = tv.datasets.CIFAR10('data', True, transform = tr, download = True)
img_size=(3,28,28),
img_size=(3,32,32) # 128,1,3,3
num_cls = 10
elif(dataset=="cifar100"):
ds = tv.datasets.CIFAR100('data', True, transform = tr, download = True)
img_size=(3,28,28),
img_size=(3,32,32)
num_cls = 100
elif(dataset=="imagenet"):

img_size=(3,28,28),
ds = tv.datasets.imagenet('data', True, transform = tr, download = True)
img_size=(3,28,28)
elif(dataset=="eda"):
ddim=5
img_size=(1,4,4)
ds=eda(ddim,4*4).reshape(img_size)
elif(dataset=="swissroll"):
N=6000
img_size=(1,28,28)
ddim=28*28
num_cls=10
d,label=swissroll(N,ddim)
label=np.array(label)
label=((label-min(label))/(max(label)-min(label))*num_cls).astype("int")
ds = torch.utils.data.TensorDataset( torch.Tensor(d), torch.Tensor(label))
else:
print("unsupported dataset {}".format(dataset))

# print(ds.shape)
loader = DataLoader(ds, batch_size = batchsize, shuffle = True, num_workers = 0)

unet = UNet(img_size[0], 128, num_cls).cuda()

if(diftype=="SMLD"):
ddpm_model = SMLD(unet, (1e-4, 0.02),scheduler_type=scheduler_type).cuda()
else:
ddpm_model = DDPM(unet, (1e-4, 0.02),scheduler_type=scheduler_type).cuda()

opt = torch.optim.Adam(list(ddpm_model.parameters()) + list(unet.parameters()), lr =lr)
criterion = nn.MSELoss()

Expand Down

0 comments on commit 706acf8

Please sign in to comment.