diff --git a/README.md b/README.md
index 5f703ca8e..0361d65eb 100644
--- a/README.md
+++ b/README.md
@@ -184,6 +184,14 @@ Neural Speed supports the following models:
StarCoder-1B,
diff --git a/neural_speed/__init__.py b/neural_speed/__init__.py
index 08514dbdc..1b9412382 100644
--- a/neural_speed/__init__.py
+++ b/neural_speed/__init__.py
@@ -132,8 +132,14 @@ def init_from_bin(self, model_type, model_path, **generate_kwargs):
self.model = self.module.Model()
if "threads" not in generate_kwargs:
threads = os.getenv("OMP_NUM_THREADS")
+ import platform
+ sys_platform = platform.platform().lower()
if threads is None:
- generate_kwargs["threads"] = len(os.sched_getaffinity(0))
+ if "windows" in sys_platform:
+ cpu_count = os.cpu_count()
+ generate_kwargs["threads"] = int(cpu_count)
+ else:
+ generate_kwargs["threads"] = len(os.sched_getaffinity(0))
else:
generate_kwargs["threads"] = int(threads)
self.model.init_model(model_path, **generate_kwargs)
diff --git a/neural_speed/convert/convert_baichuan.py b/neural_speed/convert/convert_baichuan.py
index 247f90eae..bbeb3c6d0 100644
--- a/neural_speed/convert/convert_baichuan.py
+++ b/neural_speed/convert/convert_baichuan.py
@@ -158,6 +158,7 @@ def baichuan13B_convert(model, tokenizer, dir_model, fname_out, ftype, hparams):
fout.write(struct.pack("i", hparams["intermediate_size"]))
fout.write(struct.pack("f", hparams.get("rms_norm_eps", 1e-6))) # rms norm eps
fout.write(struct.pack("f", 10000.0)) # freq_base
+ fout.write(struct.pack("f", 1.0)) # rope_factor
fout.write(struct.pack("i", tokenizer.bos_token_id if tokenizer.bos_token_id is not None else 1))
fout.write(struct.pack("i", tokenizer.eos_token_id if tokenizer.eos_token_id is not None else 2))
diff --git a/neural_speed/convert/convert_bloom.py b/neural_speed/convert/convert_bloom.py
index 3fca71b51..7e2a3f805 100644
--- a/neural_speed/convert/convert_bloom.py
+++ b/neural_speed/convert/convert_bloom.py
@@ -101,6 +101,7 @@ def main(args_in: Optional[List[str]] = None) -> None:
fout.write(struct.pack("i", 0))
fout.write(struct.pack("f", hparams.get("rms_norm_eps", 1e-6))) # rms norm eps
fout.write(struct.pack("f", 10000.0)) # freq_base
+ fout.write(struct.pack("f", 1.0)) # rope_factor
fout.write(struct.pack("i", tokenizer.bos_token_id if tokenizer.bos_token_id is not None else 1))
fout.write(struct.pack("i", tokenizer.eos_token_id if tokenizer.eos_token_id is not None else 2))
diff --git a/neural_speed/convert/convert_chatglm.py b/neural_speed/convert/convert_chatglm.py
index 38f986809..0eed8e130 100644
--- a/neural_speed/convert/convert_chatglm.py
+++ b/neural_speed/convert/convert_chatglm.py
@@ -358,6 +358,7 @@ def chatglm2_convert(model, tokenizer, dir_model, fname_out, ftype, hparams):
fout.write(struct.pack("i", 0))
fout.write(struct.pack("f", hparams.get("layernorm_epsilon", 1e-6))) # rms norm eps
fout.write(struct.pack("f", 10000.0)) # freq_base
+ fout.write(struct.pack("f", 1.0)) # rope_factor
fout.write(struct.pack("i", tokenizer.bos_token_id if tokenizer.bos_token_id is not None else 1))
fout.write(struct.pack("i", tokenizer.eos_token_id if tokenizer.eos_token_id is not None else 2))
diff --git a/neural_speed/convert/convert_dolly.py b/neural_speed/convert/convert_dolly.py
index f5b589f96..dc77b1c43 100644
--- a/neural_speed/convert/convert_dolly.py
+++ b/neural_speed/convert/convert_dolly.py
@@ -115,6 +115,7 @@ def main(args_in: Optional[List[str]] = None) -> None:
fout.write(struct.pack("i", 0))
fout.write(struct.pack("f", hparams.get("rms_norm_eps", 1e-6))) # rms norm eps
fout.write(struct.pack("f", 10000.0)) # freq_base
+ fout.write(struct.pack("f", 1.0)) # rope_factor
fout.write(struct.pack("i", tokenizer.bos_token_id if tokenizer.bos_token_id is not None else 1))
fout.write(struct.pack("i", tokenizer.eos_token_id if tokenizer.eos_token_id is not None else 2))
diff --git a/neural_speed/convert/convert_falcon.py b/neural_speed/convert/convert_falcon.py
index 2f1ddd2b8..9d323f89d 100644
--- a/neural_speed/convert/convert_falcon.py
+++ b/neural_speed/convert/convert_falcon.py
@@ -109,6 +109,7 @@ def main(args_in: Optional[List[str]] = None) -> None:
fout.write(struct.pack("i", 0))
fout.write(struct.pack("f", hparams.get("rms_norm_eps", 1e-6))) # rms norm eps
fout.write(struct.pack("f", 10000.0)) # freq_base
+ fout.write(struct.pack("f", 1.0)) # rope_factor
fout.write(struct.pack("i", tokenizer.bos_token_id if tokenizer.bos_token_id is not None else 1))
fout.write(struct.pack("i", tokenizer.eos_token_id if tokenizer.eos_token_id is not None else 2))
diff --git a/neural_speed/convert/convert_gptj.py b/neural_speed/convert/convert_gptj.py
index c7b89b2d8..2f6c8e673 100644
--- a/neural_speed/convert/convert_gptj.py
+++ b/neural_speed/convert/convert_gptj.py
@@ -101,6 +101,7 @@ def main(args_in: Optional[List[str]] = None) -> None:
fout.write(struct.pack("i", 0))
fout.write(struct.pack("f", hparams.get("rms_norm_eps", 1e-6))) # rms norm eps
fout.write(struct.pack("f", 10000.0)) # freq_base
+ fout.write(struct.pack("f", 1.0)) # rope_factor
fout.write(struct.pack("i", tokenizer.bos_token_id if tokenizer.bos_token_id is not None else 1))
fout.write(struct.pack("i", tokenizer.eos_token_id if tokenizer.eos_token_id is not None else 2))
diff --git a/neural_speed/convert/convert_gptneox.py b/neural_speed/convert/convert_gptneox.py
index 3b67daa31..409cc05ba 100644
--- a/neural_speed/convert/convert_gptneox.py
+++ b/neural_speed/convert/convert_gptneox.py
@@ -115,6 +115,7 @@ def main(args_in: Optional[List[str]] = None) -> None:
fout.write(struct.pack("i", 0))
fout.write(struct.pack("f", hparams.get("rms_norm_eps", 1e-6))) # rms norm eps
fout.write(struct.pack("f", 10000.0)) # freq_base
+ fout.write(struct.pack("f", 1.0)) # rope_factor
fout.write(struct.pack("i", tokenizer.bos_token_id if tokenizer.bos_token_id is not None else 1))
fout.write(struct.pack("i", tokenizer.eos_token_id if tokenizer.eos_token_id is not None else 2))
diff --git a/neural_speed/convert/convert_llama.py b/neural_speed/convert/convert_llama.py
index 8b448c95d..511608d52 100644
--- a/neural_speed/convert/convert_llama.py
+++ b/neural_speed/convert/convert_llama.py
@@ -151,6 +151,9 @@ class Params:
ffn_hidden_size: int
rms_norm_eps: float
rope_theta: float
+ rope_scale: float
+ bos_token_id: int
+ eos_token_id: int
@staticmethod
def guessed(model: 'LazyModel') -> 'Params':
@@ -180,6 +183,11 @@ def loadHFTransformerJson(model: 'LazyModel', config_path: Path) -> 'Params':
ffn_hidden_size = config["intermediate_size"]
rms_norm_eps = config["rms_norm_eps"]
rope_theta = config["rope_theta"] if "rope_theta" in config else 10000
+ rope_scale = 1
+ if "rope_scaling" in config and config["rope_scaling"] is not None:
+ rope_scale = config["rope_scaling"]["factor"] if "factor" in config["rope_scaling"] else 1
+ bos_token_id = config["bos_token_id"]
+ eos_token_id = config["eos_token_id"]
return Params(
n_vocab=n_vocab,
@@ -191,6 +199,9 @@ def loadHFTransformerJson(model: 'LazyModel', config_path: Path) -> 'Params':
ffn_hidden_size=ffn_hidden_size,
rms_norm_eps=rms_norm_eps,
rope_theta=rope_theta,
+ rope_scale=rope_scale,
+ bos_token_id = bos_token_id,
+ eos_token_id = eos_token_id,
)
# LLaMA v2 70B params.json
@@ -206,6 +217,8 @@ def loadOriginalParamsJson(model: 'LazyModel', config_path: Path) -> 'Params':
n_head = config["n_heads"]
n_head_kv = config["n_kv_heads"] if "n_kv_heads" in config else n_head
ffn_hidden_size = config["intermediate_size"]
+ bos_token_id = config["bos_token_id"]
+ eos_token_id = config["eos_token_id"]
# hack to determine LLaMA v1 vs v2 vs CodeLlama
if n_vocab == -1:
@@ -219,6 +232,8 @@ def loadOriginalParamsJson(model: 'LazyModel', config_path: Path) -> 'Params':
n_head=n_head,
n_head_kv=n_head_kv,
ffn_hidden_size=ffn_hidden_size,
+ bos_token_id = bos_token_id,
+ eos_token_id = eos_token_id,
)
@staticmethod
@@ -241,11 +256,11 @@ def load(model: 'ModelPlus') -> 'Params':
class SentencePieceVocab:
- def __init__(self, fname_tokenizer: Path, fname_added_tokens: Optional[Path]) -> None:
+ def __init__(self, fname_tokenizer: Path, params_vocab_size: int, fname_added_tokens: Optional[Path]) -> None:
self.sentencepiece_tokenizer = SentencePieceProcessor(str(fname_tokenizer))
added_tokens: Dict[str, int]
if fname_added_tokens is not None:
- added_tokens = json.load(open(fname_added_tokens))
+ added_tokens = json.load(open(fname_added_tokens, encoding='utf-8'))
else:
added_tokens = {}
vocab_size: int = self.sentencepiece_tokenizer.vocab_size()
@@ -260,25 +275,31 @@ def __init__(self, fname_tokenizer: Path, fname_added_tokens: Optional[Path]) ->
self.vocab_size: int = self.vocab_size_base + len(self.added_tokens_list)
self.fname_tokenizer = fname_tokenizer
self.fname_added_tokens = fname_added_tokens
+ self.params_vocab_size = params_vocab_size
def sentencepiece_tokens(self) -> Iterable[Tuple[bytes, float]]:
tokenizer = self.sentencepiece_tokenizer
- for i in range(tokenizer.vocab_size()):
- text: bytes
- if tokenizer.is_unknown(i):
+ for i in range(self.params_vocab_size):
+ text: bytes
+ if i < tokenizer.vocab_size():
+ if tokenizer.is_unknown(i):
+ text = " \u2047 ".encode("utf-8")
+ elif tokenizer.is_control(i):
+ text = b""
+ elif tokenizer.is_byte(i):
+ piece = tokenizer.id_to_piece(i)
+ if len(piece) != 6:
+ raise Exception(f"Invalid token: {piece}")
+ byte_value = int(piece[3:-1], 16)
+ text = struct.pack("B", byte_value)
+ else:
+ text = tokenizer.id_to_piece(i).replace("\u2581", " ").encode("utf-8")
+ score: float = tokenizer.get_score(i)
+ yield text, score
+ else :
text = " \u2047 ".encode("utf-8")
- elif tokenizer.is_control(i):
- text = b""
- elif tokenizer.is_byte(i):
- piece = tokenizer.id_to_piece(i)
- if len(piece) != 6:
- raise Exception(f"Invalid token: {piece}")
- byte_value = int(piece[3:-1], 16)
- text = struct.pack("B", byte_value)
- else:
- text = tokenizer.id_to_piece(i).replace("\u2581", " ").encode("utf-8")
- score: float = tokenizer.get_score(i)
- yield text, score
+ score: float = i
+ yield text, score
def added_tokens(self) -> Iterable[Tuple[bytes, float]]:
for text in self.added_tokens_list:
@@ -1066,13 +1087,14 @@ def write_file_header(self, params: Params, file_type: NEFileType) -> None:
self.fout.write(struct.pack("f", params.rms_norm_eps))
self.fout.write(struct.pack("f", params.rope_theta))
+ self.fout.write(struct.pack("f", params.rope_scale))
# TODO, bos_token_id = 0 in https://huggingface.co/decapoda-research/llama-7b-hf/blob/main/config.json
# but bos_token_id = 1 in llama.cpp
- self.fout.write(struct.pack("i", 1))
- self.fout.write(struct.pack("i", 2))
- self.fout.write(struct.pack("i", 0))
- self.fout.write(struct.pack("i", 0))
+ self.fout.write(struct.pack("i", params.bos_token_id))
+ self.fout.write(struct.pack("i", params.eos_token_id))
+ self.fout.write(struct.pack("i", -1))
+ self.fout.write(struct.pack("i", -1))
def write_tensor_header(self, name: str, shape: Sequence[int], data_type: DataType) -> None:
sname = name.encode('utf-8')
@@ -1103,7 +1125,7 @@ def write_vocab_only(fname_out: Path, vocab: Vocab) -> None:
@staticmethod
def write_all(fname_out: Path, params: Params, model: LazyModel, vocab: Vocab, file_type: NEFileType) -> None:
- check_vocab_size(params, vocab)
+ #check_vocab_size(params, vocab)
of = OutputFile(fname_out)
of.write_file_header(params, file_type)
print("Writing vocab...")
@@ -1345,7 +1367,7 @@ def filter_and_sort_tensors(model: LazyModel) -> LazyModel:
return {name: model[name] for name in TENSORS_LIST if name in model}
-def load_vocab(path: Path) -> SentencePieceVocab:
+def load_vocab(path: Path, params_vocab_size: int) -> SentencePieceVocab:
# Be extra-friendly and accept either a file or a directory. Also, if it's
# a directory, it might be the model directory, and tokenizer.model might
# be in the parent of that.
@@ -1363,7 +1385,7 @@ def load_vocab(path: Path) -> SentencePieceVocab:
pass the directory as --vocab-dir")
added_tokens_path = path.parent / "added_tokens.json"
print(f"Loading vocab file {path}")
- return SentencePieceVocab(path, added_tokens_path if added_tokens_path.exists() else None)
+ return SentencePieceVocab(path, params_vocab_size, added_tokens_path if added_tokens_path.exists() else None)
def default_outfile(model_paths: List[Path], params: Params) -> Path:
@@ -1430,15 +1452,14 @@ def main(args_in: Optional[List[str]] = None) -> None:
if args.dump:
do_dump_model(model_plus)
return
+ model = model_plus.model
+ params = Params.load(model_plus)
if model_plus.vocab is not None and args.vocab_dir is None:
vocab = model_plus.vocab
else:
vocab_dir = args.vocab_dir if args.vocab_dir else model_plus.paths[0].parent
- vocab = load_vocab(vocab_dir)
-
- model = model_plus.model
- params = Params.load(model_plus)
+ vocab = load_vocab(vocab_dir, params.n_vocab)
model = do_necessary_conversions(model, params)
output_type = pick_output_type(model, args.outtype)
model = convert_to_output_type(model, output_type)
diff --git a/neural_speed/convert/convert_mistral.py b/neural_speed/convert/convert_mistral.py
index 7bd6cfca0..71a195fcc 100644
--- a/neural_speed/convert/convert_mistral.py
+++ b/neural_speed/convert/convert_mistral.py
@@ -151,6 +151,7 @@ class Params:
ffn_hidden_size: int
rms_norm_eps: float
rope_theta: float
+ rope_scale: float
@staticmethod
def guessed(model: 'LazyModel') -> 'Params':
@@ -179,6 +180,10 @@ def loadHFTransformerJson(model: 'LazyModel', config_path: Path) -> 'Params':
ffn_hidden_size = config["intermediate_size"]
rms_norm_eps = config["rms_norm_eps"]
rope_theta = config["rope_theta"] if "rope_theta" in config else 10000
+ rope_scale = 1
+ if "rope_scaling" in config and config["rope_scaling"] is not None:
+ rope_scale = config["rope_scaling"]["factor"] if "factor" in config["rope_scaling"] else 1
+
return Params(
n_vocab=n_vocab,
@@ -190,6 +195,7 @@ def loadHFTransformerJson(model: 'LazyModel', config_path: Path) -> 'Params':
ffn_hidden_size=ffn_hidden_size,
rms_norm_eps=rms_norm_eps,
rope_theta=rope_theta,
+ rope_scale=rope_scale,
)
# LLaMA v2 70B params.json
@@ -244,7 +250,7 @@ def __init__(self, fname_tokenizer: Path, fname_added_tokens: Optional[Path]) ->
self.sentencepiece_tokenizer = SentencePieceProcessor(str(fname_tokenizer))
added_tokens: Dict[str, int]
if fname_added_tokens is not None:
- added_tokens = json.load(open(fname_added_tokens))
+ added_tokens = json.load(open(fname_added_tokens, encoding='utf-8'))
else:
added_tokens = {}
vocab_size: int = self.sentencepiece_tokenizer.vocab_size()
@@ -1057,6 +1063,7 @@ def write_file_header(self, params: Params, file_type: NEFileType) -> None:
self.fout.write(struct.pack("i", 0))
self.fout.write(struct.pack("f", params.rms_norm_eps))
self.fout.write(struct.pack("f", params.rope_theta))
+ self.fout.write(struct.pack("f", params.rope_scale))
self.fout.write(
struct.pack("i", 1)
diff --git a/neural_speed/convert/convert_mpt.py b/neural_speed/convert/convert_mpt.py
index 2bdf0ec4a..c441a9607 100644
--- a/neural_speed/convert/convert_mpt.py
+++ b/neural_speed/convert/convert_mpt.py
@@ -97,7 +97,7 @@ def main(args_in: Optional[List[str]] = None) -> None:
fout.write(struct.pack("i", 0))
fout.write(struct.pack("f", hparams.get("rms_norm_eps", 1e-6))) # rms norm eps
fout.write(struct.pack("f", 10000.0)) # freq_base
-
+ fout.write(struct.pack("f", 1.0)) # rope_factor
fout.write(struct.pack("i", tokenizer.bos_token_id if tokenizer.bos_token_id is not None else 1))
fout.write(struct.pack("i", tokenizer.eos_token_id if tokenizer.eos_token_id is not None else 2))
fout.write(struct.pack("i", tokenizer.pad_token_id if tokenizer.pad_token_id is not None else -1))
diff --git a/neural_speed/convert/convert_opt.py b/neural_speed/convert/convert_opt.py
index 76c27c057..4f487f68c 100644
--- a/neural_speed/convert/convert_opt.py
+++ b/neural_speed/convert/convert_opt.py
@@ -108,6 +108,7 @@ def main(args_in: Optional[List[str]] = None) -> None:
fout.write(struct.pack("i", 0))
fout.write(struct.pack("f", hparams.get("rms_norm_eps", 1e-6))) # rms norm eps
fout.write(struct.pack("f", 10000.0)) # freq_base
+ fout.write(struct.pack("f", 1.0)) # rope_factor
fout.write(struct.pack("i", tokenizer.bos_token_id if tokenizer.bos_token_id is not None else 1))
fout.write(struct.pack("i", tokenizer.eos_token_id if tokenizer.eos_token_id is not None else 2))
diff --git a/neural_speed/convert/convert_quantized_bloom.py b/neural_speed/convert/convert_quantized_bloom.py
index c0e88a2d7..823831efa 100644
--- a/neural_speed/convert/convert_quantized_bloom.py
+++ b/neural_speed/convert/convert_quantized_bloom.py
@@ -170,7 +170,7 @@ def bytes_to_unicode():
f.write(struct.pack("i", 0))
f.write(struct.pack("i", 0))
f.write(struct.pack("i", 0))
-fout.write(struct.pack("f", 1e-6)) # rms norm eps
+f.write(struct.pack("f", 1e-6)) # rms norm eps
f.write(struct.pack("i", tokenizer.bos_token_id if tokenizer.bos_token_id is not None else 1))
f.write(struct.pack("i", tokenizer.eos_token_id if tokenizer.eos_token_id is not None else 2))
diff --git a/neural_speed/convert/convert_quantized_llama.py b/neural_speed/convert/convert_quantized_llama.py
index 6b9bfd79d..5403cd3e3 100644
--- a/neural_speed/convert/convert_quantized_llama.py
+++ b/neural_speed/convert/convert_quantized_llama.py
@@ -80,6 +80,7 @@ def main(args_in: Optional[List[str]] = None) -> None:
f.write(struct.pack("f", config["rms_norm_eps"]))
f.write(struct.pack("f", config["rope_theta"] if "rope_theta" in config else 10000))
+ f.write(struct.pack("f", config["rope_scale"]))
# TODO, bos_token_id = 0 in https://huggingface.co/decapoda-research/llama-7b-hf/blob/main/config.json
# but bos_token_id = 1 in llama.cpp
diff --git a/neural_speed/convert/convert_quantized_mistral.py b/neural_speed/convert/convert_quantized_mistral.py
index 3552320f7..e09a9ad02 100644
--- a/neural_speed/convert/convert_quantized_mistral.py
+++ b/neural_speed/convert/convert_quantized_mistral.py
@@ -83,6 +83,7 @@ def main(args_in: Optional[List[str]] = None) -> None:
f.write(struct.pack("f", config["rms_norm_eps"]))
f.write(struct.pack("f", config["rope_theta"] if "rope_theta" in config else 10000))
+ f.write(struct.pack("f", config["rope_scale"]))
# TODO, bos_token_id = 0 in https://huggingface.co/decapoda-research/llama-7b-hf/blob/main/config.json
# but bos_token_id = 1 in llama.cpp
diff --git a/neural_speed/convert/convert_qwen.py b/neural_speed/convert/convert_qwen.py
index c559dd398..a4b9c7f4f 100644
--- a/neural_speed/convert/convert_qwen.py
+++ b/neural_speed/convert/convert_qwen.py
@@ -114,6 +114,7 @@ def main(args_in: Optional[List[str]] = None) -> None:
fout.write(struct.pack("i", 0))
fout.write(struct.pack("f", hparams.get("rms_norm_eps", 1e-6))) # rms norm eps
fout.write(struct.pack("f", 10000.0)) # freq_base
+ fout.write(struct.pack("f", 1.0)) # rope_factor
fout.write(struct.pack("i", tokenizer.special_tokens['<|endoftext|>']))
fout.write(struct.pack("i", tokenizer.special_tokens['<|endoftext|>']))
diff --git a/neural_speed/convert/convert_starcoder.py b/neural_speed/convert/convert_starcoder.py
index 327f88864..932be3f8e 100644
--- a/neural_speed/convert/convert_starcoder.py
+++ b/neural_speed/convert/convert_starcoder.py
@@ -112,6 +112,7 @@ def main(args_in: Optional[List[str]] = None) -> None:
fout.write(struct.pack("i", 0))
fout.write(struct.pack("f", hparams.get("rms_norm_eps", 1e-6))) # rms norm eps
fout.write(struct.pack("f", 10000.0)) # freq_base
+ fout.write(struct.pack("f", 1.0)) # rope_factor
fout.write(struct.pack("i", tokenizer.bos_token_id if tokenizer.bos_token_id is not None else 1))
fout.write(struct.pack("i", tokenizer.eos_token_id if tokenizer.eos_token_id is not None else 2))
diff --git a/neural_speed/core/ne_layers.c b/neural_speed/core/ne_layers.c
index 3ec194a1c..509c047ec 100644
--- a/neural_speed/core/ne_layers.c
+++ b/neural_speed/core/ne_layers.c
@@ -2980,7 +2980,7 @@ struct ne_tensor* ne_soft_max_inplace(struct ne_context* ctx, struct ne_tensor*
struct ne_tensor* ne_rope_impl(struct ne_context* ctx, struct ne_tensor* a, int n_past, int n_dims, int mode,
int prompt_size, bool inplace, int n_keep, struct ne_tensor* cossin, int* n_padding,
- bool padding_left, float freq_base) {
+ bool padding_left, float freq_base, float freq_scale) {
NE_ASSERT(n_past >= 0 || n_keep >= 0);
NE_ASSERT(padding_left);
bool is_node = false;
@@ -3020,7 +3020,9 @@ struct ne_tensor* ne_rope_impl(struct ne_context* ctx, struct ne_tensor* a, int
ne_scratch_load(ctx);
- ne_set_op_params(result, &freq_base, sizeof(freq_base));
+ float params[] = {freq_base, freq_scale};
+ ne_set_op_params(result, ¶ms, sizeof(params));
+
result->op = NE_OP_ROPE;
result->grad = is_node ? ne_dup_tensor(ctx, result) : NULL;
result->src0 = a;
@@ -3031,18 +3033,20 @@ struct ne_tensor* ne_rope_impl(struct ne_context* ctx, struct ne_tensor* a, int
}
struct ne_tensor* ne_rope(struct ne_context* ctx, struct ne_tensor* a, int n_past, int n_dims, int mode,
- int prompt_size, float freq_base) {
- return ne_rope_impl(ctx, a, n_past, n_dims, mode, prompt_size, false, -1, NULL, NULL, true, freq_base);
+ int prompt_size, float freq_base, float freq_scale) {
+ return ne_rope_impl(ctx, a, n_past, n_dims, mode, prompt_size, false, -1, NULL, NULL, true, freq_base, freq_scale);
}
struct ne_tensor* ne_rope_inplace(struct ne_context* ctx, struct ne_tensor* a, int n_past, int n_dims, int mode,
- int prompt_size, float freq_base) {
- return ne_rope_impl(ctx, a, n_past, n_dims, mode, prompt_size, true, -1, NULL, NULL, true, freq_base);
+ int prompt_size, float freq_base, float freq_scale) {
+ return ne_rope_impl(ctx, a, n_past, n_dims, mode, prompt_size, true, -1, NULL, NULL, true, freq_base, freq_scale);
}
struct ne_tensor* ne_rope_shift_inplace(struct ne_context* ctx, struct ne_tensor* a, int n_shift, int n_dims, int mode,
- int prompt_size, int n_keep, struct ne_tensor* cossin, float freq_base) {
- return ne_rope_impl(ctx, a, n_shift, n_dims, mode, prompt_size, true, n_keep, cossin, NULL, true, freq_base);
+ int prompt_size, int n_keep, struct ne_tensor* cossin, float freq_base,
+ float freq_scale) {
+ return ne_rope_impl(ctx, a, n_shift, n_dims, mode, prompt_size, true, n_keep, cossin, NULL, true, freq_base,
+ freq_scale);
}
// ne_rope_back
@@ -3078,13 +3082,16 @@ struct ne_tensor* ne_rope_back(struct ne_context* ctx, struct ne_tensor* a, int
}
struct ne_tensor* ne_rope_with_padding(struct ne_context* ctx, struct ne_tensor* a, int n_past, int n_dims, int mode,
- int prompt_size, int* n_padding, float freq_base) {
- return ne_rope_impl(ctx, a, n_past, n_dims, mode, prompt_size, false, -1, NULL, n_padding, true, freq_base);
+ int prompt_size, int* n_padding, float freq_base, float freq_scale) {
+ return ne_rope_impl(ctx, a, n_past, n_dims, mode, prompt_size, false, -1, NULL, n_padding, true, freq_base,
+ freq_scale);
}
struct ne_tensor* ne_rope_with_padding_inplace(struct ne_context* ctx, struct ne_tensor* a, int n_past, int n_dims,
- int mode, int prompt_size, int* n_padding, float freq_base) {
- return ne_rope_impl(ctx, a, n_past, n_dims, mode, prompt_size, true, -1, NULL, n_padding, true, freq_base);
+ int mode, int prompt_size, int* n_padding, float freq_base,
+ float freq_scale) {
+ return ne_rope_impl(ctx, a, n_past, n_dims, mode, prompt_size, true, -1, NULL, n_padding, true, freq_base,
+ freq_scale);
}
// ne_alibi
@@ -7868,9 +7875,8 @@ static void ne_compute_forward_rope_f32(const struct ne_compute_params* params,
NE_ASSERT(src1->type == NE_TYPE_I32);
NE_ASSERT(ne_nelements(src1) == 5 + bs); // 5 + bs params
- float freq_base = 10000.0f;
- memcpy(&freq_base, dst->op_params, sizeof(float));
- static const float freq_scale = 1.0f;
+ const float freq_base = ((float*)(dst->op_params))[0];
+ const float freq_scale = 1 / ((float*)(dst->op_params))[1];
const int64_t n_past = ((int32_t*)src1->data)[ROPE_NPAST_IDX];
const int64_t n_dims = ((int32_t*)src1->data)[ROPE_NDIMS_IDX];
@@ -8044,7 +8050,10 @@ static void ne_compute_forward_rope_f16(const struct ne_compute_params* params,
// row index used to determine which thread to use
int ir = 0;
- const float theta_scale = powf(10000.0, -2.0f / n_dims);
+ const float freq_base = ((float*)(dst->op_params))[0];
+ const float freq_scale = 1 / ((float*)(dst->op_params))[1];
+
+ const float theta_scale = powf(freq_base, -2.0f / n_dims);
const bool skip = mode & 1;
const bool is_neox = mode & 2;
@@ -8054,7 +8063,7 @@ static void ne_compute_forward_rope_f16(const struct ne_compute_params* params,
NE_ASSERT(("shift RoPE is only implemented for the vanilla mode", !is_shift || !(is_glm || is_neox || skip)));
if (is_shift) {
- float theta = n_past;
+ float theta = n_past * freq_scale;
ne_fp16_t* cossin = (dst->opt[0] != NULL) ? dst->opt[0]->data : NULL;
if (cossin == NULL) {
cossin = malloc(ne0 * sizeof(ne_fp16_t));
@@ -8099,7 +8108,7 @@ static void ne_compute_forward_rope_f16(const struct ne_compute_params* params,
if (ir++ < ir0) continue;
if (ir > ir1) break;
- float theta = (float)p;
+ float theta = freq_scale * (float)p;
if (!is_neox) {
for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
@@ -8172,12 +8181,13 @@ static void ne_compute_forward_rope_bestla(const struct ne_compute_params* param
const int head_num = dst->ne[2];
const int seq_len = dst->ne[1];
const int head_size = dst->ne[0];
-
+ const float freq_base = ((float*)(dst->op_params))[0];
+ const float freq_scale = 1 / ((float*)(dst->op_params))[1];
if (is_shift) {
ne_fp16_t* cossin = (dst->opt[0] != NULL) ? dst->opt[0]->data : NULL;
if (cossin == NULL) {
- float theta = n_past;
- const float theta_scale = powf(10000.0, -2.0f / n_dims);
+ float theta = n_past * freq_scale;
+ const float theta_scale = powf(freq_base, -2.0f / n_dims);
cossin = malloc(head_size * sizeof(ne_fp16_t));
for (int i0 = 0; i0 < head_size; i0 += 2) {
cossin[i0 + 0] = NE_FP32_TO_FP16(cosf(theta));
@@ -10017,7 +10027,7 @@ static void ne_compute_backward(struct ne_context* ctx, struct ne_tensor* tensor
const int n_dims = ((int32_t*)src1->data)[1];
const int mode = ((int32_t*)src1->data)[2];
src0->grad =
- ne_add_impl(ctx, src0->grad, ne_rope(ctx, tensor->grad, n_past, n_dims, mode, 0, 10000.0), inplace);
+ ne_add_impl(ctx, src0->grad, ne_rope(ctx, tensor->grad, n_past, n_dims, mode, 0, 10000.0, 1.0), inplace);
}
if (src1->grad) {
// noop
diff --git a/neural_speed/core/ne_layers.h b/neural_speed/core/ne_layers.h
index 88428b51e..d8bbf02d8 100644
--- a/neural_speed/core/ne_layers.h
+++ b/neural_speed/core/ne_layers.h
@@ -403,29 +403,30 @@ NE_API struct ne_tensor* ne_soft_max_inplace(struct ne_context* ctx, struct ne_t
// if mode & 4 == 1, especially for glm
// TODO: avoid creating a new tensor every time
NE_API struct ne_tensor* ne_rope(struct ne_context* ctx, struct ne_tensor* a, int n_past, int n_dims, int mode,
- int prompt_size, float freq_base);
+ int prompt_size, float freq_base, float freq_scale);
// in-place, returns view(a)
NE_API struct ne_tensor* ne_rope_inplace(struct ne_context* ctx, struct ne_tensor* a, int n_past, int n_dims, int mode,
- int prompt_size, float freq_base);
+ int prompt_size, float freq_base, float freq_scale);
// shift all tokens by a give p (n_shift)
// Optionally give a 1d tensor of precomputed interleaved cos/sin value of n_shift*scale^k for k \in [0, n_dims)
NE_API struct ne_tensor* ne_rope_shift_inplace(struct ne_context* ctx, struct ne_tensor* a, int n_shift, int n_dims,
int mode, int prompt_size, int n_keep, struct ne_tensor* cossin,
- float freq_base);
+ float freq_base, float freq_scale);
// rotary position embedding backward, i.e compute dx from dy
// a - dy
NE_API struct ne_tensor* ne_rope_back(struct ne_context* ctx, struct ne_tensor* a, int n_past, int n_dims, int mode);
NE_API struct ne_tensor* ne_rope_with_padding(struct ne_context* ctx, struct ne_tensor* a, int n_past, int n_dims,
- int mode, int prompt_size, int* n_padding, float freq_base);
+ int mode, int prompt_size, int* n_padding, float freq_base,
+ float freq_scale);
// in-place, returns view(a)
NE_API struct ne_tensor* ne_rope_with_padding_inplace(struct ne_context* ctx, struct ne_tensor* a, int n_past,
int n_dims, int mode, int prompt_size, int* n_padding,
- float freq_base);
+ float freq_base, float freq_scale);
// alibi position embedding
// in-place, returns view(a)
diff --git a/neural_speed/models/chatglm/chatglm.cpp b/neural_speed/models/chatglm/chatglm.cpp
index f7f927af4..69c973b39 100644
--- a/neural_speed/models/chatglm/chatglm.cpp
+++ b/neural_speed/models/chatglm/chatglm.cpp
@@ -137,14 +137,14 @@ static bool chatglm_model_eval_internal(model_context* ctx, const model_input* i
ne_set_name(query_layer, "query_layer");
query_layer = ne_rope_with_padding_inplace(ctx0, query_layer, n_past, rope_dim, 4, first_tokens_size,
- n_padding.data(), hparams.freq_base);
+ n_padding.data(), hparams.freq_base, hparams.freq_scale);
query_layer = ne_permute(ctx0, query_layer, 0, 2, 1, 3); // [bs, heads, qlen, head_size]
ne_tensor* key_layer =
ne_view_4d(ctx0, cur, head_size, num_attention_heads, qlen, batch_size, 3 * head_size * ne_element_size(cur),
cur->nb[1], cur->nb[1] * qlen, head_size * ne_element_size(cur)); // [bs, qlen, heads, head_size]
key_layer = ne_rope_with_padding_inplace(ctx0, key_layer, n_past, rope_dim, 4, first_tokens_size,
- n_padding.data(), hparams.freq_base);
+ n_padding.data(), hparams.freq_base, hparams.freq_scale);
ne_tensor* value_layer = ne_view_4d(ctx0, cur, head_size, num_attention_heads, qlen, batch_size,
3 * head_size * ne_element_size(cur), cur->nb[1], cur->nb[1] * qlen,
diff --git a/neural_speed/models/chatglm/chatglm2.cpp b/neural_speed/models/chatglm/chatglm2.cpp
index 560452a15..59d1a39e1 100644
--- a/neural_speed/models/chatglm/chatglm2.cpp
+++ b/neural_speed/models/chatglm/chatglm2.cpp
@@ -146,14 +146,15 @@ static bool chatglm_model_eval_internal(model_context* ctx, const model_input* i
ne_view_3d(ctx0, cur, head_size, n_head, N, head_size * ne_element_size(cur), cur->nb[1],
0); // [N, heads, head_size]
ne_set_name(query_layer, "query_layer");
- query_layer = ne_rope_inplace(ctx0, query_layer, std::max(n_cached - N, n_past), n_rot, 0, 0, hparams.freq_base);
+ query_layer = ne_rope_inplace(ctx0, query_layer, std::max(n_cached - N, n_past), n_rot, 0, 0, hparams.freq_base,
+ hparams.freq_scale);
struct ne_tensor* key_layer =
ne_view_3d(ctx0, cur, head_size, num_kv_heads, N, head_size * ne_element_size(cur), cur->nb[1],
hidden_size * ne_element_size(cur)); // [N, kv_heads, head_size]
ne_set_name(key_layer, "key_layer");
key_layer = ne_rope_inplace( // n_ctx exceeds but it will be shift-roped back with cached K
- ctx0, key_layer, (is_ring_full ? n_ctx : n_past), n_rot, 0, 0, hparams.freq_base);
+ ctx0, key_layer, (is_ring_full ? n_ctx : n_past), n_rot, 0, 0, hparams.freq_base, hparams.freq_scale);
struct ne_tensor* value_layer =
ne_view_3d(ctx0, cur, head_size, num_kv_heads, N, head_size * ne_element_size(cur), cur->nb[1],
@@ -198,7 +199,8 @@ static bool chatglm_model_eval_internal(model_context* ctx, const model_input* i
// Currently we only cache cossin for N == 1 in model-wide; It may be worthwhile to cache cossin for other N
// in a single eval execution
if (N == 1) cossin_cache = kv_self.cossin;
- key_layer = ne_rope_shift_inplace(ctx0, key_layer, -N, n_rot, 0, 0, n_keep, cossin_cache, hparams.freq_base);
+ key_layer = ne_rope_shift_inplace(ctx0, key_layer, -N, n_rot, 0, 0, n_keep, cossin_cache, hparams.freq_base,
+ hparams.freq_scale);
key_layer = ne_permute(ctx0, key_layer, 0, 2, 1, 3); // perm back
}
@@ -253,7 +255,8 @@ static bool chatglm_model_eval_internal(model_context* ctx, const model_input* i
// Currently we only cache cossin for N == 1 in model-wide; It may be worthwhile to cache cossin for other N
// in a single eval execution
if (N == 1) cossin_cache = kv_self.cossin;
- key_layer = ne_rope_shift_inplace(ctx0, key_layer, -N, n_rot, 0, 0, n_keep, cossin_cache, hparams.freq_base);
+ key_layer = ne_rope_shift_inplace(ctx0, key_layer, -N, n_rot, 0, 0, n_keep, cossin_cache, hparams.freq_base,
+ hparams.freq_scale);
}
value_layer =
ne_view_3d(ctx0, model.layers[il].v_cache, // tensor
diff --git a/neural_speed/models/falcon/falcon.cpp b/neural_speed/models/falcon/falcon.cpp
index 6b624e4bb..f45dabdb6 100644
--- a/neural_speed/models/falcon/falcon.cpp
+++ b/neural_speed/models/falcon/falcon.cpp
@@ -162,8 +162,8 @@ static bool falcon_model_eval_internal(model_context* ctx, const model_input* in
fused_qkv_row_nb, (n_embd + n_head_kv * head_dim) * ne_element_size(cur));
// using mode = 2 for neox mode
- Qcur = ne_rope_inplace(ctx0, Qcur, n_past, head_dim, 2, 0, hparams.freq_base);
- Kcur = ne_rope_inplace(ctx0, Kcur, n_past, head_dim, 2, 0, hparams.freq_base);
+ Qcur = ne_rope_inplace(ctx0, Qcur, n_past, head_dim, 2, 0, hparams.freq_base, hparams.freq_scale);
+ Kcur = ne_rope_inplace(ctx0, Kcur, n_past, head_dim, 2, 0, hparams.freq_base, hparams.freq_scale);
// self-attention
const float attn_scale = 1.0f / sqrtf(static_cast(head_dim));
diff --git a/neural_speed/models/gptj/gptj.cpp b/neural_speed/models/gptj/gptj.cpp
index 1dde4ca38..25914885e 100644
--- a/neural_speed/models/gptj/gptj.cpp
+++ b/neural_speed/models/gptj/gptj.cpp
@@ -186,9 +186,10 @@ static bool gptj_model_eval_internal(model_context* ctx, const model_input* inpu
Kcur = ne_reshape_4d(ctx0, ne_mul_mat(ctx0, model.layers[il].attn[1], cur), head_size, n_head, N, batch_size);
Vcur = ne_mul_mat(ctx0, model.layers[il].attn[2], cur);
}
- Qcur = ne_rope_inplace(ctx0, Qcur, std::max(n_cached - N, n_past), n_rot, 0, 0, hparams.freq_base);
+ Qcur =
+ ne_rope_inplace(ctx0, Qcur, std::max(n_cached - N, n_past), n_rot, 0, 0, hparams.freq_base, hparams.freq_scale);
Kcur = ne_rope_inplace( // n_ctx exceeds but it will be shift-roped back with cached K
- ctx0, Kcur, (is_ring_full ? n_ctx : n_past), n_rot, 0, 0, hparams.freq_base);
+ ctx0, Kcur, (is_ring_full ? n_ctx : n_past), n_rot, 0, 0, hparams.freq_base, hparams.freq_scale);
ne_set_name(Qcur, "Qcur");
ne_set_name(Kcur, "Kcur");
ne_set_name(Vcur, "Vcur");
@@ -292,8 +293,11 @@ static bool gptj_model_eval_internal(model_context* ctx, const model_input* inpu
struct ne_tensor* cossin_cache = nullptr;
// Currently we only cache cossin for N == 1 in model-wide; It may be worthwhile to cache cossin for other N
// in a single eval execution
- if (N == 1) cossin_cache = kv_self.cossin;
- K = ne_rope_shift_inplace(ctx0, K, -N, n_rot, 0, 0, n_keep, cossin_cache, hparams.freq_base);
+ if (N == 1) {
+ cossin_cache = kv_self.cossin;
+ }
+ K = ne_rope_shift_inplace(ctx0, K, -N, n_rot, 0, 0, n_keep, cossin_cache, hparams.freq_base,
+ hparams.freq_scale);
}
const auto v_size = kv_cache_info.v_bytes;
V = ne_view_4d(ctx0, kv_self.v, // tensor
@@ -320,8 +324,11 @@ static bool gptj_model_eval_internal(model_context* ctx, const model_input* inpu
struct ne_tensor* cossin_cache = nullptr;
// Currently we only cache cossin for N == 1 in model-wide; It may be worthwhile to cache cossin for other N in
// a single eval execution
- if (N == 1) cossin_cache = kv_self.cossin;
- K = ne_rope_shift_inplace(ctx0, K, -N, n_rot, 0, 0, n_keep, cossin_cache, hparams.freq_base);
+ if (N == 1) {
+ cossin_cache = kv_self.cossin;
+ }
+ K = ne_rope_shift_inplace(ctx0, K, -N, n_rot, 0, 0, n_keep, cossin_cache, hparams.freq_base,
+ hparams.freq_scale);
K = ne_permute(ctx0, K, 0, 2, 1, 3);
}
} else {
@@ -331,8 +338,11 @@ static bool gptj_model_eval_internal(model_context* ctx, const model_input* inpu
struct ne_tensor* cossin_cache = nullptr;
// Currently we only cache cossin for N == 1 in model-wide; It may be worthwhile to cache cossin for other N in
// a single eval execution
- if (N == 1) cossin_cache = kv_self.cossin;
- K = ne_rope_shift_inplace(ctx0, K, -N, n_rot, 0, 0, n_keep, cossin_cache, hparams.freq_base);
+ if (N == 1) {
+ cossin_cache = kv_self.cossin;
+ }
+ K = ne_rope_shift_inplace(ctx0, K, -N, n_rot, 0, 0, n_keep, cossin_cache, hparams.freq_base,
+ hparams.freq_scale);
K = ne_permute(ctx0, K, 0, 2, 1, 3);
}
diff --git a/neural_speed/models/gptneox/gptneox.cpp b/neural_speed/models/gptneox/gptneox.cpp
index 80ec83d7a..4652e4664 100644
--- a/neural_speed/models/gptneox/gptneox.cpp
+++ b/neural_speed/models/gptneox/gptneox.cpp
@@ -188,9 +188,9 @@ static bool gptneox_model_eval_internal(model_context* ctx, const model_input* i
// using mode = 2 for GPT-NeoX mode
Qcur = ne_rope_inplace(ctx0, ne_reshape_4d(ctx0, Qcur, head_dim, n_head, N, batch_size), n_past, n_rot, 2, 0,
- hparams.freq_base);
+ hparams.freq_base, hparams.freq_scale);
Kcur = ne_rope_inplace(ctx0, ne_reshape_4d(ctx0, Kcur, head_dim, n_head, N, batch_size), n_past, n_rot, 2, 0,
- hparams.freq_base);
+ hparams.freq_base, hparams.freq_scale);
const float attn_scale = 1.0f / sqrtf(static_cast(head_dim));
// store key and value to memory
if (!run_mha_reordered) {
diff --git a/neural_speed/models/llama/llama.cpp b/neural_speed/models/llama/llama.cpp
index 39cda575b..65673671f 100644
--- a/neural_speed/models/llama/llama.cpp
+++ b/neural_speed/models/llama/llama.cpp
@@ -187,10 +187,11 @@ static bool llama_model_eval_internal(model_context* ctx, const model_input* inp
Kcur = ne_reshape_3d(ctx0, ne_mul_mat(ctx0, model.layers[il].attn[1], cur), head_size, n_head_kv, N);
Vcur = ne_mul_mat(ctx0, model.layers[il].attn[2], cur);
}
- Qcur = ne_rope_inplace(ctx0, Qcur, std::max(n_cached - N, n_past), n_rot, 0, 0, hparams.freq_base);
+ Qcur =
+ ne_rope_inplace(ctx0, Qcur, std::max(n_cached - N, n_past), n_rot, 0, 0, hparams.freq_base, hparams.freq_scale);
ne_set_name(Qcur, "Qcur");
Kcur = ne_rope_inplace( // n_ctx exceeds but it will be shift-roped back with cached K
- ctx0, Kcur, (is_ring_full ? n_ctx : n_past), n_rot, 0, 0, hparams.freq_base);
+ ctx0, Kcur, (is_ring_full ? n_ctx : n_past), n_rot, 0, 0, hparams.freq_base, hparams.freq_scale);
ne_set_name(Kcur, "Kcur");
Vcur = ne_transpose(ctx0, ne_reshape_2d(ctx0, Vcur, head_size * n_head_kv, N));
ne_set_name(Vcur, "Vcur");
@@ -220,8 +221,11 @@ static bool llama_model_eval_internal(model_context* ctx, const model_input* inp
struct ne_tensor* cossin_cache = nullptr;
// Currently we only cache cossin for N == 1 in model-wide; It may be worthwhile to cache cossin for other N in
// a single eval execution
- if (N == 1) cossin_cache = kv_self.cossin;
- K = ne_rope_shift_inplace(ctx0, K, -N, n_rot, 0, 0, n_keep, cossin_cache, hparams.freq_base);
+ if (N == 1) {
+ cossin_cache = kv_self.cossin;
+ }
+ K = ne_rope_shift_inplace(ctx0, K, -N, n_rot, 0, 0, n_keep, cossin_cache, hparams.freq_base,
+ hparams.freq_scale);
}
K = ne_permute(ctx0, K, 0, 2, 1, 3);
ne_set_name(K, "K");
@@ -301,7 +305,8 @@ static bool llama_model_eval_internal(model_context* ctx, const model_input* inp
// Currently we only cache cossin for N == 1 in model-wide; It may be worthwhile to cache cossin for other N in
// a single eval execution
if (N == 1) cossin_cache = kv_self.cossin;
- K = ne_rope_shift_inplace(ctx0, K, -N, n_rot, 0, 0, n_keep, cossin_cache, hparams.freq_base);
+ K = ne_rope_shift_inplace(ctx0, K, -N, n_rot, 0, 0, n_keep, cossin_cache, hparams.freq_base,
+ hparams.freq_scale);
}
ne_set_name(K, "K");
diff --git a/neural_speed/models/mistral/CMakeLists.txt b/neural_speed/models/mistral/CMakeLists.txt
deleted file mode 100644
index dd894d1a2..000000000
--- a/neural_speed/models/mistral/CMakeLists.txt
+++ /dev/null
@@ -1,19 +0,0 @@
-# Copyright (c) 2023 Intel Corporation
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-set(TARGET mistral)
-add_library_w_warning(${TARGET} mistral.cpp mistral_utils.cpp ${MODEL_UTILS_SOURCE})
-target_compile_features(${TARGET} PUBLIC cxx_std_11) # don't bump
-set_target_properties(${TARGET} PROPERTIES POSITION_INDEPENDENT_CODE ON)
-target_link_libraries(${TARGET} PUBLIC ne_layers bestla::bestla)
diff --git a/neural_speed/models/model_utils/model_files.h b/neural_speed/models/model_utils/model_files.h
index 30366104e..10b0c24cb 100644
--- a/neural_speed/models/model_utils/model_files.h
+++ b/neural_speed/models/model_utils/model_files.h
@@ -1043,6 +1043,7 @@ struct model_file_loader {
file.read_raw(&hparams.rms_norm_eps, sizeof(float));
file.read_raw(&hparams.freq_base, sizeof(float));
+ file.read_raw(&hparams.freq_scale, sizeof(float));
}
void read_ne_vocab() {
@@ -1162,6 +1163,7 @@ struct model_file_saver {
file.write_raw(&hparams.rms_norm_eps, sizeof(float));
file.write_raw(&hparams.freq_base, sizeof(float));
+ file.write_raw(&hparams.freq_scale, sizeof(float));
}
void write_vocab() {
if (any_file_loader->file_version == MODEL_FILE_VERSION_NE) {
diff --git a/neural_speed/models/model_utils/model_types.h b/neural_speed/models/model_utils/model_types.h
index 3c05b168b..3ec457562 100644
--- a/neural_speed/models/model_utils/model_types.h
+++ b/neural_speed/models/model_utils/model_types.h
@@ -126,7 +126,8 @@ struct model_hparams {
uint32_t word_embed_proj_dim = 0; // for opt
bool do_layer_norm_before = false; // for opt
float rms_norm_eps = 1e-6f; // rms norm epsilon
- float freq_base = 10000.0f;
+ float freq_base = 10000.0f; // rope theta
+ float freq_scale = 1.0f; // rope scale factor
// ChatGLM-2
int32_t multi_query_group_num = 0;
diff --git a/neural_speed/models/model_utils/model_utils.cpp b/neural_speed/models/model_utils/model_utils.cpp
index 63d84907f..08a83fc29 100644
--- a/neural_speed/models/model_utils/model_utils.cpp
+++ b/neural_speed/models/model_utils/model_utils.cpp
@@ -148,11 +148,12 @@ static bool kv_cache_init(const struct model_hparams& hparams, struct model_kv_c
const auto cossin_dtype = wtype == NE_TYPE_BTLA ? NE_TYPE_F16 : wtype;
cache.cossin = ne_new_tensor_1d(cache.ctx, cossin_dtype, head_size, NE_SIZE_CALC);
ne_set_name(cache.cossin, "cossin(-1)");
- float theta = -1;
+ float freq_base = hparams.freq_base;
+ float theta = -1 * hparams.freq_scale;
float theta_scale = (model != nullptr && model->arch == MODEL_CHATGLM2)
- ? std::pow(10000.f, -2.0f / (head_size / 2)) // chatglm2 has their DIM_SCALE of 2
- : hparams.n_rot > 0 ? std::pow(10000.f, -2.0f / hparams.n_rot)
- : std::pow(10000.f, -2.0f / head_size);
+ ? std::pow(freq_base, -2.0f / (head_size / 2)) // chatglm2 has their DIM_SCALE of 2
+ : hparams.n_rot > 0 ? std::pow(freq_base, -2.0f / hparams.n_rot)
+ : std::pow(freq_base, -2.0f / head_size);
if (cossin_dtype == NE_TYPE_F16) {
const auto data = reinterpret_cast(cache.cossin->data);
for (int i = 0; i < head_size; i += 2) {
diff --git a/neural_speed/models/qwen/qwen.cpp b/neural_speed/models/qwen/qwen.cpp
index 28ce4bcd6..837c09021 100644
--- a/neural_speed/models/qwen/qwen.cpp
+++ b/neural_speed/models/qwen/qwen.cpp
@@ -180,9 +180,9 @@ static bool qwen_model_eval_internal(model_context* ctx, const model_input* inpu
fused_qkv_row_nb, 2 * sizeof(float) * n_embd));
// using mode = 2 for GPT-NeoX mode
- Qcur = ne_rope_inplace(ctx0, Qcur, n_past, n_rot, 2, 0, hparams.freq_base);
+ Qcur = ne_rope_inplace(ctx0, Qcur, n_past, n_rot, 2, 0, hparams.freq_base, hparams.freq_scale);
ne_set_name(Qcur, "Qcur");
- Kcur = ne_rope_inplace(ctx0, Kcur, n_past, n_rot, 2, 0, hparams.freq_base);
+ Kcur = ne_rope_inplace(ctx0, Kcur, n_past, n_rot, 2, 0, hparams.freq_base, hparams.freq_scale);
ne_set_name(Kcur, "kcur");
const float attn_scale = 1.0f / sqrtf(static_cast(head_dim));
// store key and value to memory
diff --git a/tests/model-test/cpp_graph_inference.sh b/tests/model-test/cpp_graph_inference.sh
index 42d835e23..e5168ddda 100644
--- a/tests/model-test/cpp_graph_inference.sh
+++ b/tests/model-test/cpp_graph_inference.sh
@@ -247,7 +247,7 @@ function main() {
infer_cmd="./build/bin/run_qwen"
elif [[ "${model}" == "magicoder" ]]; then
quant_script="./build/bin/quant_llama"
- convert_script="${convert_script}/convert_bmagicoder.py"
+ convert_script="${convert_script}/convert_llama.py"
infer_cmd="./build/bin/run_llama"
elif [[ "${model}" == "whisper" ]]; then
quant_script="./build/bin/quant_whisper"
|