-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodel_dl_clas.m
134 lines (107 loc) · 3.18 KB
/
model_dl_clas.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
%% import the data from the subject in 8 class test and data collection
% hand open, hand closed, Wrist Flexion, Wrist Extension, Supination,
% Pronation, Rest
fs = 3000;
datasetFolder = "MyoelectricData";
if ~exist(datasetFolder,"dir")
unzip(localfile,datasetFolder)
end
sds1 = signalDatastore(datasetFolder,IncludeSubFolders=true,SampleRate=fs);
p = endsWith(sds1.Files,"d.mat"); %data
sdssig = subset(sds1,p);
sds2 = signalDatastore(datasetFolder,SignalVariableNames=["motion";"data_indx"],IncludeSubfolders=true);
p = endsWith(sds2.Files,"i.mat"); %label
sdslbl = subset(sds2,p);
%% plot the data and label
signal = preview(sdssig);
for i = 1:8
ax(i) = subplot(4,2,i);
plot(signal(:,i))
title("Channel"+i)
end
linkaxes(ax,"y")
%% ROI of the motion of the different motion
lbls = {};
i = 1;
while hasdata(sdslbl)
label = read(sdslbl);
idx_start = label{2}(2:end-1)';
idx_end = [idx_start(2:end)-1;idx_start(end)+(3*fs)];
val = categorical(label{1}(2:end-1)',[1 2 3 4 5 6 7], ... %this is the sequence of motion
["HandOpen" "HandClose" "WristFlexion" "WristExtension" "Supination" "Pronation" "Rest"]);
ROI = [idx_start idx_end];
if numel(val) < size(ROI,1)
ROI(end,:) = [];
elseif numel(val) > size(ROI,1)
val(end) = [];
end
lbltable = table(ROI,val);
lbls{i} = {lbltable};
i = i+1;
end
%% show my data with label within my ROI
lblDS = signalDatastore(lbls);
lblstable = preview(lblDS);
lblstable{1}
%% create the connection between signal with label
DS = combine(sdssig,lblDS);
combinedData = preview(DS);
%% show the motion within the duration
figure
msk = signalMask(combinedData{2});
plotsigroi(msk,combinedData{1}(:,1))
%% preprocessing
addpath('preprocessing.m');
tDS = transform(DS,@preprocessing);
transformedData = preview(tDS)
%% preprocessing before training
rng default
%[trainIdx,~,testIdx] = dividerand(30,0.75,0,0.25);
[trainIdx,~,testIdx] = dividerand(30,0.8,0,0.2);
trainIdx_all = {};
m = 1;
for k = trainIdx
if k == 1
start = k;
else
start = ((k-1)*24)+1;
end
l = start:k*24;
trainIdx_all{m} = l;
m = m+1;
end
trainIdx_all = cell2mat(trainIdx_all)';
trainDS = subset(tDS,trainIdx_all);
testIdx_all = {};
m = 1;
for k = testIdx
if k == 1
start = k;
else
start = ((k-1)*24)+1;
end
l = start:k*24;
testIdx_all{m} = l;
m = m+1;
end
testIdx_all = cell2mat(testIdx_all)';
testDS = subset(tDS,testIdx_all);
%% LSTM network
layers = [sequenceInputLayer(8),lstmLayer(80,OutputMode="sequence"),fullyConnectedLayer(4),softmaxLayer,classificationLayer];
% adam and shuffle
options = trainingOptions("adam", ...
MaxEpochs=100, ...
MiniBatchSize=32, ...
Plots="training-progress",...
InitialLearnRate=0.001,...
Verbose=0,...
Shuffle="every-epoch",...
GradientThreshold=1e5,...
DispatchInBackground=true);
%% train my model
traindata = readall(trainDS,"UseParallel",true);
rawNet = trainNetwork(traindata(:,1),traindata(:,2),layers,options);
%% evaluate
testdata = readall(testDS);
predTest = classify(rawNet,testdata(:,1),MiniBatchSize=32);
confusionchart([testdata{:,2}],[predTest{:}],Normalization="column-normalized")