forked from VowpalWabbit/vowpal_wabbit
-
Notifications
You must be signed in to change notification settings - Fork 0
/
global_data.cc
106 lines (96 loc) · 2.35 KB
/
global_data.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
#include <pthread.h>
#include <stdio.h>
#include <float.h>
#include "global_data.h"
#include "multisource.h"
#include "message_relay.h"
using namespace std;
global_data global;
string version = "6.0";
pthread_mutex_t output_lock = PTHREAD_MUTEX_INITIALIZER;
pthread_cond_t output_done = PTHREAD_COND_INITIALIZER;
void binary_print_result(int f, float res, float weight, v_array<char> tag)
{
if (f >= 0)
{
global_prediction ps = {res, weight};
send_global_prediction(f, ps);
}
}
void print_result(int f, float res, float weight, v_array<char> tag)
{
if (f >= 0)
{
char temp[30];
int num = sprintf(temp, "%f", res);
ssize_t t;
t = write(f, temp, num);
if (t != num)
cerr << "write error" << endl;
if (tag.begin != tag.end){
temp[0] = ' ';
t = write(f, temp, 1);
if (t != 1)
cerr << "write error" << endl;
t = write(f, tag.begin, sizeof(char)*tag.index());
if (t != (ssize_t) (sizeof(char)*tag.index()))
cerr << "write error" << endl;
}
if(global.active && weight >= 0)
{
num = sprintf(temp, " %f", weight);
t = write(f, temp, num);
if (t != num)
cerr << "write error" << endl;
}
temp[0] = '\n';
t = write(f, temp, 1);
if (t != 1)
cerr << "write error" << endl;
}
}
void print_lda_result(int f, float* res, float weight, v_array<char> tag)
{
if (f >= 0)
{
char temp[30];
ssize_t t;
int num;
for (size_t k = 0; k < global.lda; k++)
{
num = sprintf(temp, "%f ", res[k]);
t = write(f, temp, num);
if (t != num)
cerr << "write error" << endl;
}
if (tag.begin != tag.end){
temp[0] = ' ';
t = write(f, temp, 1);
if (t != 1)
cerr << "write error" << endl;
t = write(f, tag.begin, sizeof(char)*tag.index());
if (t != (ssize_t) (sizeof(char)*tag.index()))
cerr << "write error" << endl;
}
if(global.active && weight >= 0)
{
num = sprintf(temp, " %f", weight);
t = write(f, temp, num);
if (t != num)
cerr << "write error" << endl;
}
temp[0] = '\n';
t = write(f, temp, 1);
if (t != 1)
cerr << "write error" << endl;
}
}
void set_mm(double label)
{
global.sd->min_label = min(global.sd->min_label, label);
if (label != FLT_MAX)
global.sd->max_label = max(global.sd->max_label, label);
}
void noop_mm(double label)
{}
void (*set_minmax)(double label) = set_mm;