Skip to content

Commit

Permalink
feat: support embedding and move Vulkan as lib variant
Browse files Browse the repository at this point in the history
  • Loading branch information
hans00 committed May 14, 2024
1 parent 25a8337 commit 5e88aa1
Show file tree
Hide file tree
Showing 16 changed files with 685 additions and 7 deletions.
6 changes: 6 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,12 @@ file(
"src/LlamaCompletionWorker.h"
"src/LlamaContext.cpp"
"src/LlamaContext.h"
"src/TokenizeWorker.cpp"
"src/TokenizeWorker.h"
"src/DetokenizeWorker.cpp"
"src/DetokenizeWorker.h"
"src/EmbeddingWorker.cpp"
"src/EmbeddingWorker.h"
"src/LoadSessionWorker.cpp"
"src/LoadSessionWorker.h"
"src/SaveSessionWorker.cpp"
Expand Down
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ console.log('Result:', text)

## Lib Variants

- [x] `default`: General usage, Supported GPU: Metal (macOS) and Vulkan (Linux / Windows)
- [x] `default`: General usage, not support GPU except macOS (Metal)
- [x] `vulkan`: Support GPU Vulkan (Windows/Linux), but some scenario might unstable

## License

Expand Down
11 changes: 11 additions & 0 deletions lib/binding.ts
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,22 @@ export type LlamaCompletionToken = {
token: string
}

export type TokenizeResult = {
tokens: Int32Array
}

export type EmbeddingResult = {
embedding: Float32Array
}

export interface LlamaContext {
new (options: LlamaModelOptions): LlamaContext
getSystemInfo(): string
completion(options: LlamaCompletionOptions, callback?: (token: LlamaCompletionToken) => void): Promise<LlamaCompletionResult>
stopCompletion(): void
tokenize(text: string): Promise<TokenizeResult>
detokenize(tokens: number[]): Promise<string>
embedding(text: string): Promise<EmbeddingResult>
saveSession(path: string): Promise<void>
loadSession(path: string): Promise<void>
release(): Promise<void>
Expand Down
6 changes: 4 additions & 2 deletions scripts/build-linux.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@ set -e
# General

if [ $(uname -m) == "x86_64" ]; then
yarn clean && yarn build-native --CDLLAMA_VULKAN=1
yarn clean && yarn build-native
yarn clean && yarn build-native --CDLLAMA_VULKAN=1 --CDVARIANT=vulkan
else
yarn clean && yarn build-native --CDLLAMA_VULKAN=1 --CDVULKAN_SDK="$(realpath 'externals/arm64-Vulkan-SDK')"
yarn clean && yarn build-native
yarn clean && yarn build-native --CDLLAMA_VULKAN=1 --CDVULKAN_SDK="$(realpath 'externals/arm64-Vulkan-SDK')" --CDVARIANT=vulkan
fi
9 changes: 7 additions & 2 deletions scripts/build-windows.ps1
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,10 @@ $ErrorActionPreference='Stop'

# General

yarn clean ; yarn build-native -a x86_64 --CDCMAKE_PREFIX_PATH=(Resolve-Path 'externals/win32-x64/Vulkan-SDK')
yarn clean ; yarn build-native -a arm64 --CDCMAKE_PREFIX_PATH=(Resolve-Path 'externals/win32-arm64/Vulkan-SDK')
yarn clean ; yarn build-native -a x86_64
yarn clean ; yarn build-native -a arm64

# Vulkan, might crash on some senerios

yarn clean ; yarn build-native -a x86_64 --CDVULKAN_SDK=(Resolve-Path 'externals/win32-x64/Vulkan-SDK') --CDVARIANT=vulkan --CDLLAMA_VULKAN=1
yarn clean ; yarn build-native -a arm64 --CDVULKAN_SDK=(Resolve-Path 'externals/win32-arm64/Vulkan-SDK') --CDVARIANT=vulkan --CDLLAMA_VULKAN=1
22 changes: 22 additions & 0 deletions src/DetokenizeWorker.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#include "DetokenizeWorker.h"
#include "LlamaContext.h"

DetokenizeWorker::DetokenizeWorker(const Napi::CallbackInfo &info,
LlamaSessionPtr &sess,
std::vector<llama_token> &tokens)
: AsyncWorker(info.Env()), Deferred(info.Env()), _sess(sess),
_tokens(std::move(tokens)) {}

void DetokenizeWorker::Execute() {
const auto text = ::llama_detokenize_bpe(_sess->context(), _tokens);
_text = std::move(text);
}

void DetokenizeWorker::OnOK() {
Napi::Promise::Deferred::Resolve(
Napi::String::New(Napi::AsyncWorker::Env(), _text));
}

void DetokenizeWorker::OnError(const Napi::Error &err) {
Napi::Promise::Deferred::Reject(err.Value());
}
19 changes: 19 additions & 0 deletions src/DetokenizeWorker.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#include "common.hpp"
#include <vector>

class DetokenizeWorker : public Napi::AsyncWorker,
public Napi::Promise::Deferred {
public:
DetokenizeWorker(const Napi::CallbackInfo &info, LlamaSessionPtr &sess,
std::vector<llama_token> &tokens);

protected:
void Execute();
void OnOK();
void OnError(const Napi::Error &err);

private:
LlamaSessionPtr _sess;
std::vector<llama_token> _tokens;
std::string _text;
};
46 changes: 46 additions & 0 deletions src/EmbeddingWorker.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
#include "EmbeddingWorker.h"
#include "LlamaContext.h"

EmbeddingWorker::EmbeddingWorker(const Napi::CallbackInfo &info,
LlamaSessionPtr &sess, std::string text)
: AsyncWorker(info.Env()), Deferred(info.Env()), _sess(sess), _text(text) {}

void EmbeddingWorker::Execute() {
llama_kv_cache_clear(_sess->context());
auto tokens = ::llama_tokenize(_sess->context(), _text, true);
// add SEP if not present
if (tokens.empty() || tokens.back() != llama_token_sep(_sess->model())) {
tokens.push_back(llama_token_sep(_sess->model()));
}
const int n_embd = llama_n_embd(_sess->model());
do {
int ret =
llama_decode(_sess->context(),
llama_batch_get_one(tokens.data(), tokens.size(), 0, 0));
if (ret < 0) {
SetError("Failed to inference, code: " + std::to_string(ret));
break;
}
const float *embd = llama_get_embeddings_seq(_sess->context(), 0);
if (embd == nullptr) {
SetError("Failed to get embeddings");
break;
}
_result.embedding.resize(n_embd);
memcpy(_result.embedding.data(), embd, n_embd * sizeof(float));
} while (false);
}

void EmbeddingWorker::OnOK() {
auto result = Napi::Object::New(Napi::AsyncWorker::Env());
auto embedding = Napi::Float32Array::New(Napi::AsyncWorker::Env(),
_result.embedding.size());
memcpy(embedding.Data(), _result.embedding.data(),
_result.embedding.size() * sizeof(float));
result.Set("embedding", embedding);
Napi::Promise::Deferred::Resolve(result);
}

void EmbeddingWorker::OnError(const Napi::Error &err) {
Napi::Promise::Deferred::Reject(err.Value());
}
23 changes: 23 additions & 0 deletions src/EmbeddingWorker.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
#include "common.hpp"
#include <vector>

struct EmbeddingResult {
std::vector<float> embedding;
};

class EmbeddingWorker : public Napi::AsyncWorker,
public Napi::Promise::Deferred {
public:
EmbeddingWorker(const Napi::CallbackInfo &info, LlamaSessionPtr &sess,
std::string text);

protected:
void Execute();
void OnOK();
void OnError(const Napi::Error &err);

private:
LlamaSessionPtr _sess;
std::string _text;
EmbeddingResult _result;
};
62 changes: 62 additions & 0 deletions src/LlamaContext.cpp
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
#include "LlamaContext.h"
#include "DetokenizeWorker.h"
#include "DisposeWorker.h"
#include "EmbeddingWorker.h"
#include "LlamaCompletionWorker.h"
#include "LoadSessionWorker.h"
#include "SaveSessionWorker.h"
#include "TokenizeWorker.h"

void LlamaContext::Init(Napi::Env env, Napi::Object &exports) {
Napi::Function func = DefineClass(
Expand All @@ -16,6 +19,13 @@ void LlamaContext::Init(Napi::Env env, Napi::Object &exports) {
InstanceMethod<&LlamaContext::StopCompletion>(
"stopCompletion",
static_cast<napi_property_attributes>(napi_enumerable)),
InstanceMethod<&LlamaContext::Tokenize>(
"tokenize", static_cast<napi_property_attributes>(napi_enumerable)),
InstanceMethod<&LlamaContext::Detokenize>(
"detokenize",
static_cast<napi_property_attributes>(napi_enumerable)),
InstanceMethod<&LlamaContext::Embedding>(
"embedding", static_cast<napi_property_attributes>(napi_enumerable)),
InstanceMethod<&LlamaContext::SaveSession>(
"saveSession",
static_cast<napi_property_attributes>(napi_enumerable)),
Expand Down Expand Up @@ -163,6 +173,58 @@ void LlamaContext::StopCompletion(const Napi::CallbackInfo &info) {
}
}

// tokenize(text: string): Promise<TokenizeResult>
Napi::Value LlamaContext::Tokenize(const Napi::CallbackInfo &info) {
Napi::Env env = info.Env();
if (info.Length() < 1 || !info[0].IsString()) {
Napi::TypeError::New(env, "String expected").ThrowAsJavaScriptException();
}
if (_sess == nullptr) {
Napi::TypeError::New(env, "Context is disposed")
.ThrowAsJavaScriptException();
}
auto text = info[0].ToString().Utf8Value();
auto *worker = new TokenizeWorker(info, _sess, text);
worker->Queue();
return worker->Promise();
}

// detokenize(tokens: number[]): Promise<string>
Napi::Value LlamaContext::Detokenize(const Napi::CallbackInfo &info) {
Napi::Env env = info.Env();
if (info.Length() < 1 || !info[0].IsArray()) {
Napi::TypeError::New(env, "Array expected").ThrowAsJavaScriptException();
}
if (_sess == nullptr) {
Napi::TypeError::New(env, "Context is disposed")
.ThrowAsJavaScriptException();
}
auto tokens = info[0].As<Napi::Array>();
std::vector<int32_t> token_ids;
for (size_t i = 0; i < tokens.Length(); i++) {
token_ids.push_back(tokens.Get(i).ToNumber().Int32Value());
}
auto *worker = new DetokenizeWorker(info, _sess, token_ids);
worker->Queue();
return worker->Promise();
}

// embedding(text: string): Promise<EmbeddingResult>
Napi::Value LlamaContext::Embedding(const Napi::CallbackInfo &info) {
Napi::Env env = info.Env();
if (info.Length() < 1 || !info[0].IsString()) {
Napi::TypeError::New(env, "String expected").ThrowAsJavaScriptException();
}
if (_sess == nullptr) {
Napi::TypeError::New(env, "Context is disposed")
.ThrowAsJavaScriptException();
}
auto text = info[0].ToString().Utf8Value();
auto *worker = new EmbeddingWorker(info, _sess, text);
worker->Queue();
return worker->Promise();
}

// saveSession(path: string): Promise<void> throws error
Napi::Value LlamaContext::SaveSession(const Napi::CallbackInfo &info) {
Napi::Env env = info.Env();
Expand Down
3 changes: 3 additions & 0 deletions src/LlamaContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ class LlamaContext : public Napi::ObjectWrap<LlamaContext> {
Napi::Value GetSystemInfo(const Napi::CallbackInfo &info);
Napi::Value Completion(const Napi::CallbackInfo &info);
void StopCompletion(const Napi::CallbackInfo &info);
Napi::Value Tokenize(const Napi::CallbackInfo &info);
Napi::Value Detokenize(const Napi::CallbackInfo &info);
Napi::Value Embedding(const Napi::CallbackInfo &info);
Napi::Value SaveSession(const Napi::CallbackInfo &info);
Napi::Value LoadSession(const Napi::CallbackInfo &info);
Napi::Value Release(const Napi::CallbackInfo &info);
Expand Down
26 changes: 26 additions & 0 deletions src/TokenizeWorker.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#include "TokenizeWorker.h"
#include "LlamaContext.h"

TokenizeWorker::TokenizeWorker(const Napi::CallbackInfo &info,
LlamaSessionPtr &sess, std::string text)
: AsyncWorker(info.Env()), Deferred(info.Env()), _sess(sess), _text(text) {}

void TokenizeWorker::Execute() {
const auto tokens = ::llama_tokenize(_sess->context(), _text, false);
_result = {.tokens = std::move(tokens)};
}

void TokenizeWorker::OnOK() {
Napi::HandleScope scope(Napi::AsyncWorker::Env());
auto result = Napi::Object::New(Napi::AsyncWorker::Env());
auto tokens =
Napi::Int32Array::New(Napi::AsyncWorker::Env(), _result.tokens.size());
memcpy(tokens.Data(), _result.tokens.data(),
_result.tokens.size() * sizeof(llama_token));
result.Set("tokens", tokens);
Napi::Promise::Deferred::Resolve(result);
}

void TokenizeWorker::OnError(const Napi::Error &err) {
Napi::Promise::Deferred::Reject(err.Value());
}
23 changes: 23 additions & 0 deletions src/TokenizeWorker.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
#include "common.hpp"
#include <vector>

struct TokenizeResult {
std::vector<llama_token> tokens;
};

class TokenizeWorker : public Napi::AsyncWorker,
public Napi::Promise::Deferred {
public:
TokenizeWorker(const Napi::CallbackInfo &info, LlamaSessionPtr &sess,
std::string text);

protected:
void Execute();
void OnOK();
void OnError(const Napi::Error &err);

private:
LlamaSessionPtr _sess;
std::string _text;
TokenizeResult _result;
};
5 changes: 3 additions & 2 deletions src/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ constexpr T get_option(const Napi::Object &options, const std::string &name,
class LlamaSession {
public:
LlamaSession(llama_model *model, llama_context *ctx, gpt_params params)
: model_(LlamaCppModel(model, llama_free_model)), ctx_(LlamaCppContext(ctx, llama_free)), params_(params) {
: model_(LlamaCppModel(model, llama_free_model)),
ctx_(LlamaCppContext(ctx, llama_free)), params_(params) {
tokens_.reserve(params.n_ctx);
}

Expand All @@ -57,7 +58,7 @@ class LlamaSession {

inline llama_model *model() { return model_.get(); }

inline std::vector<llama_token>* tokens_ptr() { return &tokens_; }
inline std::vector<llama_token> *tokens_ptr() { return &tokens_; }

inline void set_tokens(std::vector<llama_token> tokens) {
tokens_ = std::move(tokens);
Expand Down
Loading

0 comments on commit 5e88aa1

Please sign in to comment.