-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsoftmaxSolver_Demo.m
76 lines (60 loc) · 1.81 KB
/
softmaxSolver_Demo.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
function softmaxSolver_Demo_tc()
% Exercise -- two-class softmax solver
clear all; close all; clc
%% generate data
nsamples = 200;
% training data
[x, y] = tcdataGenerator(nsamples, 0.5, 'normal');
y(find(y==-1)) = 2;
% testing data
[xt, yt] = tcdataGenerator(nsamples, 0.5, 'normal');
yt(find(yt==-1))=2;
%%
% FastDescent ConjugateGradient Newton FixedNewton DFP BFGS SGD
option.C = 1;
option.debug = 0;
options.epochs = 3;
options.minibatch = 50;
options.alpha = 1e-1;
options.momentum = .95;
[theta, cost] = softmaxSGD(x, y, option)
%% Visualize Results
figure(1)
subplot(121)
xmin = min(x(:))-1;
xmax = max(x(:))+1;
data_pos = x(find(y==1),:);
data_neg = x(find(y==2),:);
scatter(data_pos(:, 1), data_pos(:, 2), 'b+', 'SizeData', 200, 'LineWidth', 2);
hold on
scatter(data_neg(:, 1), data_neg(:, 2), 'gx', 'SizeData', 200, 'LineWidth', 2);
axis tight
margin = xmin:0.1:xmax;
plot(margin, (-theta(1)-margin*theta(2))/theta(3), 'r-', 'LineWidth', 2);
hold off
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
option.C = 100;
[theta, cost] = softmaxLBFGS(x, y, option)
%% Visualize Results
subplot(122)
xmin = min(x(:, 1))-1;
xmax = max(x(:, 1))+1;
data_pos = x(find(y==1),:);
data_neg = x(find(y==2),:);
scatter(data_pos(:, 1), data_pos(:, 2), 'b+', 'SizeData', 200, 'LineWidth', 2);
hold on
scatter(data_neg(:, 1), data_neg(:, 2), 'gx', 'SizeData', 200, 'LineWidth', 2);
axis tight
margin = xmin:0.1:xmax;
plot(margin, (-theta(1)-margin*theta(2))/theta(3), 'r-', 'LineWidth', 2);
hold off
%% predict
xx = [ones(size(x, 1), 1), x];
h = softmaxFunc(xx, theta);
[v p] = max(h, [], 2);
acc = sum(p==y)/length(p);
disp(['train acc: ', num2str(acc)]);
h = softmaxFunc(xx, theta);
[v p] = max(h, [], 2);
acc = sum(p==yt)/length(p);
disp(['test acc: ', num2str(acc)]);