From 583bf8d7398d35cf0407223a238b77b865539d49 Mon Sep 17 00:00:00 2001 From: Dusan Malusev Date: Thu, 18 Apr 2024 23:21:53 +0200 Subject: [PATCH] Implement better way (allocation free mostly) language detection Signed-off-by: Dusan Malusev --- .vscode/settings.json | 76 ---------- cbits.cpp | 173 ---------------------- fasttext.cpp | 2 +- fasttext.go | 150 +++---------------- fasttext/{ => include}/aligned.h | 0 fasttext/{ => include}/args.h | 0 fasttext/{ => include}/autotune.h | 0 fasttext/{ => include}/densematrix.h | 0 fasttext/{ => include}/dictionary.h | 2 +- fasttext/{ => include}/fasttext.h | 3 +- fasttext/{ => include}/fullprediction.h | 8 +- fasttext/{ => include}/loss.h | 0 fasttext/{ => include}/matrix.h | 0 fasttext/{ => include}/meter.h | 0 fasttext/{ => include}/model.h | 0 fasttext/{ => include}/productquantizer.h | 0 fasttext/{ => include}/quantmatrix.h | 0 fasttext/{ => include}/real.h | 0 fasttext/{ => include}/utils.h | 0 fasttext/{ => include}/vector.h | 0 fasttext/{ => src}/args.cc | 0 fasttext/{ => src}/autotune.cc | 0 fasttext/{ => src}/densematrix.cc | 0 fasttext/{ => src}/dictionary.cc | 4 +- fasttext/{ => src}/fasttext.cc | 25 +++- fasttext/{ => src}/loss.cc | 0 fasttext/{ => src}/main.cc | 0 fasttext/{ => src}/matrix.cc | 0 fasttext/{ => src}/meter.cc | 0 fasttext/{ => src}/model.cc | 0 fasttext/{ => src}/productquantizer.cc | 0 fasttext/{ => src}/quantmatrix.cc | 0 fasttext/{ => src}/utils.cc | 0 fasttext/{ => src}/vector.cc | 0 fasttext_test.go | 15 -- fasttextlib.cpp | 12 ++ handle.cpp | 39 +++++ helpers.go | 9 ++ prediction.cpp | 45 ++++++ predictions.cpp | 60 ++++++++ cbits.h => predictions.h | 22 +-- 41 files changed, 229 insertions(+), 416 deletions(-) delete mode 100644 cbits.cpp rename fasttext/{ => include}/aligned.h (100%) rename fasttext/{ => include}/args.h (100%) rename fasttext/{ => include}/autotune.h (100%) rename fasttext/{ => include}/densematrix.h (100%) rename fasttext/{ => include}/dictionary.h (97%) rename fasttext/{ => include}/fasttext.h (97%) rename fasttext/{ => include}/fullprediction.h (75%) rename fasttext/{ => include}/loss.h (100%) rename fasttext/{ => include}/matrix.h (100%) rename fasttext/{ => include}/meter.h (100%) rename fasttext/{ => include}/model.h (100%) rename fasttext/{ => include}/productquantizer.h (100%) rename fasttext/{ => include}/quantmatrix.h (100%) rename fasttext/{ => include}/real.h (100%) rename fasttext/{ => include}/utils.h (100%) rename fasttext/{ => include}/vector.h (100%) rename fasttext/{ => src}/args.cc (100%) rename fasttext/{ => src}/autotune.cc (100%) rename fasttext/{ => src}/densematrix.cc (100%) rename fasttext/{ => src}/dictionary.cc (99%) rename fasttext/{ => src}/fasttext.cc (98%) rename fasttext/{ => src}/loss.cc (100%) rename fasttext/{ => src}/main.cc (100%) rename fasttext/{ => src}/matrix.cc (100%) rename fasttext/{ => src}/meter.cc (100%) rename fasttext/{ => src}/model.cc (100%) rename fasttext/{ => src}/productquantizer.cc (100%) rename fasttext/{ => src}/quantmatrix.cc (100%) rename fasttext/{ => src}/utils.cc (100%) rename fasttext/{ => src}/vector.cc (100%) create mode 100644 fasttextlib.cpp create mode 100644 handle.cpp create mode 100644 helpers.go create mode 100644 prediction.cpp create mode 100644 predictions.cpp rename cbits.h => predictions.h (59%) diff --git a/.vscode/settings.json b/.vscode/settings.json index 8c78476..de6bc22 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -148,12 +148,6 @@ "editor.formatOnSave": true, "editor.formatOnType": true, "editor.cursorBlinking": "smooth", - "[jsonc]": { - "editor.defaultFormatter": "vscode.json-language-features" - }, - "[json]": { - "editor.defaultFormatter": "vscode.json-language-features" - }, "git.alwaysSignOff": true, "git.autofetch": true, "git.ignoreLimitWarning": true, @@ -161,9 +155,6 @@ "files.trimTrailingWhitespace": true, "explorer.incrementalNaming": "smart", "explorer.sortOrder": "type", - "[proto3]": { - "editor.defaultFormatter": "zxh404.vscode-proto3" - }, "files.exclude": { "**/.idea/": true, "**/.nuke": true, @@ -177,71 +168,4 @@ "debug.allowBreakpointsEverywhere": true, "debug.console.historySuggestions": true, "debug.console.collapseIdenticalLines": true, - "ansible.python.interpreterPath": "/bin/python", - "files.associations": { - "*main.yml": "ansible", - "string_view": "cpp", - "istream": "cpp", - "array": "cpp", - "atomic": "cpp", - "bit": "cpp", - "*.tcc": "cpp", - "cctype": "cpp", - "charconv": "cpp", - "chrono": "cpp", - "clocale": "cpp", - "cmath": "cpp", - "compare": "cpp", - "concepts": "cpp", - "csignal": "cpp", - "cstdarg": "cpp", - "cstddef": "cpp", - "cstdint": "cpp", - "cstdio": "cpp", - "cstdlib": "cpp", - "cstring": "cpp", - "ctime": "cpp", - "cwchar": "cpp", - "cwctype": "cpp", - "deque": "cpp", - "set": "cpp", - "string": "cpp", - "unordered_map": "cpp", - "unordered_set": "cpp", - "vector": "cpp", - "exception": "cpp", - "algorithm": "cpp", - "functional": "cpp", - "iterator": "cpp", - "memory": "cpp", - "memory_resource": "cpp", - "numeric": "cpp", - "optional": "cpp", - "random": "cpp", - "ratio": "cpp", - "system_error": "cpp", - "tuple": "cpp", - "type_traits": "cpp", - "utility": "cpp", - "format": "cpp", - "fstream": "cpp", - "initializer_list": "cpp", - "iomanip": "cpp", - "iosfwd": "cpp", - "iostream": "cpp", - "limits": "cpp", - "new": "cpp", - "numbers": "cpp", - "ostream": "cpp", - "semaphore": "cpp", - "span": "cpp", - "sstream": "cpp", - "stdexcept": "cpp", - "stop_token": "cpp", - "streambuf": "cpp", - "text_encoding": "cpp", - "thread": "cpp", - "typeinfo": "cpp", - "variant": "cpp" - } } diff --git a/cbits.cpp b/cbits.cpp deleted file mode 100644 index b8149ca..0000000 --- a/cbits.cpp +++ /dev/null @@ -1,173 +0,0 @@ -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "cbits.h" - -#define FREE_STRING(str) \ - do \ - { \ - if (str.data != nullptr) \ - free(str.data); \ - str.data = nullptr; \ - str.size = 0; \ - } while (0) - -#define LABEL_PREFIX ("__label__") -#define LABEL_PREFIX_SIZE (sizeof(LABEL_PREFIX) - 1) - -using Predictions = std::vector>; - -struct membuf : std::streambuf -{ - membuf(FastText_String_t query) - { - this->setg(query.data, query.data, query.data + query.size); - } -}; - -BEGIN_EXTERN_C() -FastText_Result_t FastText_NewHandle(FastText_String_t path) -{ - auto model = new fasttext::FastText(); - - try - { - model->loadModel(std::string(path.data, path.size)); - return FastText_Result_t{ - FastText_Result_t::SUCCESS, - (FastText_Handle_t)model, - }; - } - catch (std::exception &e) - { - return FastText_Result_t{ - FastText_Result_t::ERROR, - strdup(e.what()), - }; - } -} - -void FastText_DeleteHandle(const FastText_Handle_t handle) -{ - if (handle != nullptr) - { - return; - } - - const auto model = reinterpret_cast(handle); - delete model; -} - -size_t FastText_Predict(const FastText_Handle_t handle, FastText_String_t query, uint32_t k, float threshold, - FastText_PredictItem_t *const value) -{ - const auto model = reinterpret_cast(handle); - auto predictions = model->predictFull(k, std::string_view(query.data, query.size), threshold); - const auto count = k > predictions.size() ? predictions.size() : k; - - for (size_t i = 0; i < count; i++) - { - const auto &prediction = predictions.at(i); - - std::string_view data = prediction.word.substr(LABEL_PREFIX_SIZE); - size_t size = data.size(); - - if (size > 8) - { - size = 8; - } - - value[i].probability = prediction.score; - value[i].lang = FastText_String_t{ - .size = size, - .data = (char *)data.data(), - }; - } - - return count; -} - -// FastText_Predict_t FastText_Analogy(const FastText_Handle_t handle, FastText_String_t word1, FastText_String_t word2, -// FastText_String_t word3, uint32_t k) -// { -// const auto model = reinterpret_cast(handle); -// Predictions predictions = -// model->getAnalogies(k, std::string(word1.data, word1.size), std::string(word2.data, word2.size), -// std::string(word3.data, word3.size)); - -// auto vec = new Predictions(std::move(predictions)); - -// return FastText_Predict_t{ -// vec->size(), -// (void *)vec, -// }; -// } - -FastText_FloatVector_t FastText_Wordvec(const FastText_Handle_t handle, FastText_String_t word) -{ - const auto model = reinterpret_cast(handle); - int64_t dimensions = model->getDimension(); - - auto vec = new fasttext::Vector(dimensions); - model->getWordVector(*vec, std::string(word.data, word.size)); - - return FastText_FloatVector_t{ - vec->data(), - (void *)vec, - (size_t)vec->size(), - }; -} - -FastText_FloatVector_t FastText_Sentencevec(const FastText_Handle_t handle, FastText_String_t sentence) -{ - const auto model = reinterpret_cast(handle); - - membuf sbuf(sentence); - std::istream in(&sbuf); - - auto vec = new fasttext::Vector(model->getDimension()); - model->getSentenceVector(in, *vec); - FREE_STRING(sentence); - - return FastText_FloatVector_t{ - vec->data(), - (void *)vec, - (size_t)vec->size(), - }; -} - -void FastText_FreeFloatVector(FastText_FloatVector_t vector) -{ - auto vec = reinterpret_cast(vector.handle); - delete vec; -} - -void FastText_FreePredict(FastText_Predict_t predict) -{ - auto vec = reinterpret_cast(predict.data); - delete vec; -} - -END_EXTERN_C() diff --git a/fasttext.cpp b/fasttext.cpp index f99f22e..cde5788 100644 --- a/fasttext.cpp +++ b/fasttext.cpp @@ -1 +1 @@ -#include +#include diff --git a/fasttext.go b/fasttext.go index c7704a7..28d840b 100644 --- a/fasttext.go +++ b/fasttext.go @@ -1,11 +1,10 @@ package fasttext -// #cgo CXXFLAGS: -I${SRCDIR}/fasttext -I${SRCDIR} -std=c++17 -O3 -fPIC -pthread -march=native -// #cgo LDFLAGS: -lstdc++ +// #cgo CXXFLAGS: -I${SRCDIR}/fasttext/include -I${SRCDIR} -std=c++17 -Ofast -fPIC -pthread -Wno-defaulted-function-deleted // #include // #include // #include -// #include "cbits.h" +// #include "predictions.h" import "C" import ( @@ -84,35 +83,6 @@ func (handle *Model) MultiLinePredict(lines []string, k int32, threshoad float32 return predics, nil } -// func (handle *Model) PredictOne(query string, threshoad float32) (Prediction, error) { -// r := C.FastText_Predict( -// handle.p, -// C.FastText_String_t{ -// data: cStr(query), -// size: C.size_t(len(query)), -// }, -// 1, -// C.float(threshoad), -// ) - -// if r.data == nil { -// return Prediction{}, ErrPredictionFailed -// } - -// defer C.FastText_FreePredict(r) - -// if r.size == 0 { -// return Prediction{}, ErrNoPredictions -// } - -// cPredic := C.FastText_PredictItemAt(r, C.size_t(0)) - -// return Prediction{ -// Label: C.GoStringN(cPredic.label.data, C.int(cPredic.label.size)), -// Probability: float32(cPredic.probability), -// }, nil -// } - // Perform model prediction func (handle *Model) Predict(query string, k int32, threshoad float32) (Predictions, error) { var pinner runtime.Pinner @@ -151,107 +121,25 @@ func (handle *Model) Predict(query string, k int32, threshoad float32) (Predicti return predictions, nil } -// func (handle *Model) Analogy(word1, word2, word3 string, k int32) Analogs { -// // cWord1 := ((*C.char) unsafe.Pointer(unsafe.StringData(word1))) - -// var pinner runtime.Pinner -// defer pinner.Unpin() - -// pinner.Pin(word1) -// pinner.Pin(word2) -// pinner.Pin(word3) - -// strWord1 := cStr(word1) -// pinner.Pin(strWord1) -// strWord2 := cStr(word2) -// pinner.Pin(strWord2) -// strWord3 := cStr(word3) -// pinner.Pin(strWord3) - -// r := C.FastText_Analogy( -// handle.p, -// C.FastText_String_t{ -// data: strWord1, -// size: C.size_t(len(word1)), -// }, -// C.FastText_String_t{ -// data: strWord2, -// size: C.size_t(len(word2)), -// }, -// C.FastText_String_t{ -// data: strWord3, -// size: C.size_t(len(word3)), -// }, -// C.uint32_t(k), -// ) - -// defer C.FastText_FreePredict(r) - -// analogs := make(Analogs, r.size) - -// for i := uint64(0); i < uint64(r.size); i++ { -// cPredic := C.FastText_PredictItemAt(r, C.size_t(i)) - -// analogs[i] = Analog{ -// Name: C.GoStringN(cPredic.label.data, C.int(cPredic.label.size)), -// Probability: float32(cPredic.probability), -// } -// } - -// return analogs -// } - -// func (handle Model) Wordvec(word string) []float32 { -// var pinner runtime.Pinner -// defer pinner.Unpin() - -// pinner.Pin(word) -// strData := cStr(word) -// pinner.Pin(strData) - -// r := C.FastText_Wordvec( -// handle.p, -// C.FastText_String_t{ -// data: strData, -// size: C.size_t(len(word)), -// }, -// ) -// defer C.FastText_FreeFloatVector(r) - -// vectors := make([]float32, r.size) -// pinner.Pin(r.data) - -// ptr := (*float32)(unsafe.Pointer(r.data)) -// pinner.Pin(ptr) - -// copy(vectors, unsafe.Slice(ptr, r.size)) - -// return vectors -// } - -// Sentencevec requires sentence ends with -// func (handle Model) Sentencevec(query string) []float32 { -// var pinner runtime.Pinner -// defer pinner.Unpin() -// pinner.Pin(query) -// strData := cStr(query) -// pinner.Pin(strData) -// r := C.FastText_Sentencevec(handle.p, C.FastText_String_t{ -// data: strData, -// size: C.size_t(len(query)), -// }) +func (handle Model) Wordvec(word string) []float32 { + var pinner runtime.Pinner + defer pinner.Unpin() -// defer C.FastText_FreeFloatVector(r) + strData := cStr(word) + pinner.Pin(strData) -// vectors := make([]float32, r.size) -// pinner.Pin(r.data) -// ptr := (*float32)(unsafe.Pointer(r.data)) -// pinner.Pin(ptr) -// copy(vectors, unsafe.Slice(ptr, r.size)) + r := C.FastText_Wordvec( + handle.p, + C.FastText_String_t{ + data: strData, + size: C.size_t(len(word)), + }, + ) -// return vectors -// } + defer C.FastText_FreeFloatVector(r) -func cStr(str string) *C.char { - return (*C.char)(unsafe.Pointer(unsafe.StringData(str))) + vectors := make([]float32, r.size) + ptr := (*float32)(unsafe.Pointer(r.data)) + copy(vectors, unsafe.Slice(ptr, r.size)) + return vectors } diff --git a/fasttext/aligned.h b/fasttext/include/aligned.h similarity index 100% rename from fasttext/aligned.h rename to fasttext/include/aligned.h diff --git a/fasttext/args.h b/fasttext/include/args.h similarity index 100% rename from fasttext/args.h rename to fasttext/include/args.h diff --git a/fasttext/autotune.h b/fasttext/include/autotune.h similarity index 100% rename from fasttext/autotune.h rename to fasttext/include/autotune.h diff --git a/fasttext/densematrix.h b/fasttext/include/densematrix.h similarity index 100% rename from fasttext/densematrix.h rename to fasttext/include/densematrix.h diff --git a/fasttext/dictionary.h b/fasttext/include/dictionary.h similarity index 97% rename from fasttext/dictionary.h rename to fasttext/include/dictionary.h index 861dbfc..e2a6be0 100644 --- a/fasttext/dictionary.h +++ b/fasttext/include/dictionary.h @@ -83,7 +83,7 @@ class Dictionary bool discard(int32_t, real) const; std::string getWord(int32_t) const; const std::vector &getSubwords(int32_t) const; - const std::vector getSubwords(const std::string &) const; + const std::vector getSubwords(const std::string_view) const; void getSubwords(const std::string &, std::vector &, std::vector &) const; void computeSubwords(const std::string &, std::vector &, std::vector *substrings = nullptr) const; diff --git a/fasttext/fasttext.h b/fasttext/include/fasttext.h similarity index 97% rename from fasttext/fasttext.h rename to fasttext/include/fasttext.h index 7927a8d..acc6efc 100644 --- a/fasttext/fasttext.h +++ b/fasttext/include/fasttext.h @@ -84,7 +84,8 @@ class FastText int32_t getLabelId(const std::string &label) const; - void getWordVector(Vector &vec, const std::string &word) const; + Vector getWordVector(const std::string_view word) const; + void getWordVector(Vector &vec, const std::string_view word) const; void getSubwordVector(Vector &vec, const std::string &subword) const; diff --git a/fasttext/fullprediction.h b/fasttext/include/fullprediction.h similarity index 75% rename from fasttext/fullprediction.h rename to fasttext/include/fullprediction.h index 16777e0..a5a3438 100644 --- a/fasttext/fullprediction.h +++ b/fasttext/include/fullprediction.h @@ -14,16 +14,16 @@ class FullPrediction }; private: - Predictions predictions_; + fasttext::Predictions predictions_; std::shared_ptr dict_; public: - FullPrediction(Predictions &&predictions, const std::shared_ptr &dict) + FullPrediction(fasttext::Predictions &&predictions, const std::shared_ptr &dict) : predictions_(std::move(predictions)), dict_(dict) { } - Item at(size_t idx) const + inline Item at(size_t idx) const { const auto &prediction = predictions_.at(idx); @@ -35,7 +35,7 @@ class FullPrediction return item; } - size_t size() const + inline size_t size() const { return predictions_.size(); } diff --git a/fasttext/loss.h b/fasttext/include/loss.h similarity index 100% rename from fasttext/loss.h rename to fasttext/include/loss.h diff --git a/fasttext/matrix.h b/fasttext/include/matrix.h similarity index 100% rename from fasttext/matrix.h rename to fasttext/include/matrix.h diff --git a/fasttext/meter.h b/fasttext/include/meter.h similarity index 100% rename from fasttext/meter.h rename to fasttext/include/meter.h diff --git a/fasttext/model.h b/fasttext/include/model.h similarity index 100% rename from fasttext/model.h rename to fasttext/include/model.h diff --git a/fasttext/productquantizer.h b/fasttext/include/productquantizer.h similarity index 100% rename from fasttext/productquantizer.h rename to fasttext/include/productquantizer.h diff --git a/fasttext/quantmatrix.h b/fasttext/include/quantmatrix.h similarity index 100% rename from fasttext/quantmatrix.h rename to fasttext/include/quantmatrix.h diff --git a/fasttext/real.h b/fasttext/include/real.h similarity index 100% rename from fasttext/real.h rename to fasttext/include/real.h diff --git a/fasttext/utils.h b/fasttext/include/utils.h similarity index 100% rename from fasttext/utils.h rename to fasttext/include/utils.h diff --git a/fasttext/vector.h b/fasttext/include/vector.h similarity index 100% rename from fasttext/vector.h rename to fasttext/include/vector.h diff --git a/fasttext/args.cc b/fasttext/src/args.cc similarity index 100% rename from fasttext/args.cc rename to fasttext/src/args.cc diff --git a/fasttext/autotune.cc b/fasttext/src/autotune.cc similarity index 100% rename from fasttext/autotune.cc rename to fasttext/src/autotune.cc diff --git a/fasttext/densematrix.cc b/fasttext/src/densematrix.cc similarity index 100% rename from fasttext/densematrix.cc rename to fasttext/src/densematrix.cc diff --git a/fasttext/dictionary.cc b/fasttext/src/dictionary.cc similarity index 99% rename from fasttext/dictionary.cc rename to fasttext/src/dictionary.cc index ab96b4e..bfe22cb 100644 --- a/fasttext/dictionary.cc +++ b/fasttext/src/dictionary.cc @@ -92,7 +92,7 @@ const std::vector &Dictionary::getSubwords(int32_t i) const return words_[i].subwords; } -const std::vector Dictionary::getSubwords(const std::string &word) const +const std::vector Dictionary::getSubwords(const std::string_view word) const { int32_t i = getId(word); if (i >= 0) @@ -102,7 +102,7 @@ const std::vector Dictionary::getSubwords(const std::string &word) cons std::vector ngrams; if (word != EOS) { - computeSubwords(BOW + word + EOW, ngrams); + computeSubwords(BOW + word.data() + EOW, ngrams); } return ngrams; } diff --git a/fasttext/fasttext.cc b/fasttext/src/fasttext.cc similarity index 98% rename from fasttext/fasttext.cc rename to fasttext/src/fasttext.cc index d5968d3..1c94662 100644 --- a/fasttext/fasttext.cc +++ b/fasttext/src/fasttext.cc @@ -120,14 +120,35 @@ int32_t FastText::getLabelId(const std::string &label) const return labelId; } -void FastText::getWordVector(Vector &vec, const std::string &word) const +Vector FastText::getWordVector(const std::string_view word) const +{ + const std::vector &ngrams = dict_->getSubwords(word); + Vector vec(args_->dim); + vec.zero(); + + for (int i = 0; i < ngrams.size(); i++) + { + addInputVector(vec, ngrams[i]); + } + + if (ngrams.size() > 0) + { + vec.mul(1.0 / ngrams.size()); + } + + return std::move(vec); +} + +void FastText::getWordVector(Vector &vec, std::string_view word) const { const std::vector &ngrams = dict_->getSubwords(word); vec.zero(); + for (int i = 0; i < ngrams.size(); i++) { addInputVector(vec, ngrams[i]); } + if (ngrams.size() > 0) { vec.mul(1.0 / ngrams.size()); @@ -158,7 +179,7 @@ void FastText::saveVectors(const std::string &filename) for (int32_t i = 0; i < dict_->nwords(); i++) { std::string word = dict_->getWord(i); - getWordVector(vec, word); + getWordVector(vec, std::string_view(word)); ofs << word << " " << vec << std::endl; } ofs.close(); diff --git a/fasttext/loss.cc b/fasttext/src/loss.cc similarity index 100% rename from fasttext/loss.cc rename to fasttext/src/loss.cc diff --git a/fasttext/main.cc b/fasttext/src/main.cc similarity index 100% rename from fasttext/main.cc rename to fasttext/src/main.cc diff --git a/fasttext/matrix.cc b/fasttext/src/matrix.cc similarity index 100% rename from fasttext/matrix.cc rename to fasttext/src/matrix.cc diff --git a/fasttext/meter.cc b/fasttext/src/meter.cc similarity index 100% rename from fasttext/meter.cc rename to fasttext/src/meter.cc diff --git a/fasttext/model.cc b/fasttext/src/model.cc similarity index 100% rename from fasttext/model.cc rename to fasttext/src/model.cc diff --git a/fasttext/productquantizer.cc b/fasttext/src/productquantizer.cc similarity index 100% rename from fasttext/productquantizer.cc rename to fasttext/src/productquantizer.cc diff --git a/fasttext/quantmatrix.cc b/fasttext/src/quantmatrix.cc similarity index 100% rename from fasttext/quantmatrix.cc rename to fasttext/src/quantmatrix.cc diff --git a/fasttext/utils.cc b/fasttext/src/utils.cc similarity index 100% rename from fasttext/utils.cc rename to fasttext/src/utils.cc diff --git a/fasttext/vector.cc b/fasttext/src/vector.cc similarity index 100% rename from fasttext/vector.cc rename to fasttext/src/vector.cc diff --git a/fasttext_test.go b/fasttext_test.go index 6c316fe..fb32e4b 100644 --- a/fasttext_test.go +++ b/fasttext_test.go @@ -29,21 +29,6 @@ func TestOpen(t *testing.T) { }) } -// func TestPredictOne(t *testing.T) { -// t.Parallel() -// assert := require.New(t) - -// model, err := fasttext.Open("testdata/lid.176.ftz") -// assert.NoError(err) - -// prediction, err := model.PredictOne("hello world from my dear C++", 0.0) - -// assert.NoError(err) -// assert.NotEmpty(prediction) -// assert.Equal("en", prediction.Label) -// assert.Greater(prediction.Probability, float32(0.7)) -// } - func TestMultilinePredict(t *testing.T) { t.Parallel() assert := require.New(t) diff --git a/fasttextlib.cpp b/fasttextlib.cpp new file mode 100644 index 0000000..12e9b8e --- /dev/null +++ b/fasttextlib.cpp @@ -0,0 +1,12 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include diff --git a/handle.cpp b/handle.cpp new file mode 100644 index 0000000..b241da4 --- /dev/null +++ b/handle.cpp @@ -0,0 +1,39 @@ +#include +#include + +#include "predictions.h" + +BEGIN_EXTERN_C() + +FastText_Result_t FastText_NewHandle(FastText_String_t path) +{ + auto model = new fasttext::FastText(); + + try + { + model->loadModel(std::string(path.data, path.size)); + return FastText_Result_t{ + FastText_Result_t::SUCCESS, + (FastText_Handle_t)model, + }; + } + catch (std::exception &e) + { + return FastText_Result_t{ + FastText_Result_t::ERROR, + strdup(e.what()), + }; + } +} + +void FastText_DeleteHandle(const FastText_Handle_t handle) +{ + if (handle != nullptr) + { + return; + } + + const auto model = reinterpret_cast(handle); + delete model; +} +END_EXTERN_C() diff --git a/helpers.go b/helpers.go new file mode 100644 index 0000000..a643128 --- /dev/null +++ b/helpers.go @@ -0,0 +1,9 @@ +package fasttext + +import "C" + +import "unsafe" + +func cStr(str string) *C.char { + return (*C.char)(unsafe.Pointer(unsafe.StringData(str))) +} diff --git a/prediction.cpp b/prediction.cpp new file mode 100644 index 0000000..56820aa --- /dev/null +++ b/prediction.cpp @@ -0,0 +1,45 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include "predictions.h" + +BEGIN_EXTERN_C() +size_t FastText_Predict(const FastText_Handle_t handle, FastText_String_t query, uint32_t k, float threshold, + FastText_PredictItem_t *const value) +{ + const auto model = reinterpret_cast(handle); + auto predictions = model->predictFull(k, std::string_view(query.data, query.size), threshold); + const auto count = k > predictions.size() ? predictions.size() : k; + + for (size_t i = 0; i < count; i++) + { + const auto &prediction = predictions.at(i); + + std::string_view data = prediction.word.substr(LABEL_PREFIX_SIZE); + size_t size = data.size(); + + if (size > 8) + { + size = 8; + } + + value[i].probability = prediction.score; + value[i].lang = FastText_String_t{ + .size = size, + .data = (char *)data.data(), + }; + } + + return count; +} + +END_EXTERN_C(); diff --git a/predictions.cpp b/predictions.cpp new file mode 100644 index 0000000..0930060 --- /dev/null +++ b/predictions.cpp @@ -0,0 +1,60 @@ +#include + +#include "predictions.h" + +BEGIN_EXTERN_C() + +// FastText_Predict_t FastText_Analogy(const FastText_Handle_t handle, FastText_String_t word1, FastText_String_t word2, +// FastText_String_t word3, uint32_t k) +// { +// const auto model = reinterpret_cast(handle); +// Predictions predictions = +// model->getAnalogies(k, std::string(word1.data, word1.size), std::string(word2.data, word2.size), +// std::string(word3.data, word3.size)); + +// auto vec = new Predictions(std::move(predictions)); + +// return FastText_Predict_t{ +// vec->size(), +// (void *)vec, +// }; +// } + +FastText_FloatVector_t FastText_Wordvec(const FastText_Handle_t handle, FastText_String_t word) +{ + const auto model = reinterpret_cast(handle); + int64_t dimensions = model->getDimension(); + auto vec = new fasttext::Vector(std::move(model->getWordVector(std::string_view(word.data, word.size)))); + + return FastText_FloatVector_t{ + vec->data(), + (void *)vec, + (size_t)vec->size(), + }; +} + +// FastText_FloatVector_t FastText_Sentencevec(const FastText_Handle_t handle, FastText_String_t sentence) +// { +// const auto model = reinterpret_cast(handle); + +// membuf sbuf(sentence); +// std::istream in(&sbuf); + +// auto vec = new fasttext::Vector(model->getDimension()); +// model->getSentenceVector(in, *vec); +// FREE_STRING(sentence); + +// return FastText_FloatVector_t{ +// vec->data(), +// (void *)vec, +// (size_t)vec->size(), +// }; +// } + +void FastText_FreeFloatVector(FastText_FloatVector_t vector) +{ + auto vec = reinterpret_cast(vector.handle); + delete vec; +} + +END_EXTERN_C() diff --git a/cbits.h b/predictions.h similarity index 59% rename from cbits.h rename to predictions.h index 9cf6fef..406ba4d 100644 --- a/cbits.h +++ b/predictions.h @@ -3,6 +3,18 @@ #include #include +#define LABEL_PREFIX ("__label__") +#define LABEL_PREFIX_SIZE (sizeof(LABEL_PREFIX) - 1) + +#define FREE_STRING(str) \ + do \ + { \ + if (str.data != nullptr) \ + free(str.data); \ + str.data = nullptr; \ + str.size = 0; \ + } while (0) + #ifdef __cplusplus #define BEGIN_EXTERN_C() \ extern "C" \ @@ -38,12 +50,6 @@ typedef struct FastText_String_t lang; } FastText_PredictItem_t; -typedef struct -{ - size_t size; - unsigned char *data; -} FastText_Predict_t; - typedef struct { enum @@ -63,11 +69,7 @@ void FastText_DeleteHandle(const FastText_Handle_t handle); size_t FastText_Predict(const FastText_Handle_t handle, FastText_String_t query, uint32_t k, float threshold, FastText_PredictItem_t *const value); FastText_FloatVector_t FastText_Wordvec(const FastText_Handle_t handle, FastText_String_t word); -FastText_FloatVector_t FastText_Sentencevec(const FastText_Handle_t handle, FastText_String_t sentance); -// FastText_Predict_t FastText_Analogy(const FastText_Handle_t handle, FastText_String_t word1, FastText_String_t word2, -// FastText_String_t word3, uint32_t k); void FastText_FreeFloatVector(FastText_FloatVector_t vector); -void FastText_FreePredict(FastText_Predict_t predict); END_EXTERN_C()