-
Notifications
You must be signed in to change notification settings - Fork 0
/
LookupTable.lua
164 lines (141 loc) · 4.73 KB
/
LookupTable.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
local THNN = require 'nn.THNN'
local LookupTable, parent = torch.class('nn.LookupTable', 'nn.Module')
LookupTable.__version = 4
function LookupTable:__init(nIndex, nOutput, paddingValue, maxNorm, normType)
parent.__init(self)
self.weight = torch.Tensor(nIndex, nOutput)
self.gradWeight = torch.Tensor(nIndex, nOutput):zero()
self.paddingValue = paddingValue or 0
self.maxNorm = maxNorm or nil
self.normType = normType or nil
self:reset()
end
function LookupTable:backCompatibility()
self._count = self._count or torch.IntTensor()
self._input = self._input or torch.LongTensor()
if not self.shouldScaleGradByFreq then
self.shouldScaleGradByFreq = false
end
end
function LookupTable:accUpdateOnly()
self.gradWeight = nil
return self
end
function LookupTable:setPadding(paddingValue)
self.paddingValue = paddingValue
return self
end
function LookupTable:setMaxNorm(maxNorm)
self.maxNorm = maxNorm
return self
end
function LookupTable:setNormType(normType)
self.normType = normType
return self
end
function LookupTable:scaleGradByFreq()
self.shouldScaleGradByFreq = true
return self
end
function LookupTable:reset(stdv)
stdv = stdv or 1
self.weight:normal(0, stdv)
end
function LookupTable:makeInputContiguous(input)
-- make sure input is a contiguous torch.LongTensor
if (not input:isContiguous()) or torch.type(input) ~= torch.type(self._input) then
self.copiedInput = true
self._input:resize(input:size()):copy(input)
return self._input
end
self.copiedInput = false
return input
end
function LookupTable:updateOutput(input)
self:backCompatibility()
self:renorm(input)
input = self:makeInputContiguous(input)
if input:dim() == 1 then
self.output:index(self.weight, 1, input)
elseif input:dim() == 2 then
self.output:index(self.weight, 1, input:view(-1))
self.output = self.output:view(input:size(1), input:size(2), self.weight:size(2))
else
error("input must be a vector or matrix")
end
return self.output
end
function LookupTable:updateGradInput(input, gradOutput)
-- the input can be of any type (as in the forward it's
-- converted anyway to LongTensor) thus, need to allocate
-- new memory each time the user changes the input type
if torch.type(self.gradInput) ~= torch.type(input) then
self.gradInput = input.new()
end
if not self.gradInput:isSameSizeAs(input) then
self.gradInput:resizeAs(input):zero()
end
return self.gradInput
end
function LookupTable:accGradParameters(input, gradOutput, scale)
self:backCompatibility()
input = self.copiedInput and self._input or input
if input:dim() == 2 then
input = input:view(-1)
elseif input:dim() ~= 1 then
error("input must be a vector or matrix")
end
self.gradWeight.THNN.LookupTable_accGradParameters(
input:cdata(),
gradOutput:cdata(),
self.gradWeight:cdata(),
self._count:cdata(),
THNN.optionalTensor(self._sorted),
THNN.optionalTensor(self._indices),
self.shouldScaleGradByFreq or false,
self.paddingValue or 0,
scale or 1
)
end
function LookupTable:renorm(input)
if not self.maxNorm then
return
end
-- copy input into _input, so _input is continuous.
-- The copied _input will be modified in the C code.
self._input:resize(input:size()):copy(input)
local row_idx = self._input
if row_idx:dim() == 2 then
row_idx = row_idx:view(-1)
elseif row_idx:dim() ~= 1 then
error("input must be a vector or matrix")
end
-- "row_idx" and "weight" will be modified in the C code
self.weight.THNN.LookupTable_renorm(
row_idx:cdata(),
self.weight:cdata(),
self.maxNorm,
self.normType or 2
)
end
function LookupTable:type(type, tensorCache)
parent.type(self, type, tensorCache)
if type and type:find('torch%.Cuda.*Tensor') then
-- CUDA uses _sorted and _indices temporary tensors
self._sorted = torch.CudaLongTensor and torch.CudaLongTensor.new() or torch.CudaTensor.new()
self._indices = torch.CudaLongTensor and torch.CudaLongTensor.new() or torch.CudaTensor.new()
self._count = torch.CudaLongTensor and torch.CudaLongTensor.new() or torch.CudaTensor.new()
self._input = torch.CudaLongTensor and torch.CudaLongTensor.new() or torch.CudaTensor.new()
else
-- self._count and self._input should only be converted if using Cuda
self._count = torch.IntTensor()
self._input = torch.LongTensor()
end
return self
end
function LookupTable:clearState()
nn.utils.clear(self, '_count', '_input')
return parent.clearState(self)
end
-- we do not need to accumulate parameters when sharing
LookupTable.sharedAccUpdateGradParameters = LookupTable.accUpdateGradParameters