-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: support embedding and move Vulkan as lib variant
- Loading branch information
Showing
16 changed files
with
685 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
#include "DetokenizeWorker.h" | ||
#include "LlamaContext.h" | ||
|
||
DetokenizeWorker::DetokenizeWorker(const Napi::CallbackInfo &info, | ||
LlamaSessionPtr &sess, | ||
std::vector<llama_token> &tokens) | ||
: AsyncWorker(info.Env()), Deferred(info.Env()), _sess(sess), | ||
_tokens(std::move(tokens)) {} | ||
|
||
void DetokenizeWorker::Execute() { | ||
const auto text = ::llama_detokenize_bpe(_sess->context(), _tokens); | ||
_text = std::move(text); | ||
} | ||
|
||
void DetokenizeWorker::OnOK() { | ||
Napi::Promise::Deferred::Resolve( | ||
Napi::String::New(Napi::AsyncWorker::Env(), _text)); | ||
} | ||
|
||
void DetokenizeWorker::OnError(const Napi::Error &err) { | ||
Napi::Promise::Deferred::Reject(err.Value()); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
#include "common.hpp" | ||
#include <vector> | ||
|
||
class DetokenizeWorker : public Napi::AsyncWorker, | ||
public Napi::Promise::Deferred { | ||
public: | ||
DetokenizeWorker(const Napi::CallbackInfo &info, LlamaSessionPtr &sess, | ||
std::vector<llama_token> &tokens); | ||
|
||
protected: | ||
void Execute(); | ||
void OnOK(); | ||
void OnError(const Napi::Error &err); | ||
|
||
private: | ||
LlamaSessionPtr _sess; | ||
std::vector<llama_token> _tokens; | ||
std::string _text; | ||
}; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
#include "EmbeddingWorker.h" | ||
#include "LlamaContext.h" | ||
|
||
EmbeddingWorker::EmbeddingWorker(const Napi::CallbackInfo &info, | ||
LlamaSessionPtr &sess, std::string text) | ||
: AsyncWorker(info.Env()), Deferred(info.Env()), _sess(sess), _text(text) {} | ||
|
||
void EmbeddingWorker::Execute() { | ||
llama_kv_cache_clear(_sess->context()); | ||
auto tokens = ::llama_tokenize(_sess->context(), _text, true); | ||
// add SEP if not present | ||
if (tokens.empty() || tokens.back() != llama_token_sep(_sess->model())) { | ||
tokens.push_back(llama_token_sep(_sess->model())); | ||
} | ||
const int n_embd = llama_n_embd(_sess->model()); | ||
do { | ||
int ret = | ||
llama_decode(_sess->context(), | ||
llama_batch_get_one(tokens.data(), tokens.size(), 0, 0)); | ||
if (ret < 0) { | ||
SetError("Failed to inference, code: " + std::to_string(ret)); | ||
break; | ||
} | ||
const float *embd = llama_get_embeddings_seq(_sess->context(), 0); | ||
if (embd == nullptr) { | ||
SetError("Failed to get embeddings"); | ||
break; | ||
} | ||
_result.embedding.resize(n_embd); | ||
memcpy(_result.embedding.data(), embd, n_embd * sizeof(float)); | ||
} while (false); | ||
} | ||
|
||
void EmbeddingWorker::OnOK() { | ||
auto result = Napi::Object::New(Napi::AsyncWorker::Env()); | ||
auto embedding = Napi::Float32Array::New(Napi::AsyncWorker::Env(), | ||
_result.embedding.size()); | ||
memcpy(embedding.Data(), _result.embedding.data(), | ||
_result.embedding.size() * sizeof(float)); | ||
result.Set("embedding", embedding); | ||
Napi::Promise::Deferred::Resolve(result); | ||
} | ||
|
||
void EmbeddingWorker::OnError(const Napi::Error &err) { | ||
Napi::Promise::Deferred::Reject(err.Value()); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
#include "common.hpp" | ||
#include <vector> | ||
|
||
struct EmbeddingResult { | ||
std::vector<float> embedding; | ||
}; | ||
|
||
class EmbeddingWorker : public Napi::AsyncWorker, | ||
public Napi::Promise::Deferred { | ||
public: | ||
EmbeddingWorker(const Napi::CallbackInfo &info, LlamaSessionPtr &sess, | ||
std::string text); | ||
|
||
protected: | ||
void Execute(); | ||
void OnOK(); | ||
void OnError(const Napi::Error &err); | ||
|
||
private: | ||
LlamaSessionPtr _sess; | ||
std::string _text; | ||
EmbeddingResult _result; | ||
}; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
#include "TokenizeWorker.h" | ||
#include "LlamaContext.h" | ||
|
||
TokenizeWorker::TokenizeWorker(const Napi::CallbackInfo &info, | ||
LlamaSessionPtr &sess, std::string text) | ||
: AsyncWorker(info.Env()), Deferred(info.Env()), _sess(sess), _text(text) {} | ||
|
||
void TokenizeWorker::Execute() { | ||
const auto tokens = ::llama_tokenize(_sess->context(), _text, false); | ||
_result = {.tokens = std::move(tokens)}; | ||
} | ||
|
||
void TokenizeWorker::OnOK() { | ||
Napi::HandleScope scope(Napi::AsyncWorker::Env()); | ||
auto result = Napi::Object::New(Napi::AsyncWorker::Env()); | ||
auto tokens = | ||
Napi::Int32Array::New(Napi::AsyncWorker::Env(), _result.tokens.size()); | ||
memcpy(tokens.Data(), _result.tokens.data(), | ||
_result.tokens.size() * sizeof(llama_token)); | ||
result.Set("tokens", tokens); | ||
Napi::Promise::Deferred::Resolve(result); | ||
} | ||
|
||
void TokenizeWorker::OnError(const Napi::Error &err) { | ||
Napi::Promise::Deferred::Reject(err.Value()); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
#include "common.hpp" | ||
#include <vector> | ||
|
||
struct TokenizeResult { | ||
std::vector<llama_token> tokens; | ||
}; | ||
|
||
class TokenizeWorker : public Napi::AsyncWorker, | ||
public Napi::Promise::Deferred { | ||
public: | ||
TokenizeWorker(const Napi::CallbackInfo &info, LlamaSessionPtr &sess, | ||
std::string text); | ||
|
||
protected: | ||
void Execute(); | ||
void OnOK(); | ||
void OnError(const Napi::Error &err); | ||
|
||
private: | ||
LlamaSessionPtr _sess; | ||
std::string _text; | ||
TokenizeResult _result; | ||
}; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.