forked from nagadomi/kaggle-lshtc
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathclassifier_storage.hpp
142 lines (130 loc) · 3.36 KB
/
classifier_storage.hpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
#ifndef CLASSIFIER_STORAGE_HPP
#define CLASSIFIER_STORAGE_HPP
#include "binary_classifier.hpp"
#include <cstdio>
// Storage for Binary Classifier
class ClassifierStorage
{
private:
std::map<int, BinaryClassifier> m_classifiers;
public:
ClassifierStorage(){}
void
set(unsigned int category_id,
BinaryClassifier &classifier)
{
#ifdef _OPENMP
#pragma omp critical (classifier_storage)
#endif
{
m_classifiers.insert(std::make_pair(category_id, classifier));
}
}
const BinaryClassifier *
get(unsigned int category_id) const
{
const BinaryClassifier *classifier = 0;
#ifdef _OPENMP
#pragma omp critical (classifier_storage)
#endif
{
auto i = m_classifiers.find(category_id);
if (i != m_classifiers.end()) {
classifier = &i->second;
}
}
return classifier;
}
bool
save(const char *file) const
{
FILE *fp = std::fopen(file, "wb");
if (fp == 0) {
return false;
}
size_t size = m_classifiers.size();
std::fwrite(&size, sizeof(size), 1, fp);
for (auto classifier = m_classifiers.begin();
classifier != m_classifiers.end(); ++classifier)
{
int category_id = classifier->first;
std::map<int, float> ws;
float bias = classifier->second.bias();
classifier->second.nonzero_weights(ws);
size = ws.size();
fwrite(&category_id, sizeof(category_id), 1, fp);
fwrite(&size, sizeof(size), 1, fp);
//printf("category_id: %d, %ld\n", category_id, size);
for (auto w = ws.begin(); w != ws.end(); ++w) {
fwrite(&w->first, sizeof(w->first), 1, fp);
fwrite(&w->second, sizeof(w->second), 1, fp);
}
fwrite(&bias, sizeof(bias), 1, fp);
}
fclose(fp);
return true;
}
bool
load(const char *file)
{
FILE *fp = std::fopen(file, "rb");
if (fp == 0) {
return false;
}
m_classifiers.clear();
size_t classifier_num = 0;
size_t ret = std::fread(&classifier_num, sizeof(classifier_num), 1, fp);
if (ret != 1) {
std::fprintf(stderr, "ClassifierStorage: %s: invalid format 1\n", file);
fclose(fp);
return false;
}
for (size_t i = 0; i < classifier_num; ++i) {
int category_id;
size_t vec_size = 0;
float bias = 0.0f;
std::map<int, float> ws;
ret = fread(&category_id, sizeof(category_id), 1, fp);
if (ret != 1) {
std::fprintf(stderr, "ClassifierStorage: %s: invalid format 2\n", file);
fclose(fp);
return false;
}
ret = fread(&vec_size, sizeof(vec_size), 1, fp);
//printf("category_id: %d, %ld\n", category_id, vec_size);
if (ret != 1) {
std::fprintf(stderr, "ClassifierStorage: %s: invalid format 3\n", file);
fclose(fp);
return false;
}
for (size_t i = 0; i < vec_size; ++i) {
int id;
float val;
ret = fread(&id, sizeof(id), 1, fp);
if (ret != 1) {
std::fprintf(stderr, "ClassifierStorage: %s: invalid format 4\n", file);
fclose(fp);
return false;
}
ret = fread(&val, sizeof(val), 1, fp);
if (ret != 1) {
std::fprintf(stderr, "ClassifierStorage: %s: invalid format 4\n", file);
fclose(fp);
return false;
}
ws.insert(std::make_pair(id, val));
}
ret = fread(&bias, sizeof(bias), 1, fp);
if (ret != 1) {
std::fprintf(stderr, "ClassifierStorage: %s: invalid format 5\n", file);
fclose(fp);
return false;
}
BinaryClassifier classifier(ws, bias);
m_classifiers.insert(std::make_pair(category_id, classifier));
}
fclose(fp);
return true;
}
};
#endif