Skip to content

Commit

Permalink
fix side effect brought by pr 1188
Browse files Browse the repository at this point in the history
  • Loading branch information
lvhan028 committed Mar 2, 2024
1 parent f81404a commit de60332
Showing 1 changed file with 12 additions and 8 deletions.
20 changes: 12 additions & 8 deletions lmdeploy/serve/async_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def _config_model_name(config):
raise ArgumentError(None,
f'Please set model_name for {model_path}')
else:
logger.warning(f'Best matched chat template name: {model_name}')
logger.info(f'matched chat template name: {model_name}')
return model_name


Expand Down Expand Up @@ -111,9 +111,10 @@ def __init__(self,
chat_template_config: Optional[ChatTemplateConfig] = None,
tp: int = 1,
**kwargs) -> None:
logger.info(f'AsyncEngine init with backend={backend}, backend_config'
f'={backend_config}, chat_template_config='
f'{chat_template_config}')
logger.info(
f'input backend={backend}, backend_config={backend_config}')
logger.info(f'input chat_template_config={chat_template_config}')

self.model_name = deduce_a_name(model_path, model_name, backend_config,
chat_template_config)
# build chat template config
Expand All @@ -122,33 +123,34 @@ def __init__(self,
elif chat_template_config.model_name is None:
chat_template_config.model_name = self.model_name
self.chat_template = chat_template_config.chat_template

# prevent bc
for k in list(kwargs.keys()):
if hasattr(chat_template_config, k):
logger.warning(f'{k} was deprecated. Please use '
'chat_template_config instead')
v = kwargs.pop(k)
setattr(chat_template_config, k, v)
logger.info(f'updated chat_template_onfig={chat_template_config}')

# build backend engine
if backend == 'turbomind':
logger.info('Running turbomind engine for pipeline.')
self._build_turbomind(model_path=model_path,
backend_config=backend_config,
chat_template_config=chat_template_config,
tp=tp,
**kwargs)
elif backend == 'pytorch':
logger.info('Running pytorch engine for pipeline.')
self._build_pytorch(model_path=model_path,
backend_config=backend_config,
**kwargs)
else:
raise ValueError(f'unsupported backend {backend}')

logger.info(f'updated backend_config={self.backend_config}')

# parameters for member functions
self.session_len = backend_config.session_len
self.backend_config = backend_config
self.session_len = self.backend_config.session_len
self.stop_words = _stop_words(self.chat_template.stop_words,
self.engine.tokenizer)
if self.stop_words is not None:
Expand Down Expand Up @@ -187,6 +189,7 @@ def _build_turbomind(
engine_config=backend_config,
chat_template_config=chat_template_config,
**kwargs)
self.backend_config = backend_config

def _build_pytorch(
self,
Expand All @@ -205,6 +208,7 @@ def _build_pytorch(
backend_config.session_len = self.chat_template.session_len
self.engine = Engine(model_path=model_path,
engine_config=backend_config)
self.backend_config = backend_config

def __call__(self,
prompts: Union[List[str], str, List[Dict], List[List[Dict]]],
Expand Down

0 comments on commit de60332

Please sign in to comment.