Skip to content

Commit

Permalink
[SLM] RWKV6 World Support
Browse files Browse the repository at this point in the history
  • Loading branch information
Celve committed Mar 15, 2024
1 parent 99addfd commit 00cc94a
Show file tree
Hide file tree
Showing 5 changed files with 586 additions and 0 deletions.
14 changes: 14 additions & 0 deletions python/mlc_llm/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from .qwen import qwen_loader, qwen_model, qwen_quantization
from .qwen2 import qwen2_loader, qwen2_model, qwen2_quantization
from .rwkv5 import rwkv5_loader, rwkv5_model, rwkv5_quantization
from .rwkv6 import rwkv6_loader, rwkv6_model, rwkv6_quantization
from .stable_lm import stablelm_loader, stablelm_model, stablelm_quantization

ModelConfig = Any
Expand Down Expand Up @@ -292,4 +293,17 @@ class Model:
"group-quant": orion_quantization.group_quant,
},
),
"rwkv6": Model(
name="rwkv6",
model=rwkv6_model.RWKV6_ForCasualLM,
config=rwkv6_model.RWKV6Config,
source={
"huggingface-torch": rwkv6_loader.huggingface,
"huggingface-safetensor": rwkv6_loader.huggingface,
},
quantize={
"no-quant": rwkv6_quantization.no_quant,
"group-quant": rwkv6_quantization.group_quant,
},
),
}
Empty file.
72 changes: 72 additions & 0 deletions python/mlc_llm/model/rwkv6/rwkv6_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
"""
This file specifies how MLC's RWKV6 parameter maps from other formats, for example HuggingFace
PyTorch, HuggingFace safetensors.
"""

import functools

import numpy as np

from ...loader import ExternMapping
from ...quantization import Quantization
from .rwkv6_model import RWKV6_ForCasualLM, RWKV6Config


def huggingface(model_config: RWKV6Config, quantization: Quantization) -> ExternMapping:
"""Returns a parameter mapping that maps from the names of MLC LLM parameters to
the names of HuggingFace PyTorch parameters.
Parameters
----------
model_config : RWKVConfig
The configuration of the Mistral model.
quantization : Quantization
The quantization configuration.
Returns
-------
param_map : ExternMapping
The parameter mapping from MLC to HuggingFace PyTorch.
"""
model = RWKV6_ForCasualLM(model_config)
if quantization is not None:
model.to(quantization.model_dtype)
_, _named_params = model.export_tvm( # pylint: disable=unbalanced-tuple-unpacking
spec=model.get_default_spec()
)
named_parameters = dict(_named_params)

mapping = ExternMapping()

for i in range(model_config.num_hidden_layers):
# rescale
if model_config.rescale_every > 0:
for name in ["feed_forward.value.weight", "attention.output.weight"]:
mlc_name = f"model.blocks.{i}.{name}"
hf_name = f"rwkv.blocks.{i}.{name}"
mlc_param = named_parameters[mlc_name]

mapping.add_mapping(
mlc_name,
[hf_name],
functools.partial(
lambda x, dtype, t: x.astype(dtype) / (2**t),
dtype=mlc_param.dtype,
t=i // model_config.rescale_every,
),
)

for mlc_name, mlc_param in named_parameters.items():
if mlc_name not in mapping.param_map:
hf_name = mlc_name.replace("model", "rwkv")
mapping.add_mapping(
mlc_name,
[hf_name],
functools.partial(
lambda x, dtype: x.astype(dtype),
dtype=mlc_param.dtype,
),
)

return mapping
Loading

0 comments on commit 00cc94a

Please sign in to comment.