Skip to content

Commit

Permalink
feat: implement vocab_only mode for context (#67)
Browse files Browse the repository at this point in the history
* feat: implement vocab_only mode for context

* chore(ts): update test

* fix: typo
  • Loading branch information
jhen0409 authored Jan 2, 2025
1 parent 18b810a commit 5c91e0d
Show file tree
Hide file tree
Showing 7 changed files with 110 additions and 17 deletions.
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ project (llama-node)
set(CMAKE_CXX_STANDARD 17)

execute_process(COMMAND
git apply ${CMAKE_CURRENT_SOURCE_DIR}/scripts/ggml-cpu-CMakeLists.txt.patch
git apply ${CMAKE_CURRENT_SOURCE_DIR}/scripts/llama.cpp.patch
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
)

Expand Down
1 change: 1 addition & 0 deletions lib/binding.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ export type LlamaModelOptions = {
n_gpu_layers?: number
use_mlock?: boolean
use_mmap?: boolean
vocab_only?: boolean
}

export type LlamaCompletionOptions = {
Expand Down
13 changes: 0 additions & 13 deletions scripts/ggml-cpu-CMakeLists.txt.patch

This file was deleted.

37 changes: 37 additions & 0 deletions scripts/llama.cpp.patch
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
diff --git a/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt b/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt
index 683b90af..e1bf104c 100644
--- a/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt
+++ b/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt
@@ -80,7 +80,7 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
message(STATUS "ARM detected")

if (MSVC AND NOT CMAKE_C_COMPILER_ID STREQUAL "Clang")
- message(FATAL_ERROR "MSVC is not supported for ARM, use clang")
+ list(APPEND ARCH_FLAGS /arch:armv8.7)
else()
check_cxx_compiler_flag(-mfp16-format=ieee COMPILER_SUPPORTS_FP16_FORMAT_I3E)
if (NOT "${COMPILER_SUPPORTS_FP16_FORMAT_I3E}" STREQUAL "")
diff --git a/src/llama.cpp/common/common.h b/src/llama.cpp/common/common.h
index 1d2bd932..b5007c66 100644
--- a/src/llama.cpp/common/common.h
+++ b/src/llama.cpp/common/common.h
@@ -183,6 +183,7 @@ struct common_params_vocoder {
};

struct common_params {
+ bool vocab_only = false;
int32_t n_predict = -1; // new tokens to predict
int32_t n_ctx = 4096; // context size
int32_t n_batch = 2048; // logical batch size for prompt processing (must be >=32 to use BLAS)
diff --git a/src/llama.cpp/common/common.cpp b/src/llama.cpp/common/common.cpp
index 20be9291..1bedc55d 100644
--- a/src/llama.cpp/common/common.cpp
+++ b/src/llama.cpp/common/common.cpp
@@ -1017,6 +1017,7 @@ struct llama_model_params common_model_params_to_llama(common_params & params) {
if (params.n_gpu_layers != -1) {
mparams.n_gpu_layers = params.n_gpu_layers;
}
+ mparams.vocab_only = params.vocab_only;
mparams.rpc_servers = params.rpc_servers.c_str();
mparams.main_gpu = params.main_gpu;
mparams.split_mode = params.split_mode;
5 changes: 5 additions & 0 deletions src/LlamaContext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,11 @@ LlamaContext::LlamaContext(const Napi::CallbackInfo &info)
Napi::TypeError::New(env, "Model is required").ThrowAsJavaScriptException();
}

params.vocab_only = get_option<bool>(options, "vocab_only", false);
if (params.vocab_only) {
params.warmup = false;
}

params.n_ctx = get_option<int32_t>(options, "n_ctx", 512);
params.n_batch = get_option<int32_t>(options, "n_batch", 2048);
params.embedding = get_option<bool>(options, "embedding", false);
Expand Down
60 changes: 58 additions & 2 deletions test/__snapshots__/index.test.ts.snap
Original file line number Diff line number Diff line change
Expand Up @@ -413,7 +413,7 @@ exports[`tokeneize 3`] = `
"
`;

exports[`work fine 1`] = `
exports[`works fine 1`] = `
{
"text": " swochadoorter scientific WindowsCa occupiedrå alta",
"timings": "Timings: (8) keys",
Expand All @@ -423,7 +423,63 @@ exports[`work fine 1`] = `
}
`;

exports[`work fine: model info 1`] = `
exports[`works fine with vocab_only: empty result 1`] = `
{
"text": "",
"timings": {
"predicted_ms": 0,
"predicted_n": 1,
"predicted_per_second": Infinity,
"predicted_per_token_ms": 0,
"prompt_ms": 0,
"prompt_n": 1,
"prompt_per_second": Infinity,
"prompt_per_token_ms": 0,
},
"tokens_evaluated": 0,
"tokens_predicted": 0,
"truncated": false,
}
`;

exports[`works fine with vocab_only: model info 1`] = `
{
"desc": "llama ?B all F32",
"isChatTemplateSupported": false,
"metadata": {
"general.architecture": "llama",
"general.file_type": "1",
"general.name": "LLaMA v2",
"llama.attention.head_count": "2",
"llama.attention.head_count_kv": "2",
"llama.attention.layer_norm_rms_epsilon": "0.000010",
"llama.block_count": "1",
"llama.context_length": "4096",
"llama.embedding_length": "8",
"llama.feed_forward_length": "32",
"llama.rope.dimension_count": "4",
"tokenizer.ggml.bos_token_id": "1",
"tokenizer.ggml.eos_token_id": "2",
"tokenizer.ggml.model": "llama",
"tokenizer.ggml.unknown_token_id": "0",
},
"nParams": 513048,
"size": 1026144,
}
`;

exports[`works fine with vocab_only: tokenize 1`] = `
{
"tokens": Int32Array [
9038,
2501,
263,
931,
],
}
`;

exports[`works fine: model info 1`] = `
{
"desc": "llama ?B F16",
"isChatTemplateSupported": false,
Expand Down
9 changes: 8 additions & 1 deletion test/index.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import path from 'path'
import waitForExpect from 'wait-for-expect'
import { loadModel } from '../lib'

it('work fine', async () => {
it('works fine', async () => {
let tokens = ''
const model = await loadModel({ model: path.resolve(__dirname, './tiny-random-llama.gguf') })
const info = model.getModelInfo()
Expand Down Expand Up @@ -30,6 +30,13 @@ it('work fine', async () => {
await model.release()
})

it('works fine with vocab_only', async () => {
const model = await loadModel({ model: path.resolve(__dirname, './tiny-random-llama.gguf'), vocab_only: true })
expect(model.getModelInfo()).toMatchSnapshot('model info')
expect(await model.tokenize('Once upon a time')).toMatchSnapshot('tokenize')
expect(await model.completion({ prompt: 'Once upon a time' })).toMatchSnapshot('empty result')
})

it('tokeneize', async () => {
const model = await loadModel({ model: path.resolve(__dirname, './tiny-random-llama.gguf') })
{
Expand Down

0 comments on commit 5c91e0d

Please sign in to comment.