-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmulti_serve.py
481 lines (412 loc) · 18 KB
/
multi_serve.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
import argparse
import asyncio
import io
import json
import logging
import os
import sys
import tempfile
from typing import Dict, Optional, Tuple, Union, List
import pandas as pd
import torch
from torchvision.io import decode_image
from ludwig.api import LudwigModel
from ludwig.constants import AUDIO, COLUMN
from ludwig.contrib import add_contrib_callback_args
from ludwig.globals import LUDWIG_VERSION
from ludwig.utils.print_utils import get_logging_level_registry, print_ludwig
from ludwig.utils.server_utils import NumpyJSONResponse
from ludwig.backend import Backend
from ludwig.callbacks import Callback
from huggingface_hub import snapshot_download, login
logger = logging.getLogger(__name__)
try:
import uvicorn
from fastapi import FastAPI, status
from starlette.datastructures import UploadFile
from starlette.middleware import Middleware
from starlette.middleware.cors import CORSMiddleware
from starlette.requests import Request
except ImportError as e:
logger.error(e)
logger.error(
" fastapi and other serving dependencies cannot be loaded"
"and may have not been installed. "
"In order to install all serving dependencies run "
"pip install ludwig[serve]"
)
sys.exit(-1)
ALL_FEATURES_PRESENT_ERROR = {"error": "entry must contain all input features"}
COULD_NOT_RUN_INFERENCE_ERROR = {"error": "Unexpected Error: could not run inference on model"}
AUTH_TOKEN = os.getenv("AUTH_TOKEN", "TOKEN_MUST_BE_DEFINED")
HF_AUTH_TOKEN = os.getenv("HF_AUTH_TOKEN")
# # Log in to HuggingFace using the provided access token
if HF_AUTH_TOKEN:
login(token=HF_AUTH_TOKEN)
def validate_and_get_project_name(repo_name:str) -> str:
"""
Validate a HuggingFace repository name and return the project name.
Parameters:
repo_name (str): The repository name in the format 'Owner/ProjectName'.
Returns:
str: The project name if the repo_name is valid.
Raises:
ValueError: If the repo_name is not in the correct format.
"""
# Check if the repo name contains exactly one '/'
if repo_name.count('/') != 1:
raise ValueError("Invalid repository name format. It must be in 'Owner/ProjectName' format.")
# Split the repository name into owner and project name
owner, project_name = repo_name.split('/')
# Validate that both owner and project name are non-empty
if not owner or not project_name:
raise ValueError("Invalid repository name. Both owner and project name must be non-empty.")
# Return the project name if the validation is successful
return project_name
def process_repo_name(repo_name: str, save_dir: Optional[str]) -> Tuple[str, str, str]:
if repo_name is not None:
project_name = validate_and_get_project_name(repo_name)
repo_dir = os.path.join("repos", project_name)
if save_dir is not None:
save_dir = os.path.join("repos", project_name, save_dir)
else:
save_dir = os.path.join("repos", project_name)
return repo_name, repo_dir, save_dir
def download_model_from_huggingface(
repo_id: str,
repo_dir: str,
force_download: bool = True
) -> None:
"""
Download the model from Hugging Face if not already present in the local directory.
Args:
- repo_id: Hugging Face repository ID of the model.
- repo_dir: Local directory to store the downloaded model.
- force_download: If True, forces the download even if the model is already present.
"""
# Check if the model is already present
if not os.path.isdir(repo_dir) or force_download:
# Create the repository directory if it doesn't exist
os.makedirs(repo_dir, exist_ok=True)
# Download the model from Hugging Face
try:
snapshot_download(
repo_id=repo_id,
force_download=force_download,
local_dir=repo_dir,
repo_type="model"
)
except Exception as e:
logging.warning(f"Failed to download the model from hugging_face: {e}")
def load_model(
save_dir: str,
repo_dir: str,
repo_id: str,
logging_level: int = logging.ERROR,
backend: Optional[Union[str, object]] = None,
gpus: Optional[Union[str, int, List[int]]] = None,
gpu_memory_limit: Optional[float] = None,
allow_parallel_threads: bool = True,
callbacks: List[object] = None,
from_checkpoint: bool = False,
) -> Optional[LudwigModel]:
"""
Load a pretrained model from the specified directory, or download it from Hugging Face if necessary.
This function first checks if the model is available locally, and if not,
it downloads it from the Hugging Face Hub. It then loads the model using Ludwig's
`LudwigModel.load()` method, which handles restoring the model's weights, config,
and other metadata.
Args:
- save_dir: Directory where the model's checkpoint and weights are saved.
- repo_dir: Directory where the model should be downloaded from Hugging Face.
- repo_id: The Hugging Face repository ID of the model.
- logging_level: Log level for logs.
- backend: Backend to use for execution.
- gpus: GPUs to use for model execution.
- gpu_memory_limit: Maximum GPU memory fraction allowed.
- allow_parallel_threads: Allow multithreading for Torch.
- callbacks: List of callbacks to use during the model pipeline.
- from_checkpoint: If True, loads from the checkpoint rather than the final weights.
Returns:
- Optional[LudwigModel]: A LudwigModel instance with restored weights and metadata, or `None` if an error occurred.
"""
try:
# Step 1: Check if the model checkpoint exists locally, and if not, download it
if not os.path.isfile(os.path.join(save_dir, "checkpoint")):
download_model_from_huggingface(repo_id, repo_dir)
# Step 2: Load the model using LudwigModel.load()
ludwig_model = LudwigModel.load(
model_dir=save_dir,
logging_level=logging_level,
backend=backend,
gpus=gpus,
gpu_memory_limit=gpu_memory_limit,
allow_parallel_threads=allow_parallel_threads,
callbacks=callbacks,
from_checkpoint=from_checkpoint
)
return ludwig_model
except Exception as e:
logging.warning(f"Failed to load the model: {e}")
return None
def server(models, allowed_origins=None):
middleware = [Middleware(CORSMiddleware, allow_origins=allowed_origins)] if allowed_origins else None
app = FastAPI(middleware=middleware)
@app.get("/")
def check_health():
return NumpyJSONResponse({"message": "Ludwig server is up", "models": list(models.keys())})
@app.post("/predict")
async def predict(request: Request):
try:
# Parse form data
form = await request.form()
model_names = form.get("model") # Single model or list of models
model_names = model_names.split(",") if model_names else None
files = []
except Exception:
logger.exception("Failed to parse predict form")
return NumpyJSONResponse(COULD_NOT_RUN_INFERENCE_ERROR, status_code=500)
async def predict_by_model(model_name: str, model: LudwigModel) -> dict:
try:
entry, files = convert_input(form, model.model.input_features) # Input compatible with all models
input_features = {f[COLUMN] for f in model.config["input_features"]}
if (entry.keys() & input_features) != input_features:
missing_features = set(input_features) - set(entry.keys())
return {
"model": model_name,
"response": {
"error": f"Missing features: {missing_features}.",
"status": "failed",
},
}
resp, _ = model.predict(dataset=[entry], data_format=dict)
return {
"model": model_name,
"response": {
"predictions": resp.to_dict("records")[0],
"status": "success",
},
}
except Exception as exc:
logger.exception(f"Failed to predict for model '{model_name}': {exc}")
return {
"model": model_name,
"response": {
"error": str(exc),
"status": "failed",
},
}
try:
# Determine target models
if model_names:
invalid_models = [name for name in model_names if name not in models]
if invalid_models:
return NumpyJSONResponse(
{"error": f"Invalid model names: {invalid_models}. Available models: {list(models.keys())}."},
status_code=400,
)
target_models = {name: models[name] for name in model_names}
else:
# Predict for all models if no specific model(s) are provided
target_models = models
# Run
tasks = [predict_by_model(name, model) for name, model in target_models.items()]
results = await asyncio.gather(*tasks)
responses = {result["model"]: result["response"] for result in results}
return NumpyJSONResponse(responses)
except Exception:
logger.exception("Failed to execute predictions")
return NumpyJSONResponse(COULD_NOT_RUN_INFERENCE_ERROR, status_code=500)
finally:
for f in files:
os.remove(f.name)
@app.post("/batch_predict")
async def batch_predict(request: Request):
try:
# Parse form data
form = await request.form()
model_names = form.get("model") # Single model or list of models
model_names = model_names.split(",") if model_names else None
files=[]
except Exception:
logger.exception("Failed to parse batch_predict form")
return NumpyJSONResponse(COULD_NOT_RUN_INFERENCE_ERROR, status_code=500)
async def batch_predict_by_model(model_name:str , model: LudwigModel) -> dict:
try:
data, files = convert_batch_input(form, model.model.input_features)
data_df = pd.DataFrame.from_records(data["data"], index=data.get("index"), columns=data["columns"])
input_features = {f[COLUMN] for f in model.config["input_features"]}
if (set(data_df.columns) & input_features) != input_features:
missing_features = set(input_features) - set(data_df.columns)
return {
"model": model_name,
"response": {
"error": f"Missing features: {missing_features}.",
"status": "failed",
},
}
resp, _ = model.predict(dataset=data_df)
return {
"model": model_name,
"response": {
"predictions": resp.to_dict("split"),
"status": "success",
},
}
except Exception as exc:
logger.exception(f"Failed to batch predict for model '{model_name}': {exc}")
return {
"model": model_name,
"response": {
"error": str(exc),
"status": "failed",
},
}
try:
# Determine target models
if model_names:
invalid_models = [name for name in model_names if name not in models]
if invalid_models:
return NumpyJSONResponse(
{"error": f"Invalid model names: {invalid_models}. Available models: {list(models.keys())}."},
status_code=400,
)
target_models = {name: models[name] for name in model_names}
else: # Predict for all models if no specific model(s) are provided
target_models = models
# Run batch predictions
tasks = [batch_predict_by_model(name, model) for name, model in target_models.items()]
results = await asyncio.gather(*tasks)
responses = {result["model"]: result["response"] for result in results}
return NumpyJSONResponse(responses)
except Exception:
logger.exception("Failed to execute batch predictions")
return NumpyJSONResponse(COULD_NOT_RUN_INFERENCE_ERROR, status_code=500)
finally:
for f in files:
os.remove(f.name)
return app
def _write_file(v, files):
# Convert UploadFile to a NamedTemporaryFile to ensure it's on the disk
suffix = os.path.splitext(v.filename)[1]
named_file = tempfile.NamedTemporaryFile(delete=False, suffix=suffix)
files.append(named_file)
named_file.write(v.file.read())
named_file.close()
return named_file.name
def _read_image_buffer(v):
# read bytes sent via REST API and convert to image tensor
# in [channels, height, width] format
byte_string = io.BytesIO(v.file.read()).read()
image = decode_image(torch.frombuffer(byte_string, dtype=torch.uint8))
return image # channels, height, width
def convert_input(form, input_features):
"""Returns a new input and a list of files to be cleaned up."""
new_input = {}
files = []
for k, v in form.multi_items():
if isinstance(v, UploadFile):
# check if audio or image file
if input_features.get(k).type() == AUDIO:
new_input[k] = _write_file(v, files)
else:
new_input[k] = _read_image_buffer(v)
else:
new_input[k] = v
return new_input, files
def convert_batch_input(form, input_features):
"""Returns a new input and a list of files to be cleaned up."""
file_index = {}
files = []
for k, v in form.multi_items():
if isinstance(v, UploadFile):
file_index[v.filename] = v
data = json.loads(form["dataset"])
for row in data["data"]:
for i, value in enumerate(row):
if value in file_index:
feature_name = data["columns"][i]
if input_features.get(feature_name).type() == AUDIO:
row[i] = _write_file(file_index[value], files)
else:
row[i] = _read_image_buffer(file_index[value])
return data, files
async def are_models_loaded(models: Dict[str, LudwigModel]) -> bool:
# Implement a check to verify all models are fully loaded
return all(model is not None for model in models.values())
def run_server(
model_paths: dict, # Dictionary of model IDs to paths
mode: str,
host: str,
port: int,
allowed_origins: list,
) -> None:
"""Loads pre-trained models and serves them on an http server."""
# If model_paths is a string, convert it to a dictionary
if isinstance(model_paths, str):
model_paths = json.loads(model_paths)
models = {}
for model_name, repo_id in model_paths.items():
if mode == "huggingface":
repo_id, repo_dir, save_dir = process_repo_name(repo_id, "model")
models[model_name] = load_model(save_dir, repo_dir, repo_id=repo_id, backend="local")
elif mode == "local":
models[model_name] = LudwigModel.load(repo_id, backend="local")
# Check if models are loaded
if not asyncio.run(are_models_loaded(models)):
headers = {"Retry-After": "120"} # Suggest retrying after 2 minutes
response = NumpyJSONResponse(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
content={"message": "Models are still loading, please retry later."},
headers=headers,
)
return response
app = server(models, allowed_origins)
uvicorn.run(app, host=host, port=port)
def cli(sys_argv):
parser = argparse.ArgumentParser(
description="This script serves multiple pretrained models", prog="ludwig multi_serve", usage="%(prog)s [options]"
)
# ----------------
# Model parameters
# ----------------
parser.add_argument("-m", "--model_paths", help="model to load", required=True)
parser.add_argument("-mode", "--mode", choices=["huggingface", "local"], help="Model loading mode: either fetch them from HuggingFace or locally ", required=True)
parser.add_argument(
"-l",
"--logging_level",
default="info",
help="the level of logging to use",
choices=["critical", "error", "warning", "info", "debug", "notset"],
)
# ----------------
# Server parameters
# ----------------
parser.add_argument(
"-p",
"--port",
help="port for server (default: 8000)",
default=8000,
type=int,
)
parser.add_argument("-H", "--host", help="host for server (default: 0.0.0.0)", default="0.0.0.0")
parser.add_argument(
"-ao",
"--allowed_origins",
nargs="*",
help="A list of origins that should be permitted to make cross-origin requests. "
'Use "*" to allow any origin. See https://www.starlette.io/middleware/#corsmiddleware.',
)
add_contrib_callback_args(parser)
args = parser.parse_args(sys_argv)
args.callbacks = args.callbacks or []
for callback in args.callbacks:
callback.on_cmdline("serve", *sys_argv)
args.logging_level = get_logging_level_registry()[args.logging_level]
logging.getLogger("ludwig").setLevel(args.logging_level)
global logger
logger = logging.getLogger("ludwig.serve")
print_ludwig("Serve", LUDWIG_VERSION)
run_server(args.model_paths, args.mode, args.host, args.port, args.allowed_origins)
if __name__ == "__main__":
cli(sys.argv[1:])