Skip to content

Commit

Permalink
Auto backend for pipeline and serve when backend is not set to pytorc…
Browse files Browse the repository at this point in the history
…h explicitly (#1211)

* add draft

* update to use cfg

* fix

* enable gemma arch

* resolve comments

* add is_supported in each backend

* add ut

* fix ut
  • Loading branch information
RunningLeon authored Mar 4, 2024
1 parent a6e8188 commit 7dd97fd
Show file tree
Hide file tree
Showing 8 changed files with 374 additions and 10 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/unit-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ jobs:
make -j$(nproc) && make install
- name: Install lmdeploy
run: |
python3 -m pip install pynvml packaging protobuf transformers_stream_generator transformers==4.33.0
python3 -m pip install pynvml packaging protobuf transformers_stream_generator
# manually install flash attn
python3 -m pip install /root/packages/flash_attn-2.3.6+cu118torch2.1cxx11abiFALSE-cp38-cp38-linux_x86_64.whl
python3 -m pip install -r requirements.txt -r requirements/test.txt
Expand Down
16 changes: 14 additions & 2 deletions lmdeploy/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
from typing import List, Literal, Optional, Union

from .archs import autoget_backend_config
from .messages import PytorchEngineConfig, TurbomindEngineConfig
from .model import ChatTemplateConfig

Expand Down Expand Up @@ -31,7 +32,7 @@ def pipeline(model_path: str,
model_name (str): needed when model_path is a pytorch model on
huggingface.co, such as "internlm/internlm-chat-7b",
"Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat" and so on.
backend_config (TurbomindEngineConfig | PytorchEngineConfig): beckend
backend_config (TurbomindEngineConfig | PytorchEngineConfig): backend
config instance. Default to None.
chat_template_config (ChatTemplateConfig): chat template configuration.
Default to None.
Expand All @@ -49,8 +50,13 @@ def pipeline(model_path: str,
from lmdeploy.utils import get_logger
logger = get_logger('lmdeploy')
logger.setLevel(log_level)

if type(backend_config) is not PytorchEngineConfig:
# set auto backend mode
backend_config = autoget_backend_config(model_path, backend_config)
backend = 'pytorch' if type(
backend_config) is PytorchEngineConfig else 'turbomind'
logger.info(f'Using {backend} engine')
if 'tp' in kwargs:
logger.warning(
'The argument "tp" is deprecated and will be removed soon. '
Expand Down Expand Up @@ -101,7 +107,7 @@ def serve(model_path: str,
"Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat" and so on.
backend (str): either `turbomind` or `pytorch` backend. Default to
`turbomind` backend.
backend_config (TurbomindEngineConfig | PytorchEngineConfig): beckend
backend_config (TurbomindEngineConfig | PytorchEngineConfig): backend
config instance. Default to none.
chat_template_config (ChatTemplateConfig): chat template configuration.
Default to None.
Expand All @@ -126,6 +132,12 @@ def serve(model_path: str,

from lmdeploy.serve.openai.api_client import APIClient
from lmdeploy.serve.openai.api_server import serve

if type(backend_config) is not PytorchEngineConfig:
# set auto backend mode
backend_config = autoget_backend_config(model_path, backend_config)
backend = 'pytorch' if type(
backend_config) is PytorchEngineConfig else 'turbomind'
if 'tp' in kwargs:
tp = kwargs['tp']
kwargs.pop('tp')
Expand Down
82 changes: 82 additions & 0 deletions lmdeploy/archs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Literal, Optional, Union

from .messages import PytorchEngineConfig, TurbomindEngineConfig
from .utils import get_logger

logger = get_logger('lmdeploy')


def autoget_backend(model_path: str) -> Union[Literal['turbomind', 'pytorch']]:
"""Get backend type in auto backend mode.
Args:
model_path (str): the path of a model.
It could be one of the following options:
- i) A local directory path of a turbomind model which is
converted by `lmdeploy convert` command or download from
ii) and iii).
- ii) The model_id of a lmdeploy-quantized model hosted
inside a model repo on huggingface.co, such as
"InternLM/internlm-chat-20b-4bit",
"lmdeploy/llama2-chat-70b-4bit", etc.
- iii) The model_id of a model hosted inside a model repo
on huggingface.co, such as "internlm/internlm-chat-7b",
"Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat"
and so on.
Returns:
str: the backend type.
"""
from lmdeploy.pytorch.supported_models import \
is_supported as is_supported_pytorch

pytorch_has, turbomind_has = False, False
try:
from lmdeploy.turbomind.supported_models import \
is_supported as is_supported_turbomind
turbomind_has = is_supported_turbomind(model_path)
except ImportError:
logger.warning(
'Lmdeploy with turbomind engine is not installed correctly. '
'You may need to install lmdeploy from pypi or build from source '
'for turbomind engine.')

pytorch_has = is_supported_pytorch(model_path)

if not (pytorch_has or turbomind_has):
logger.warning(f'{model_path} is not explicitly supported by lmdeploy.'
f' Try to run with lmdeploy pytorch engine.')
backend = 'turbomind' if turbomind_has else 'pytorch'
return backend


def autoget_backend_config(
model_path: str,
backend_config: Optional[Union[PytorchEngineConfig,
TurbomindEngineConfig]] = None
) -> Union[PytorchEngineConfig, TurbomindEngineConfig]:
"""Get backend config automatically.
Args:
model_path (str): The input model path.
backend_config (TurbomindEngineConfig | PytorchEngineConfig): The
input backend config. Default to None.
Returns:
(PytorchEngineConfig | TurbomindEngineConfig): The auto-determined
backend engine config.
"""
from dataclasses import asdict

backend = autoget_backend(model_path)
if backend == 'pytorch':
config = PytorchEngineConfig()
else:
config = TurbomindEngineConfig()
if backend_config is not None:
data = asdict(backend_config)
for k, v in data.items():
if v and hasattr(config, k):
setattr(config, k, v)
return config
3 changes: 2 additions & 1 deletion lmdeploy/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,8 @@ def list(args):
if engine == 'pytorch':
model_names = [
'llama', 'llama2', 'internlm', 'internlm2', 'baichuan2',
'chatglm2', 'falcon', 'yi', 'mistral', 'qwen1.5', 'gemma'
'chatglm2', 'falcon', 'yi', 'mistral', 'mixtral', 'qwen1.5',
'gemma', 'deepseek'
]
elif engine == 'turbomind':
from lmdeploy.model import MODELS
Expand Down
24 changes: 18 additions & 6 deletions lmdeploy/cli/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,18 +189,24 @@ def add_parser_triton_client():
@staticmethod
def gradio(args):
"""Serve LLMs with web UI using gradio."""
from lmdeploy.archs import autoget_backend
from lmdeploy.messages import (PytorchEngineConfig,
TurbomindEngineConfig)
from lmdeploy.model import ChatTemplateConfig
from lmdeploy.serve.gradio.app import run
if args.backend == 'pytorch':
from lmdeploy.messages import PytorchEngineConfig
backend = args.backend

if backend != 'pytorch' and ':' not in args.model_path_or_server:
# set auto backend mode
backend = autoget_backend(args.model_path_or_server)
if backend == 'pytorch':
backend_config = PytorchEngineConfig(
tp=args.tp,
model_name=args.model_name,
max_batch_size=args.max_batch_size,
cache_max_entry_count=args.cache_max_entry_count,
session_len=args.session_len)
else:
from lmdeploy.messages import TurbomindEngineConfig
backend_config = TurbomindEngineConfig(
model_name=args.model_name,
tp=args.tp,
Expand All @@ -217,16 +223,22 @@ def gradio(args):
run(args.model_path_or_server,
server_name=args.server_name,
server_port=args.server_port,
backend=args.backend,
backend=backend,
backend_config=backend_config,
chat_template_config=chat_template_config)

@staticmethod
def api_server(args):
"""Serve LLMs with restful api using fastapi."""
from lmdeploy.archs import autoget_backend
from lmdeploy.model import ChatTemplateConfig
from lmdeploy.serve.openai.api_server import serve as run_api_server
if args.backend == 'pytorch':
backend = args.backend
if backend != 'pytorch':
# set auto backend mode
backend = autoget_backend(args.model_path)

if backend == 'pytorch':
from lmdeploy.messages import PytorchEngineConfig
backend_config = PytorchEngineConfig(
tp=args.tp,
Expand All @@ -250,7 +262,7 @@ def api_server(args):
meta_instruction=args.meta_instruction,
capability=args.cap)
run_api_server(args.model_path,
backend=args.backend,
backend=backend,
backend_config=backend_config,
chat_template_config=chat_template_config,
server_name=args.server_name,
Expand Down
92 changes: 92 additions & 0 deletions lmdeploy/pytorch/supported_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# Copyright (c) OpenMMLab. All rights reserved.

from transformers import AutoConfig

from lmdeploy.utils import get_logger

logger = get_logger('lmdeploy')

_SUPPORTED_ARCHS = dict(
# baichuan-7b
BaiChuanForCausalLM=False,
# baichuan2-7b, baichuan-13b, baichuan2-13b
BaichuanForCausalLM=True,
# chatglm2-6b, chatglm3-6b
ChatGLMModel=True,
# deepseek-moe
DeepseekForCausalLM=True,
# falcon-7b
FalconForCausalLM=True,
# gemma-7b
GemmaForCausalLM=True,
# internlm
InternLMForCausalLM=True,
# internlm2
InternLM2ForCausalLM=True,
# internlm-xcomposer
InternLMXComposerForCausalLM=False,
# internlm2-xcomposer
InternLM2XComposerForCausalLM=False,
# llama, llama2, alpaca, vicuna, codellama, ultracm, yi,
# deepseek-coder, deepseek-llm
LlamaForCausalLM=True,
# Mistral-7B
MistralForCausalLM=True,
# Mixtral-8x7B
MixtralForCausalLM=True,
# Qwen 7B-72B, Qwen-VL-7B
QWenLMHeadModel=False,
# Qwen1.5 7B-72B
Qwen2ForCausalLM=True,
)


def is_supported(model_path: str):
"""Check whether supported by pytorch engine.
Args:
model_path (str): the path of a model.
It could be one of the following options:
- i) A local directory path of a turbomind model which is
converted by `lmdeploy convert` command or download from
ii) and iii).
- ii) The model_id of a lmdeploy-quantized model hosted
inside a model repo on huggingface.co, such as
"InternLM/internlm-chat-20b-4bit",
"lmdeploy/llama2-chat-70b-4bit", etc.
- iii) The model_id of a model hosted inside a model repo
on huggingface.co, such as "internlm/internlm-chat-7b",
"Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat"
and so on.
Returns:
support_by_torch (bool): Whether input model is supported by pytorch engine
""" # noqa: E501
import os

support_by_torch = False

triton_model_path = os.path.join(model_path, 'triton_models')
if os.path.exists(triton_model_path):
logger.warning(f'{model_path} seems to be a turbomind workspace, '
'which can only be ran with turbomind engine.')
else:
cfg = AutoConfig.from_pretrained(model_path, trust_remote_code=True)

if hasattr(cfg, 'architectures'):
arch = cfg.architectures[0]
elif hasattr(cfg,
'auto_map') and 'AutoModelForCausalLM' in cfg.auto_map:
arch = cfg.auto_map['AutoModelForCausalLM'].split('.')[-1]
else:
raise RuntimeError(
f'Could not find model architecture from config: {cfg}')

if arch in _SUPPORTED_ARCHS:
support_by_torch = _SUPPORTED_ARCHS[arch]
# special cases
if arch == 'BaichuanForCausalLM':
# baichuan-13B not supported by pytorch
if cfg.num_attention_heads == 40 and cfg.vocab_size == 64000:
support_by_torch = False

return support_by_torch
88 changes: 88 additions & 0 deletions lmdeploy/turbomind/supported_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# Copyright (c) OpenMMLab. All rights reserved.
from transformers import AutoConfig

from lmdeploy.utils import get_logger

logger = get_logger('lmdeploy')

_SUPPORTED_ARCHS = dict(
# baichuan-7b
BaiChuanForCausalLM=True,
# baichuan2-7b, baichuan-13b, baichuan2-13b
BaichuanForCausalLM=True,
# chatglm2-6b, chatglm3-6b
ChatGLMModel=False,
# deepseek-moe
DeepseekForCausalLM=False,
# falcon-7b
FalconForCausalLM=False,
# gemma-7b
GemmaForCausalLM=False,
# internlm
InternLMForCausalLM=True,
# internlm2
InternLM2ForCausalLM=True,
# internlm-xcomposer
InternLMXComposerForCausalLM=True,
# internlm2-xcomposer
InternLM2XComposerForCausalLM=False,
# llama, llama2, alpaca, vicuna, codellama, ultracm, yi,
# deepseek-coder, deepseek-llm
LlamaForCausalLM=True,
# Mistral-7B
MistralForCausalLM=False,
# Mixtral-8x7B
MixtralForCausalLM=False,
# Qwen 7B-72B, Qwen-VL-7B
QWenLMHeadModel=True,
# Qwen1.5 7B-72B
Qwen2ForCausalLM=False)


def is_supported(model_path: str):
"""Check whether supported by turbomind engine.
Args:
model_path (str): the path of a model.
It could be one of the following options:
- i) A local directory path of a turbomind model which is
converted by `lmdeploy convert` command or download from
ii) and iii).
- ii) The model_id of a lmdeploy-quantized model hosted
inside a model repo on huggingface.co, such as
"InternLM/internlm-chat-20b-4bit",
"lmdeploy/llama2-chat-70b-4bit", etc.
- iii) The model_id of a model hosted inside a model repo
on huggingface.co, such as "internlm/internlm-chat-7b",
"Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat"
and so on.
Returns:
support_by_turbomind (bool): Whether input model is supported by turbomind engine
""" # noqa: E501
import os

support_by_turbomind = False
triton_model_path = os.path.join(model_path, 'triton_models')
if os.path.exists(triton_model_path):
support_by_turbomind = True
else:
cfg = AutoConfig.from_pretrained(model_path, trust_remote_code=True)

if hasattr(cfg, 'architectures'):
arch = cfg.architectures[0]
elif hasattr(cfg,
'auto_map') and 'AutoModelForCausalLM' in cfg.auto_map:
arch = cfg.auto_map['AutoModelForCausalLM'].split('.')[-1]
else:
raise RuntimeError(
f'Could not find model architecture from config: {cfg}')

if arch in _SUPPORTED_ARCHS:
support_by_turbomind = _SUPPORTED_ARCHS[arch]
# special cases
if arch == 'BaichuanForCausalLM':
num_attn_head = cfg.num_attention_heads
if num_attn_head == 40:
# baichuan-13B, baichuan2-13B not supported by turbomind
support_by_turbomind = False
return support_by_turbomind
Loading

0 comments on commit 7dd97fd

Please sign in to comment.