-
Notifications
You must be signed in to change notification settings - Fork 2
/
Sampler.lua
31 lines (24 loc) · 907 Bytes
/
Sampler.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
-- Based on JoinTable module
--Written by Shanu Kumar
--Copyright (c) 2019, Shanu Kumar [See LICENSE file for details]
require 'nn'
local Sampler, parent = torch.class('nn.Sampler', 'nn.Module')
function Sampler:__init(dim)
parent.__init(self)
self.dim = dim
end
function Sampler:updateOutput(input)
self.eps = torch.randn(input:size(1), self.dim):type(input:type())
self.output = self.output or self.output.new()
self.output:resizeAs(self.eps):copy(self.eps)
self.output:cmul(torch.expand(input, input:size(1), self.dim))
return self.output
end
function Sampler:updateGradInput(input, gradOutput)
self.gradInput = self.gradInput or input.new()
self.gradInput:resizeAs(input)
local gi = torch.cmul(self.eps, torch.expand(input, input:size(1), self.dim))
gi:mul(0.5):cmul(gradOutput)
self.gradInput:copy(torch.sum(gi, 2))
return self.gradInput
end