forked from bountylabs/go-fasttext
-
Notifications
You must be signed in to change notification settings - Fork 1
/
prediction.cpp
45 lines (36 loc) · 1.14 KB
/
prediction.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
#include <algorithm>
#include <iomanip>
#include <iostream>
#include <istream>
#include <memory>
#include <queue>
#include <stdexcept>
#include <streambuf>
#include <string_view>
#include <fasttext/include/fasttext.h>
#include "predictions.h"
BEGIN_EXTERN_C()
size_t FastText_Predict(const FastText_Handle_t handle, FastText_String_t query, uint32_t k, float threshold,
FastText_PredictItem_t *const value)
{
const auto model = reinterpret_cast<fasttext::FastText *>(handle);
auto predictions = model->predictFull(k, std::string_view(query.data, query.size), threshold);
const auto count = k > predictions.size() ? predictions.size() : k;
for (size_t i = 0; i < count; i++)
{
const auto &prediction = predictions.at(i);
std::string_view data = prediction.word.substr(LABEL_PREFIX_SIZE);
size_t size = data.size();
if (size > 8)
{
size = 8;
}
value[i].probability = prediction.score;
value[i].lang = FastText_String_t{
.size = size,
.data = (char *)data.data(),
};
}
return count;
}
END_EXTERN_C();