Skip to content

Commit

Permalink
[SLM]: Support for rwkv tokenizer
Browse files Browse the repository at this point in the history
  • Loading branch information
Celve committed Mar 15, 2024
1 parent 09fe1bc commit 99addfd
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 3 deletions.
4 changes: 2 additions & 2 deletions cpp/llm_chat.cc
Original file line number Diff line number Diff line change
Expand Up @@ -262,8 +262,8 @@ struct FunctionTable {
this->kv_cache_begin_forward_func_ = get_global_func("vm.builtin.kv_state_begin_forward");
this->kv_cache_end_forward_func_ = get_global_func("vm.builtin.kv_state_end_forward");
this->fkvcache_array_popn_ = get_global_func("vm.builtin.kv_state_popn");
// TODO(mlc-team): enable backtracing when using paged kvcache
this->support_backtracking_kv_ = true;
// note: We use max sequence length = 1 for RNN state for now, so disable back tracking
this->support_backtracking_kv_ = this->use_kv_state == KVStateKind::kAttention;
}
}

Expand Down
49 changes: 48 additions & 1 deletion python/mlc_llm/interface/gen_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import dataclasses
import json
import shutil
import re
import msgpack
from pathlib import Path
from typing import Any, Dict, List, Optional

Expand Down Expand Up @@ -74,6 +76,40 @@ def apply_defaults(self) -> None:
logger.info("[System default] Setting %s: %s", bold(key), value)


def txt2rwkvt(vocab: Path, out: Path) -> None:
"""Generate tokenizer_model from RWKV vocab file."""
idx2token = {}

with vocab.open("r", encoding="utf-8") as f:
lines = f.readlines()

for l in lines:
idx = int(l[: l.index(" ")])
x = eval(l[l.index(" ") : l.rindex(" ")])
x = x.encode("utf-8") if isinstance(x, str) else x
assert isinstance(x, bytes)
assert len(x) == int(l[l.rindex(" ") :])
idx2token[idx] = x

with (out / "tokenizer_model").open("wb") as f:
msgpack.pack(idx2token, f)


def json2rwkvt(vocab: Path, out: Path) -> None:
"""Generate tokenizer_model from RWKV vocab file."""
idx2token = {}

with vocab.open("r", encoding="utf-8") as f:
data = json.load(f)
for key, value in data.items():
x = key.encode("utf-8") if isinstance(key, str) else key
assert isinstance(x, bytes)
idx2token[int(value)] = x

with (out / "tokenizer_model").open("wb") as f:
msgpack.pack(idx2token, f)


def gen_config( # pylint: disable=too-many-locals,too-many-arguments,too-many-branches,too-many-statements
config: Path,
model: Model,
Expand Down Expand Up @@ -133,7 +169,18 @@ def gen_config( # pylint: disable=too-many-locals,too-many-arguments,too-many-b
logger.info("%s tokenizer config: %s. Copying to %s", FOUND, file, bold(str(dest)))
else:
logger.info("%s tokenizer config: %s", NOT_FOUND, file)
# 3.2. If we have `tokenizer.model` but not `tokenizer.json`, try convert it to
# 3.2. Generate `tokenizer_model` for rwkv if `rwkv_vocab_.*` is found
pattern = re.compile(r"rwkv_vocab_v\d{8}\.(json|txt)")
for item in config.parent.iterdir():
if item.is_file() and pattern.match(item.name):
logger.info(
"%s RWKV vocab file: %s. Genetating %s", FOUND, item, bold("tokenizer_model")
)
if item.name.endswith(".txt"):
txt2rwkvt(item, output)
else:
json2rwkvt(item, output)
# 3.3. If we have `tokenizer.model` but not `tokenizer.json`, try convert it to
# `tokenizer.json` with `transformers`.
tokenizer_json_file = config.parent / "tokenizer.json"
tokenizer_model_file = config.parent / "tokenizer.model"
Expand Down

0 comments on commit 99addfd

Please sign in to comment.