-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathempirical.m
24 lines (21 loc) · 850 Bytes
/
empirical.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
label_counting = 23;
indices = find((partition_all == 1) .* (label_all == label_counting));
counting = ones(visible, 1);
for eval_index_index = 1:length(indices) % for each data point
eval_index = indices(eval_index_index);
visible_units = get_data_from_index(data_all, eval_index, const_h, const_w, channels);
counting = counting + visible_units;
end
% w = gather(matrix(visible+hidden+1, 1:visible));
counting = counting / length(indices);
w_counting = log(counting ./ (1 - counting))';
matrix(visible+hidden+1, 1:visible) = w_counting;
matrix(1:visible, 1+visible+hidden) = w_counting;
matrix = max(matrix, -10);
matrix = gpuArray(single(matrix));
save(['matrix_', num2str(label_counting),'_counting_test.mat'],'matrix');
a = 0;
figure;
imshow(reshape(weights, 7, 8624 / 7), []);
figure;
imshow(reshape(counting, 7, 8624 / 7), []);