-
Notifications
You must be signed in to change notification settings - Fork 489
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Auto backend for pipeline and serve when backend is not set to pytorc…
…h explicitly (#1211) * add draft * update to use cfg * fix * enable gemma arch * resolve comments * add is_supported in each backend * add ut * fix ut
- Loading branch information
1 parent
a6e8188
commit 7dd97fd
Showing
8 changed files
with
374 additions
and
10 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from typing import Literal, Optional, Union | ||
|
||
from .messages import PytorchEngineConfig, TurbomindEngineConfig | ||
from .utils import get_logger | ||
|
||
logger = get_logger('lmdeploy') | ||
|
||
|
||
def autoget_backend(model_path: str) -> Union[Literal['turbomind', 'pytorch']]: | ||
"""Get backend type in auto backend mode. | ||
Args: | ||
model_path (str): the path of a model. | ||
It could be one of the following options: | ||
- i) A local directory path of a turbomind model which is | ||
converted by `lmdeploy convert` command or download from | ||
ii) and iii). | ||
- ii) The model_id of a lmdeploy-quantized model hosted | ||
inside a model repo on huggingface.co, such as | ||
"InternLM/internlm-chat-20b-4bit", | ||
"lmdeploy/llama2-chat-70b-4bit", etc. | ||
- iii) The model_id of a model hosted inside a model repo | ||
on huggingface.co, such as "internlm/internlm-chat-7b", | ||
"Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat" | ||
and so on. | ||
Returns: | ||
str: the backend type. | ||
""" | ||
from lmdeploy.pytorch.supported_models import \ | ||
is_supported as is_supported_pytorch | ||
|
||
pytorch_has, turbomind_has = False, False | ||
try: | ||
from lmdeploy.turbomind.supported_models import \ | ||
is_supported as is_supported_turbomind | ||
turbomind_has = is_supported_turbomind(model_path) | ||
except ImportError: | ||
logger.warning( | ||
'Lmdeploy with turbomind engine is not installed correctly. ' | ||
'You may need to install lmdeploy from pypi or build from source ' | ||
'for turbomind engine.') | ||
|
||
pytorch_has = is_supported_pytorch(model_path) | ||
|
||
if not (pytorch_has or turbomind_has): | ||
logger.warning(f'{model_path} is not explicitly supported by lmdeploy.' | ||
f' Try to run with lmdeploy pytorch engine.') | ||
backend = 'turbomind' if turbomind_has else 'pytorch' | ||
return backend | ||
|
||
|
||
def autoget_backend_config( | ||
model_path: str, | ||
backend_config: Optional[Union[PytorchEngineConfig, | ||
TurbomindEngineConfig]] = None | ||
) -> Union[PytorchEngineConfig, TurbomindEngineConfig]: | ||
"""Get backend config automatically. | ||
Args: | ||
model_path (str): The input model path. | ||
backend_config (TurbomindEngineConfig | PytorchEngineConfig): The | ||
input backend config. Default to None. | ||
Returns: | ||
(PytorchEngineConfig | TurbomindEngineConfig): The auto-determined | ||
backend engine config. | ||
""" | ||
from dataclasses import asdict | ||
|
||
backend = autoget_backend(model_path) | ||
if backend == 'pytorch': | ||
config = PytorchEngineConfig() | ||
else: | ||
config = TurbomindEngineConfig() | ||
if backend_config is not None: | ||
data = asdict(backend_config) | ||
for k, v in data.items(): | ||
if v and hasattr(config, k): | ||
setattr(config, k, v) | ||
return config |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
|
||
from transformers import AutoConfig | ||
|
||
from lmdeploy.utils import get_logger | ||
|
||
logger = get_logger('lmdeploy') | ||
|
||
_SUPPORTED_ARCHS = dict( | ||
# baichuan-7b | ||
BaiChuanForCausalLM=False, | ||
# baichuan2-7b, baichuan-13b, baichuan2-13b | ||
BaichuanForCausalLM=True, | ||
# chatglm2-6b, chatglm3-6b | ||
ChatGLMModel=True, | ||
# deepseek-moe | ||
DeepseekForCausalLM=True, | ||
# falcon-7b | ||
FalconForCausalLM=True, | ||
# gemma-7b | ||
GemmaForCausalLM=True, | ||
# internlm | ||
InternLMForCausalLM=True, | ||
# internlm2 | ||
InternLM2ForCausalLM=True, | ||
# internlm-xcomposer | ||
InternLMXComposerForCausalLM=False, | ||
# internlm2-xcomposer | ||
InternLM2XComposerForCausalLM=False, | ||
# llama, llama2, alpaca, vicuna, codellama, ultracm, yi, | ||
# deepseek-coder, deepseek-llm | ||
LlamaForCausalLM=True, | ||
# Mistral-7B | ||
MistralForCausalLM=True, | ||
# Mixtral-8x7B | ||
MixtralForCausalLM=True, | ||
# Qwen 7B-72B, Qwen-VL-7B | ||
QWenLMHeadModel=False, | ||
# Qwen1.5 7B-72B | ||
Qwen2ForCausalLM=True, | ||
) | ||
|
||
|
||
def is_supported(model_path: str): | ||
"""Check whether supported by pytorch engine. | ||
Args: | ||
model_path (str): the path of a model. | ||
It could be one of the following options: | ||
- i) A local directory path of a turbomind model which is | ||
converted by `lmdeploy convert` command or download from | ||
ii) and iii). | ||
- ii) The model_id of a lmdeploy-quantized model hosted | ||
inside a model repo on huggingface.co, such as | ||
"InternLM/internlm-chat-20b-4bit", | ||
"lmdeploy/llama2-chat-70b-4bit", etc. | ||
- iii) The model_id of a model hosted inside a model repo | ||
on huggingface.co, such as "internlm/internlm-chat-7b", | ||
"Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat" | ||
and so on. | ||
Returns: | ||
support_by_torch (bool): Whether input model is supported by pytorch engine | ||
""" # noqa: E501 | ||
import os | ||
|
||
support_by_torch = False | ||
|
||
triton_model_path = os.path.join(model_path, 'triton_models') | ||
if os.path.exists(triton_model_path): | ||
logger.warning(f'{model_path} seems to be a turbomind workspace, ' | ||
'which can only be ran with turbomind engine.') | ||
else: | ||
cfg = AutoConfig.from_pretrained(model_path, trust_remote_code=True) | ||
|
||
if hasattr(cfg, 'architectures'): | ||
arch = cfg.architectures[0] | ||
elif hasattr(cfg, | ||
'auto_map') and 'AutoModelForCausalLM' in cfg.auto_map: | ||
arch = cfg.auto_map['AutoModelForCausalLM'].split('.')[-1] | ||
else: | ||
raise RuntimeError( | ||
f'Could not find model architecture from config: {cfg}') | ||
|
||
if arch in _SUPPORTED_ARCHS: | ||
support_by_torch = _SUPPORTED_ARCHS[arch] | ||
# special cases | ||
if arch == 'BaichuanForCausalLM': | ||
# baichuan-13B not supported by pytorch | ||
if cfg.num_attention_heads == 40 and cfg.vocab_size == 64000: | ||
support_by_torch = False | ||
|
||
return support_by_torch |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from transformers import AutoConfig | ||
|
||
from lmdeploy.utils import get_logger | ||
|
||
logger = get_logger('lmdeploy') | ||
|
||
_SUPPORTED_ARCHS = dict( | ||
# baichuan-7b | ||
BaiChuanForCausalLM=True, | ||
# baichuan2-7b, baichuan-13b, baichuan2-13b | ||
BaichuanForCausalLM=True, | ||
# chatglm2-6b, chatglm3-6b | ||
ChatGLMModel=False, | ||
# deepseek-moe | ||
DeepseekForCausalLM=False, | ||
# falcon-7b | ||
FalconForCausalLM=False, | ||
# gemma-7b | ||
GemmaForCausalLM=False, | ||
# internlm | ||
InternLMForCausalLM=True, | ||
# internlm2 | ||
InternLM2ForCausalLM=True, | ||
# internlm-xcomposer | ||
InternLMXComposerForCausalLM=True, | ||
# internlm2-xcomposer | ||
InternLM2XComposerForCausalLM=False, | ||
# llama, llama2, alpaca, vicuna, codellama, ultracm, yi, | ||
# deepseek-coder, deepseek-llm | ||
LlamaForCausalLM=True, | ||
# Mistral-7B | ||
MistralForCausalLM=False, | ||
# Mixtral-8x7B | ||
MixtralForCausalLM=False, | ||
# Qwen 7B-72B, Qwen-VL-7B | ||
QWenLMHeadModel=True, | ||
# Qwen1.5 7B-72B | ||
Qwen2ForCausalLM=False) | ||
|
||
|
||
def is_supported(model_path: str): | ||
"""Check whether supported by turbomind engine. | ||
Args: | ||
model_path (str): the path of a model. | ||
It could be one of the following options: | ||
- i) A local directory path of a turbomind model which is | ||
converted by `lmdeploy convert` command or download from | ||
ii) and iii). | ||
- ii) The model_id of a lmdeploy-quantized model hosted | ||
inside a model repo on huggingface.co, such as | ||
"InternLM/internlm-chat-20b-4bit", | ||
"lmdeploy/llama2-chat-70b-4bit", etc. | ||
- iii) The model_id of a model hosted inside a model repo | ||
on huggingface.co, such as "internlm/internlm-chat-7b", | ||
"Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat" | ||
and so on. | ||
Returns: | ||
support_by_turbomind (bool): Whether input model is supported by turbomind engine | ||
""" # noqa: E501 | ||
import os | ||
|
||
support_by_turbomind = False | ||
triton_model_path = os.path.join(model_path, 'triton_models') | ||
if os.path.exists(triton_model_path): | ||
support_by_turbomind = True | ||
else: | ||
cfg = AutoConfig.from_pretrained(model_path, trust_remote_code=True) | ||
|
||
if hasattr(cfg, 'architectures'): | ||
arch = cfg.architectures[0] | ||
elif hasattr(cfg, | ||
'auto_map') and 'AutoModelForCausalLM' in cfg.auto_map: | ||
arch = cfg.auto_map['AutoModelForCausalLM'].split('.')[-1] | ||
else: | ||
raise RuntimeError( | ||
f'Could not find model architecture from config: {cfg}') | ||
|
||
if arch in _SUPPORTED_ARCHS: | ||
support_by_turbomind = _SUPPORTED_ARCHS[arch] | ||
# special cases | ||
if arch == 'BaichuanForCausalLM': | ||
num_attn_head = cfg.num_attention_heads | ||
if num_attn_head == 40: | ||
# baichuan-13B, baichuan2-13B not supported by turbomind | ||
support_by_turbomind = False | ||
return support_by_turbomind |
Oops, something went wrong.