-
Notifications
You must be signed in to change notification settings - Fork 4
/
modal_vllm_chat.py
189 lines (167 loc) · 6.16 KB
/
modal_vllm_chat.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
import modal
import asyncio
from contextlib import asynccontextmanager
def download_model_to_image(model_dir, model_name, model_revision):
import os
from huggingface_hub import snapshot_download
from transformers.utils import move_cache
os.makedirs(model_dir, exist_ok=True)
snapshot_download(
model_name,
revision=model_revision,
local_dir=model_dir,
ignore_patterns=["*.pt", "*.bin"], # Using safetensors
)
move_cache()
MODEL_DIR = "/qwen"
MODEL_NAME = "rawsh/MetaMath-Qwen2.5-0.5b"
MODEL_REVISION = "779b469ef1bb4ef8faac05e46b94c09d38112194"
vllm_image = (
modal.Image.debian_slim(python_version="3.10")
.pip_install(
"vllm==0.6.2",
"torch==2.4.0",
"transformers>=4.45",
"ray==2.36.0",
"hf-transfer==0.1.8",
"huggingface_hub==0.25.0",
)
.env({"HF_HUB_ENABLE_HF_TRANSFER": "1"})
.run_function(
download_model_to_image,
timeout=60 * 20,
secrets=[modal.Secret.from_name("hf-token")],
kwargs={
"model_dir": MODEL_DIR,
"model_name": MODEL_NAME,
"model_revision": MODEL_REVISION,
},
)
.env({"VLLM_ALLOW_LONG_MAX_MODEL_LEN": "1"})
)
app = modal.App("vllm-qwen-metamath")
N_GPU = 1
MINUTES = 60
HOURS = 60 * MINUTES
async def get_model_config(engine):
try:
return await engine.get_model_config()
except Exception as e:
print(f"Error getting model config: {e}")
raise
@asynccontextmanager
async def lifespan(app):
try:
await asyncio.sleep(0)
yield
finally:
tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()]
for task in tasks:
task.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
@app.function(
image=vllm_image,
gpu=modal.gpu.A10G(count=N_GPU),
container_idle_timeout=5 * MINUTES,
timeout=20 * MINUTES,
allow_concurrent_inputs=1000,
secrets=[modal.Secret.from_name("vllm-token")]
)
@modal.asgi_app()
def serve():
import os
import fastapi
import vllm.entrypoints.openai.api_server as api_server
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
from vllm.entrypoints.openai.serving_engine import BaseModelPath
from vllm.usage.usage_lib import UsageContext
from transformers import AutoTokenizer
web_app = fastapi.FastAPI(
title=f"OpenAI-compatible {MODEL_NAME} server",
description="Run an OpenAI-compatible LLM server with vLLM on modal.com",
version="0.0.1",
docs_url="/docs",
lifespan=lifespan
)
http_bearer = fastapi.security.HTTPBearer(
scheme_name="Bearer Token",
description="See code for authentication details.",
)
web_app.add_middleware(
fastapi.middleware.cors.CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
TOKEN = os.environ["API_TOKEN"]
async def is_authenticated(api_key: str = fastapi.Security(http_bearer)):
if api_key.credentials != TOKEN:
raise fastapi.HTTPException(
status_code=fastapi.status.HTTP_401_UNAUTHORIZED,
detail="Invalid authentication credentials",
)
return {"username": "authenticated_user"}
router = fastapi.APIRouter(dependencies=[fastapi.Depends(is_authenticated)])
router.include_router(api_server.router)
web_app.include_router(router)
engine_args = AsyncEngineArgs(
model=MODEL_DIR,
tensor_parallel_size=N_GPU,
gpu_memory_utilization=0.90,
max_model_len=8096,
enforce_eager=False,
enable_prefix_caching=True
)
engine = AsyncLLMEngine.from_engine_args(
engine_args, usage_context=UsageContext.OPENAI_API_SERVER
)
async def setup_engine():
model_config = await get_model_config(engine)
return model_config
model_config = asyncio.run(setup_engine())
request_logger = RequestLogger(max_log_len=2048)
base_model_paths = [
BaseModelPath(name=MODEL_NAME.split("/")[1], model_path=MODEL_NAME)
]
# Qwen chat template with exact formatting
# TEMPLATE = """{%- for message in messages %}
# {{- '<|im_start|>' + message.role + '\n' + message.content.strip() + '\n<|im_end|>\n' }}
# {%- endfor %}
# {%- if add_generation_prompt %}
# {{- '<|im_start|>assistant\n' }}
# {%- endif %}"""
#NICEE
# TEMPLATE = """{%- for message in messages %}
# {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>\n' }}
# {%- endfor %}
# <|im_start|>assistant
# """
TEMPLATE = """{%- for message in messages %}{%- set content = '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>\n' %}{%- if loop.last and message.role == 'assistant' %}{%- set content = '<|im_start|>' + message.role + '\n' + message.content %}{%- endif %}{{- content }}{%- endfor %}"""
# TEMPLATE = """{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if (loop.last and add_generation_prompt) or not loop.last %}{{ '<|im_end|>' + '\n'}}{% endif %}{% endfor %}
# {% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{ '<|im_start|>assistant\n' }}{% endif %}"""
# Set up completion endpoint
api_server.completion = lambda s: OpenAIServingCompletion(
engine,
model_config=model_config,
base_model_paths=base_model_paths,
lora_modules=[],
prompt_adapters=[],
request_logger=request_logger,
)
# Set up chat endpoint with tokenizer's chat template
api_server.chat = lambda s: OpenAIServingChat(
engine,
model_config=model_config,
base_model_paths=base_model_paths,
lora_modules=[],
prompt_adapters=[],
request_logger=request_logger,
response_role="assistant",
chat_template=TEMPLATE,
)
return web_app