forked from houxianxu/DFC-VAE
-
Notifications
You must be signed in to change notification settings - Fork 0
/
GaussianCriterion.lua
28 lines (19 loc) · 923 Bytes
/
GaussianCriterion.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
require 'nn'
local GaussianCriterion, parent = torch.class('nn.GaussianCriterion', 'nn.Criterion')
function GaussianCriterion:updateOutput(input, target)
-- - log(sigma) - 0.5 *(2pi)) - 0.5 * (x - mu)^2/sigma^2
-- input[1] = mu
-- input[2] = log(sigma^2)
local Gelement = torch.mul(input[2],0.5):add(0.5 * math.log(2 * math.pi))
Gelement:add(torch.add(target,-1,input[1]):pow(2):cdiv(torch.exp(input[2])):mul(0.5))
self.output = torch.sum(Gelement)
return self.output
end
function GaussianCriterion:updateGradInput(input, target)
self.gradInput = {}
-- (x - mu) / sigma^2 --> (1 / sigma^2 = exp(-log(sigma^2)) )
self.gradInput[1] = torch.exp(-input[2]):cmul(torch.add(target,-1,input[1])):mul(-1)
-- - 0.5 + 0.5 * (x - mu)^2 / sigma^2
self.gradInput[2] = torch.exp(-input[2]):cmul(torch.add(target,-1,input[1]):pow(2)):mul(-1):add(0.5)
return self.gradInput
end