-
-
Notifications
You must be signed in to change notification settings - Fork 332
/
dcgan_mnist.jl
147 lines (125 loc) · 4.59 KB
/
dcgan_mnist.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
using Base.Iterators: partition
using Flux
using Flux.Optimise: update!
using Flux.Losses: logitbinarycrossentropy
using Images
using MLDatasets
using Statistics
using Printf
using Random
using CUDA
CUDA.allowscalar(false)
Base.@kwdef struct HyperParams
batch_size::Int = 128
latent_dim::Int = 100
epochs::Int = 20
verbose_freq::Int = 1000
output_x::Int = 6
output_y::Int = 6
lr_dscr::Float32 = 0.0002
lr_gen::Float32 = 0.0002
end
function create_output_image(gen, fixed_noise, hparams)
fake_images = @. cpu(gen(fixed_noise))
image_array = reduce(vcat, reduce.(hcat, partition(fake_images, hparams.output_y)))
image_array = permutedims(dropdims(image_array; dims=(3, 4)), (2, 1))
image_array = @. Gray(image_array + 1f0) / 2f0
return image_array
end
# weight initialization as given in the paper https://arxiv.org/abs/1511.06434
dcgan_init(shape...) = randn(Float32, shape...) * 0.02f0
function Discriminator()
return Chain(
Conv((4, 4), 1 => 64; stride = 2, pad = 1, init = dcgan_init),
x->leakyrelu.(x, 0.2f0),
Dropout(0.25),
Conv((4, 4), 64 => 128; stride = 2, pad = 1, init = dcgan_init),
x->leakyrelu.(x, 0.2f0),
Dropout(0.25),
x->reshape(x, 7 * 7 * 128, :),
Dense(7 * 7 * 128, 1))
end
function Generator(latent_dim::Int)
return Chain(
Dense(latent_dim, 7 * 7 * 256),
BatchNorm(7 * 7 * 256, relu),
x->reshape(x, 7, 7, 256, :),
ConvTranspose((5, 5), 256 => 128; stride = 1, pad = 2, init = dcgan_init),
BatchNorm(128, relu),
ConvTranspose((4, 4), 128 => 64; stride = 2, pad = 1, init = dcgan_init),
BatchNorm(64, relu),
ConvTranspose((4, 4), 64 => 1; stride = 2, pad = 1, init = dcgan_init),
x -> tanh.(x)
)
end
# Loss functions
function discriminator_loss(real_output, fake_output)
real_loss = logitbinarycrossentropy(real_output, 1)
fake_loss = logitbinarycrossentropy(fake_output, 0)
return real_loss + fake_loss
end
generator_loss(fake_output) = logitbinarycrossentropy(fake_output, 1)
function train_discriminator!(gen, dscr, x, opt_dscr, hparams)
noise = randn!(similar(x, (hparams.latent_dim, hparams.batch_size)))
fake_input = gen(noise)
# Taking gradient
loss, grads = Flux.withgradient(dscr) do dscr
discriminator_loss(dscr(x), dscr(fake_input))
end
update!(opt_dscr, dscr, grads[1])
return loss
end
function train_generator!(gen, dscr, x, opt_gen, hparams)
noise = randn!(similar(x, (hparams.latent_dim, hparams.batch_size)))
# Taking gradient
loss, grads = Flux.withgradient(gen) do gen
generator_loss(dscr(gen(noise)))
end
update!(opt_gen, gen, grads[1])
return loss
end
function train(; kws...)
# Model Parameters
hparams = HyperParams(; kws...)
if CUDA.functional()
@info "Training on GPU"
else
@warn "Training on CPU, this will be very slow!" # 20 mins/epoch
end
# Load MNIST dataset
images = MLDatasets.MNIST(:train).features
# Normalize to [-1, 1]
image_tensor = reshape(@.(2f0 * images - 1f0), 28, 28, 1, :)
# Partition into batches
data = [image_tensor[:, :, :, r] |> gpu for r in partition(1:60000, hparams.batch_size)]
fixed_noise = [randn(Float32, hparams.latent_dim, 1) |> gpu for _=1:hparams.output_x*hparams.output_y]
# Discriminator
dscr = Discriminator() |> gpu
# Generator
gen = Generator(hparams.latent_dim) |> gpu
# Optimizers
opt_dscr = Flux.setup(Adam(hparams.lr_dscr), dscr)
opt_gen = Flux.setup(Adam(hparams.lr_gen), gen)
# Training
train_steps = 0
for ep in 1:hparams.epochs
@info "Epoch $ep"
for x in data
# Update discriminator and generator
loss_dscr = train_discriminator!(gen, dscr, x, opt_dscr, hparams)
loss_gen = train_generator!(gen, dscr, x, opt_gen, hparams)
if train_steps % hparams.verbose_freq == 0
@info("Train step $(train_steps), Discriminator loss = $(loss_dscr), Generator loss = $(loss_gen)")
# Save generated fake image
output_image = create_output_image(gen, fixed_noise, hparams)
save(@sprintf("output/dcgan_steps_%06d.png", train_steps), output_image)
end
train_steps += 1
end
end
output_image = create_output_image(gen, fixed_noise, hparams)
save(@sprintf("output/dcgan_steps_%06d.png", train_steps), output_image)
end
if abspath(PROGRAM_FILE) == @__FILE__
train()
end