forked from houxianxu/DFC-VAE
-
Notifications
You must be signed in to change notification settings - Fork 0
/
generate.lua
70 lines (59 loc) · 2.11 KB
/
generate.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
require 'torch'
require 'nn'
require 'image'
require 'Sampler'
disp = require 'display'
util = paths.dofile('util.lua')
opt = {
dataset = 'folder',
batchSize =100,
loadSize = 128,
fineSize = 128,
nz = 100, -- # of dim for Z
nThreads = 3, -- # of data loading threads to use
display = 1, -- display samples while training. 0 = false
gpu = 1, -- gpu = 0 is CPU mode. gpu=X is GPU mode on GPU X
decoder = 'checkpoints/cvae_content_123_decoder.t7',
encoder = 'checkpoints/cvae_content_123_encoder.t7',
image_folder = 'images/',
reconstruction = 0,
}
for k,v in pairs(opt) do opt[k] = tonumber(os.getenv(k)) or os.getenv(k) or opt[k] end
local dtype = 'torch.FloatTensor'
if opt.gpu >= 0 then
require 'cutorch'
require 'cunn'
require 'cudnn'
dtype = 'torch.CudaTensor'
print(string.format('Running with CUDA on GPU %d', opt.gpu))
end
local decoder = torch.load(opt.decoder):type(dtype)
local encoder = torch.load(opt.encoder):type(dtype)
local DataLoader = paths.dofile('data/data.lua')
local data = DataLoader.new(opt.nThreads, opt.dataset, opt)
local real = data:getBatch()
local input = real:type(dtype)
for i = 1, input:size(1) do
input[i] = util.preprocess(input[i]:float():clone())
end
input = input:type(dtype)
local results = encoder:forward(input)
local mean = results[1]
local log_var = results[2]
local z = nn.Sampler():forward({mean, log_var})
if not (opt.reconstruction == 1) then
z = torch.randn(z:size()):type(dtype):mul(0.4) -- can play with this parameter
end
reconstruct_results = decoder:forward(z)
fake = reconstruct_results
for i = 1, input:size(1) do
input[i] = util.deprocess(input[i]:float():clone())
input[i] = torch.clamp(input[i], 0, 1)
fake[i] = util.deprocess(fake[i]:float():clone())
end
disp.image(fake, {win=1000, title='test'})
if opt.reconstruction == 1 then
disp.image(input, {win=12, title='test input'})
end
-- image.save(opt.image_folder .. 'low_level_output_content.jpg', image.toDisplayTensor{input=fake, nrow=10})
-- image.save(opt.image_folder .. 'real.jpg', image.toDisplayTensor{input=real, nrow=10})