-
Notifications
You must be signed in to change notification settings - Fork 1
/
train.py
99 lines (76 loc) · 3.02 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from networks.gan import Generator, Discriminator
import argparse
from tqdm import tqdm
parser = argparse.ArgumentParser()
parser.add_argument(
'--lr',
type = float,
help ='learning rate',
default = 0.001
)
parser.add_argument(
'--batch_size',
type = int,
help ='batch size for processing the images',
default = 32
)
parser.add_argument(
'--epoch_size',
type = int,
help ='no of epochs to train GAN for',
default = 100
)
args = parser.parse_args()
def train():
lr = args.lr
z_dim = 64
image_dim = 28 * 28 * 1 # 784
batch_size = args.batch_size
num_epochs = args.epoch_size
device = "cuda" if torch.cuda.is_available() else "cpu"
discr = Discriminator(image_dim).to(device)
genr = Generator(z_dim, image_dim).to(device)
transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,)),]
)
dataset = datasets.MNIST(root="dataset/", transform=transform, download=True)
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
opt_disc = optim.Adam(discr.parameters(), lr=lr)
opt_gen = optim.Adam(genr.parameters(), lr=lr)
criterion = nn.BCELoss()
for epoch in range(num_epochs):
with tqdm(loader, unit="batch") as tepoch:
tepoch.set_description(f"Epoch {epoch+1}")
for batch_idx, (real, _) in enumerate(tepoch):
real = real.view(-1, 784).to(device)
batch_size = real.shape[0]
### Train Discriminator: maximize => log(D(x)) + log(1 - D(G(z)))
### z = a random noise, x = real sample, D is denoted by the discriminator
noise = torch.randn(batch_size, z_dim).to(device)
#print(batch_idx)
fake = genr(noise)
disc_real = discr(real).view(-1)
lossD_real = criterion(disc_real, torch.ones_like(disc_real))
disc_fake = discr(fake).view(-1)
lossD_fake = criterion(disc_fake, torch.zeros_like(disc_fake))
lossD = (lossD_real + lossD_fake) / 2
discr.zero_grad()
lossD.backward(retain_graph=True)
opt_disc.step()
## Train Generator
output = discr(fake).view(-1)
lossG = criterion(output, torch.ones_like(output))
genr.zero_grad()
lossG.backward()
opt_gen.step()
tepoch.set_postfix(d_loss=lossD.item(), g_loss=lossG.item())
if __name__ == '__main__':
train()