-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_ranking.m
70 lines (62 loc) · 1.77 KB
/
train_ranking.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
%% Load and prep data
loadData
%% Ranking SVM
tryC = [.01];
tryn0 = [100];
tryIter = [5];
% tryC = [0.01, 0.1, 1, 10, 100, 1000];
% tryn0 = [0.01, 0.5, 1.0, 10, 100];
% tryIter = [5,50,100,1000,5000,6000];
% tryC = [0.00001, 0.0001, 0.001, .01];
% tryn0 = [0.01, 0.5, 1.0, 10, 100];
% tryIter = [5,50,100,1000,5000,6000];
w = zeros(1,size(trainMatrix, 2));
numTrain = size(trainMatrix, 1);
numTest = size(testMatrix, 1);
divisor = numTrain*(numTrain-1)/2;
% Best result and hyper-parameters
bestCorr = -1;
bestC = 0;
bestn0 = 0;
bestIter = 5;
% IMPORTANT: Flip matrix upside down to get ranking order from greatest to
% least.
trainMatrix = flipud(trainMatrix);
trainLabel = flipud(trainLabel);
% testing the hyper-parameter
for C = tryC
for n0 = tryn0
g = @(n0,i)n0/(1+i*n0);
for iter = tryIter
w = zeros(1,size(trainMatrix, 2));
for n = 1:iter
for i = 1:numTrain-1
for j = i+1:numTrain
indicator = logical((trainMatrix(i,:) - trainMatrix(j,:))*w' < 1);
wGrad = w/divisor - C*indicator*(trainMatrix(i,:) - trainMatrix(j,:));
w = w - g(n0,n)*wGrad;
end
end
end
% Run w on train results
trainLabel_pred = trainMatrix*w';
M = corrcoef(trainLabel, trainLabel_pred);
fprintf('Train correlation = %d\n\n', M(2,1));
% Run w on val results
testLabel_pred = testMatrix*w';
M = corrcoef(testLabel, testLabel_pred);
fprintf('Val correlation = %d\n\n', M(2,1));
fprintf('Hyper-parameters: corr = %d, C = %d, ', M(2,1), C);
fprintf('bestn0 = %d, bestIter = %d, -----------\n\n', n0, iter);
% Update best hyper-parameters
if(M(2,1) > bestCorr)
bestCorr = M(2,1);
bestIter = iter;
bestC = C;
bestn0 = n0;
end
end
end
end
fprintf('Best hyper-parameters: bestCorr = %d, bestC = %d, ', bestCorr, bestC);
fprintf('bestn0 = %d, bestIter = %d, -----------\n\n', bestn0, bestIter);