-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcalcGradient.m
87 lines (80 loc) · 2.81 KB
/
calcGradient.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
function [ nn ] = calcGradient( nn, Y, C, b )
% calculate derivative of L wrt to all w in every layer, and
% store the results in nn.deltaW{j} for layer j.
Xred = nn.a{nn.n - 1};
J = nn.n; % number of layers
[m n] = size(Xred);
%% calculate the m * n matrix of dLdX
dLdX = dLogPartialLdX(Xred, Y, C, b);
%% Numerical calculation of dLdX
% e = 1e-4;
% dLdXapprox = zeros(m, n);
% for i = 1:m
% for j = 1:n
% XredMinus = Xred;
% XredMinus(i , j) = Xred(i , j) - e;
% gradMinus = LogPartialL(XredMinus, Y, C, b);
%
% XredPlus = Xred;
% XredPlus(i , j) = Xred(i, j) + e;
% gradPlus = LogPartialL(XredPlus, Y, C, b);
% dLdXapprox(i , j) = (gradPlus - gradMinus) / (2 * e);
% end
% end
%
%% Last layer
nn.deltaW{J - 1} = dLogPartialL(Xred, Y, C, b)';
%% Exact computation of dXX
dxx = cell(1,m);
for i = 1:m
dxx{i} = dXX(nn, i);
end
%% Sigmoid layers
nn.d_act = cell(size(nn.a));
for j = (J - 1):-1:1
switch nn.activation_function
case 'sigm'
nn.d_act{j} = nn.a{j} .* (1 - nn.a{j});
case 'tanh_opt'
nn.d_act{j} = 1.7159 * 2/3 * (1 - 1/(1.7159)^2 * nn.a{j}.^2);
end
end
for j = (J - 2):-1:1
% tic;
[P, Q] = size(nn.W{j});
nn.deltaW{j} = zeros(P, Q);
% nn.deltaWapprox{j} = zeros(P, Q);
% nn.dXdwapprox{j} = zeros([P, Q, size(Xred)]);
for p = P:-1:1
for q = Q:-1:1
%% Numerical calcucation of dXdW
%
% nnPlus = nn;
% nnMinus = nn;
%
% nnPlus.W{j}(p , q) = nnPlus.W{j}(p , q) + e;
% nnMinus.W{j}(p , q) = nnMinus.W{j}(p , q) - e;
%
% nnPlus = mynnff(nnPlus, nnPlus.a{1}(:, 2:end) );
% nnMinus = mynnff(nnMinus, nnMinus.a{1}(:, 2:end) );
%
% XredMinus = nnMinus.a{nn.n - 1};
% XredPlus = nnPlus.a{nn.n - 1};
%
% nn.dXdwapprox{j}(p,q, :,:) = (XredPlus - XredMinus)/(2 * e);
%% Back Propagation
for i = 1:1:m
% for each w, dL/dw, note that L is the result of summation
% over Xi
tmp = dXW(nn, dxx{i}, i, j, p, q);
% tmp2 = reshape(nn.dXdwapprox{j}(p, q, i, :), 1, n);
% diff = norm(tmp' - tmp2)
%nn.deltaWapprox{j}(p, q) = nn.deltaWapprox{j}(p, q) + reshape(nn.dXdwapprox{j}(p, q, i, :), 1, n) * ...
%dLdXapprox(i, :)';
nn.deltaW{j}(p, q) = nn.deltaW{j}(p, q) + tmp' * dLdX(i, :)';
end
end
end
% toc
end
end