diff --git a/cpp/llm_chat.cc b/cpp/llm_chat.cc index 5577f9b87d..5a843402dc 100644 --- a/cpp/llm_chat.cc +++ b/cpp/llm_chat.cc @@ -262,8 +262,8 @@ struct FunctionTable { this->kv_cache_begin_forward_func_ = get_global_func("vm.builtin.kv_state_begin_forward"); this->kv_cache_end_forward_func_ = get_global_func("vm.builtin.kv_state_end_forward"); this->fkvcache_array_popn_ = get_global_func("vm.builtin.kv_state_popn"); - // TODO(mlc-team): enable backtracing when using paged kvcache - this->support_backtracking_kv_ = true; + // note: We use max sequence length = 1 for RNN state for now, so disable back tracking + this->support_backtracking_kv_ = this->use_kv_state == KVStateKind::kAttention; } } diff --git a/python/mlc_llm/interface/gen_config.py b/python/mlc_llm/interface/gen_config.py index f4d39aa8ba..bbaaf0bf68 100644 --- a/python/mlc_llm/interface/gen_config.py +++ b/python/mlc_llm/interface/gen_config.py @@ -3,6 +3,8 @@ import dataclasses import json import shutil +import re +import msgpack from pathlib import Path from typing import Any, Dict, List, Optional @@ -74,6 +76,40 @@ def apply_defaults(self) -> None: logger.info("[System default] Setting %s: %s", bold(key), value) +def txt2rwkvt(vocab: Path, out: Path) -> None: + """Generate tokenizer_model from RWKV vocab file.""" + idx2token = {} + + with vocab.open("r", encoding="utf-8") as f: + lines = f.readlines() + + for l in lines: + idx = int(l[: l.index(" ")]) + x = eval(l[l.index(" ") : l.rindex(" ")]) + x = x.encode("utf-8") if isinstance(x, str) else x + assert isinstance(x, bytes) + assert len(x) == int(l[l.rindex(" ") :]) + idx2token[idx] = x + + with (out / "tokenizer_model").open("wb") as f: + msgpack.pack(idx2token, f) + + +def json2rwkvt(vocab: Path, out: Path) -> None: + """Generate tokenizer_model from RWKV vocab file.""" + idx2token = {} + + with vocab.open("r", encoding="utf-8") as f: + data = json.load(f) + for key, value in data.items(): + x = key.encode("utf-8") if isinstance(key, str) else key + assert isinstance(x, bytes) + idx2token[int(value)] = x + + with (out / "tokenizer_model").open("wb") as f: + msgpack.pack(idx2token, f) + + def gen_config( # pylint: disable=too-many-locals,too-many-arguments,too-many-branches,too-many-statements config: Path, model: Model, @@ -133,7 +169,18 @@ def gen_config( # pylint: disable=too-many-locals,too-many-arguments,too-many-b logger.info("%s tokenizer config: %s. Copying to %s", FOUND, file, bold(str(dest))) else: logger.info("%s tokenizer config: %s", NOT_FOUND, file) - # 3.2. If we have `tokenizer.model` but not `tokenizer.json`, try convert it to + # 3.2. Generate `tokenizer_model` for rwkv if `rwkv_vocab_.*` is found + pattern = re.compile(r"rwkv_vocab_v\d{8}\.(json|txt)") + for item in config.parent.iterdir(): + if item.is_file() and pattern.match(item.name): + logger.info( + "%s RWKV vocab file: %s. Genetating %s", FOUND, item, bold("tokenizer_model") + ) + if item.name.endswith(".txt"): + txt2rwkvt(item, output) + else: + json2rwkvt(item, output) + # 3.3. If we have `tokenizer.model` but not `tokenizer.json`, try convert it to # `tokenizer.json` with `transformers`. tokenizer_json_file = config.parent / "tokenizer.json" tokenizer_model_file = config.parent / "tokenizer.model"