-
Notifications
You must be signed in to change notification settings - Fork 52
/
onevsall.q
38 lines (29 loc) · 1.1 KB
/
onevsall.q
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
\c 20 100
\l funq.q
\l mnist.q
/ digit recognition
-1"referencing mnist data from global namespace";
`X`Xt`Y`y`yt set' mnist`X`Xt`Y`y`yt;
-1"shrinking training set";
X:1000#'X;Y:1000#'Y;y:1000#y;
X%:255f;Xt%:255f
-1"define a plot function that includes the empty space character";
plt:value .ut.plot[28;14;.ut.c10;avg] .ut.hmap flip 28 cut
-1"visualize the data";
-1 (,'/) plt each X@\:/: -4?count X 0;
lbls:"i"$til 10
rf:.ml.l2[1] / regularization function
theta:(1+count X)#0f / initial theta coefficients
f:first .fmincg.fmincg[5;;theta] .ml.logcostgrad[rf;;X]@
-1"to run one-vs-all",$[count rf;" with regularization";""];
-1"we perform multiple runs of logistic regression (one for each digit)";
-1"this trains one set of parameters for each number";
-1 .ut.box["**"] "for performance, we peach across digits";
THETA:.ml.fova[f;Y;lbls]
-1"checking accuracy of parameters";
avg yt=p:lbls .ml.imax .ml.plog[Xt] THETA
-1"view a few confused characters";
w:where not yt=p
do[2;-1 plt Xt[;i:rand w];show ([]p;yt) i]
-1"view the confusion matrix";
show .ut.totals[`TOTAL] .ml.cm[yt;p]