-
Notifications
You must be signed in to change notification settings - Fork 22
/
util.lua
141 lines (124 loc) · 4.12 KB
/
util.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
local util = {}
function util.save(filename, net, gpu)
net:float() -- if needed, bring back to CPU
local netsave = net:clone()
if gpu > 0 then
net:cuda()
end
for k, l in ipairs(netsave.modules) do
netsave.modules[k] = util.optimizemodule(netsave.modules[k])
end
if torch.type(netsave.output) == 'table' then
for k, o in ipairs(netsave.output) do
netsave.output[k] = o.new()
end
else
netsave.output = netsave.output.new()
end
netsave.gradInput = netsave.gradInput.new()
netsave:apply(function(m) if m.weight then m.gradWeight = nil; m.gradBias = nil; end end)
torch.save(filename, netsave)
end
function util.optimizemodule(m)
-- convert to CPU compatible model
if torch.type(m) == 'cudnn.SpatialConvolution' then
local new = nn.SpatialConvolution(m.nInputPlane, m.nOutputPlane,
m.kW, m.kH, m.dW, m.dH,
m.padW, m.padH)
new.weight:copy(m.weight)
new.bias:copy(m.bias)
m = new
elseif torch.type(l) == 'fbnn.SpatialBatchNormalization' then
new = nn.SpatialBatchNormalization(m.weight:size(1), m.eps,
m.momentum, m.affine)
new.running_mean:copy(m.running_mean)
new.running_std:copy(m.running_std)
if m.affine then
new.weight:copy(m.weight)
new.bias:copy(m.bias)
end
m = new
elseif m['modules'] then
for k, l in ipairs(m.modules) do
m.modules[k] = util.optimizemodule(m.modules[k])
end
return m
end
-- clean up buffers
m.output = m.output.new()
m.gradInput = m.gradInput.new()
m.finput = m.finput and m.finput.new() or nil
m.fgradInput = m.fgradInput and m.fgradInput.new() or nil
m.buffer = nil
m.buffer2 = nil
m.centered = nil
m.std = nil
m.normalized = nil
-- TODO: figure out why giant storage-offsets being created on typecast
if m.weight then
m.weight = m.weight:clone()
m.gradWeight = m.gradWeight:clone()
m.bias = m.bias:clone()
m.gradBias = m.gradBias:clone()
end
return m
end
function util.load(filename, gpu)
local net = torch.load(filename)
net:apply(function(m) if m.weight then
m.gradWeight = m.weight:clone():zero();
m.gradBias = m.bias:clone():zero(); end end)
return net
end
function util.cudnn(net)
for k, l in ipairs(net.modules) do
-- convert to cudnn
if torch.type(l) == 'nn.SpatialConvolution' and pcall(require, 'cudnn') then
local new = cudnn.SpatialConvolution(l.nInputPlane, l.nOutputPlane,
l.kW, l.kH, l.dW, l.dH,
l.padW, l.padH)
new.weight:copy(l.weight)
new.bias:copy(l.bias)
net.modules[k] = new
end
end
return net
end
-- a function to do memory optimizations by
-- setting up double-buffering across the network.
-- this drastically reduces the memory needed to generate samples.
function util.optimizeInferenceMemory(net)
local finput, output, outputB
net:apply(
function(m)
if torch.type(m):find('Convolution') then
finput = finput or m.finput
m.finput = finput
output = output or m.output
m.output = output
elseif torch.type(m):find('ReLU') then
m.inplace = true
elseif torch.type(m):find('BatchNormalization') then
outputB = outputB or m.output
m.output = outputB
end
end)
end
function util.preprocess(img)
local mean_pixel = torch.FloatTensor({103.939, 116.779, 123.68})
local perm = torch.LongTensor{3, 2, 1}
img = img:index(1, perm):mul(255.0)
mean_pixel = mean_pixel:view(3, 1, 1):expandAs(img)
img:add(-1, mean_pixel)
return img
end
-- Undo the above preprocessing.
function util.deprocess(img)
local mean_pixel = torch.FloatTensor({103.939, 116.779, 123.68})
mean_pixel = mean_pixel:view(3, 1, 1):expandAs(img)
img = img + mean_pixel
local perm = torch.LongTensor{3, 2, 1}
img = img:index(1, perm):div(255.0)
return img
end
return util