-
Notifications
You must be signed in to change notification settings - Fork 1
/
mnist_mp.m
119 lines (84 loc) · 3.33 KB
/
mnist_mp.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
% load images
imds = imageDatastore('../dataset/mnist/images');
% class names
classNames =["zero","one","two","three","four","five","six","seven","eight","nine","ten"];
pixelLabelID = [0,1,2,3,4,5,6,7,8,9,10];
pxds = pixelLabelDatastore('../dataset/mnist/masks',classNames,pixelLabelID);
% count the pixels
tbl = countEachLabel(pxds);
% fix class weighting imbalance
% this one uses the median of the frequency class weights
imageFreq = tbl.PixelCount./ tbl.ImagePixelCount;
classWeights = median(imageFreq) ./ imageFreq;
% Visualize by pixel counts
bar(1:numel(classNames),imageFreq);
xticks(1:numel(classNames));
xticklabels(tbl.Name)
xtickangle(45);
ylabel('Frequency');
% Use Data Augmentation during training, this helps provide more examples,
% because it helps improve the accuracy of the network. This one was used
% in the matlab example online, let's see how it works
augmenter = imageDataAugmenter('RandXReflection', true, ...
'RandXTranslation',[-5,5], ...
'RandYTranslation',[-5,5]);
% create the imageDataStore
plds= pixelLabelImageDatastore(imds,pxds,'DataAugmentation',augmenter);
% shuffle the dataset
plds = shuffle(plds);
% load the test set
test_imds = imageDatastore('../dataset/mnist/test_images');
test_pxds = pixelLabelDatastore('../dataset/mnist/test_masks',classNames,pixelLabelID);
test_plds= pixelLabelImageDatastore(test_imds,test_pxds);
% Define Segmentation Network
numClasses = 11;
numFilters = 128;
imageSize = [28,28,1];
layers = [
imageInputLayer(imageSize,'Name','input')
% block 1
convolution2dLayer(3,128,'Padding','same','Name','conv1_1')
convolution2dLayer(3,128,'Padding','same','Name','conv1_2')
reluLayer('Name','relu1_2')
batchNormalizationLayer('Name','BN1')
maxPooling2dLayer(2,'Stride',2)
% block 2
convolution2dLayer(3,256,'Padding','same')
convolution2dLayer(3,256,'Padding','same')
reluLayer()
batchNormalizationLayer('Name','BN2')
maxPooling2dLayer(2,'Stride',2)
% block 3
convolution2dLayer(3,512,'Padding','same')
convolution2dLayer(3,512,'Padding','same')
reluLayer()
batchNormalizationLayer('Name','BN3')
% encoder upsampling
transposedConv2dLayer(3,512,'Stride',2,'Cropping','same');
transposedConv2dLayer(5,1024,'Stride',2,'Cropping','same');
batchNormalizationLayer('Name','BN5')
% class layer
convolution2dLayer(1,numClasses);
softmaxLayer()
pixelClassificationLayer('Name','labels','Classes',tbl.Name,'ClassWeights',classWeights)
];
analyzeNetwork(layers)
% define optimizer
opts = trainingOptions('sgdm', ...
'InitialLearnRate',2e-3, ...
'LearnRateSchedule','piecewise',...
'LearnRateDropPeriod',3,...
'LearnRateDropFactor',0.5,...
'MaxEpochs',9,...
'Momentum', 0.9,...
'ExecutionEnvironment','gpu',...
'MiniBatchSize',32, ...
'Plots','training-progress',...
'ValidationPatience',10);
% train the network
net = trainNetwork(plds,layers,opts);
% make predictions
pxdsPred = semanticseg(test_plds,net,'MiniBatchSize', 64, 'WriteLocation','../dataset/mnist_preds');
metrics = evaluateSemanticSegmentation(pxdsPred,test_plds);
filename = strcat('models/mnist/mnist_mp_',sprintf('%.2f',metrics.DataSetMetrics.WeightedIoU),'_iou.mat');
save(filename,'net');