-
Notifications
You must be signed in to change notification settings - Fork 0
/
learner_multilabel.cpp
134 lines (109 loc) · 3.49 KB
/
learner_multilabel.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
// Implementation of multi-label learner
//
// Copyright (C) 2012 Heidelberg University
//
// Author: Sascha Fendrich
//
// This file is part of Sol.
//
// Sol is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Sol is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with Sol. If not, see <http://www.gnu.org/licenses/>.
#include <iostream>
#include <boost/program_options.hpp>
#include "tiny_log.h"
#include "learner_multilabel.h"
#include "weight_vector.h"
#include "sparse_vector.h"
namespace po = boost::program_options;
MultiLabelLearner::MultiLabelLearner ()
{
// Add multi-label-options
po::options_description opt_special ("Multi-label options");
opt_special.add_options ()
("num-labels,c", po::value<int> (&num_labels_), // TODO: required
"number of labels (class labels in 0 ... 2^arg - 1)")
;
options_.add (opt_special);
}
// Initialize multi-label learner
int MultiLabelLearner::Init (int argc, char **argv)
{
// Call parent member
int rv = Learner::Init (argc, argv);
// Setup multi-label specific configuration
num_submodels_ = num_labels_;
num_classes_ = 1 << (num_labels_ - 1);
return rv;
}
// Learn multi-label classifier
bool MultiLabelLearner::SingleUpdate (const DataSet &data_set)
{
int index = rand () % data_set.size ();
int target = int (data_set[index].target ());
bool model_updated = false;
// Update from loss
for (int j = 0; j < model_.num_submodels (); ++j)
{
int current_class = 1 << j;
float target_sign = (target & current_class)?1:-1;
float bias = model_[j].bias ();
float score = model_[j].InnerProduct (data_set[index]) + bias;
if (target_sign * score < 1)
{
model_[j].PlusEquals (target_sign * learning_rate_, data_set[index]);
model_[j].set_bias (bias + learning_rate_ * target_sign);
model_updated = true;
}
}
return model_updated;
}
// Evaluate multi-label classifier
void MultiLabelLearner::Evaluate (const DataSet &data_set)
{
int positive = 0;
int negative = 0;
int count = data_set.size ();
for (int i = 0; i < count; ++i)
{
// Apply model
int predicted_class = 0;
for (int j = 0; j < model_.num_submodels (); ++j)
{
float tmp_score = model_[j].InnerProduct (data_set[i]) +
model_[j].bias ();
if (tmp_score > 0)
{
predicted_class |= 1 << j;
}
}
// Compare prediction with target
int target_value = int (data_set[i].target ());
if (predicted_class == target_value)
positive++;
else
negative++;
// Write predictions
if (print_predictions_)
std::cout << predicted_class << std::endl;
// Report progress
if ((progress_interval_ > 0) && (i % progress_interval_ == 0))
INFO << i << '/' << count << '\r';
}
float result = float (positive) / float (positive + negative);
// Log result
INFO << "result: " << result
<< " (" << positive << '/' << positive + negative << ')' << std::endl;
// Write result to stdout
if (print_result_)
std::cout << result << std::endl;
}