diff --git a/python/llm/src/ipex_llm/vllm/cpu/engine/engine.py b/python/llm/src/ipex_llm/vllm/cpu/engine/engine.py index 1210d5dc313..942eb3ae6fb 100644 --- a/python/llm/src/ipex_llm/vllm/cpu/engine/engine.py +++ b/python/llm/src/ipex_llm/vllm/cpu/engine/engine.py @@ -230,21 +230,23 @@ def from_engine_args(cls, engine_args: AsyncEngineArgs, return super().from_engine_args(engine_args, usage_context, ipc_path) +def signal_handler(*_) -> None: + raise KeyboardInterrupt("MQLLMEngine terminated") # noqa + + def run_mp_engine(engine_args: AsyncEngineArgs, usage_context: UsageContext, ipc_path: str, load_in_low_bit: str, engine_alive): - def signal_handler(*_) -> None: - # Interrupt server on sigterm - raise KeyboardInterrupt("MQLLMEngine terminated") # noqa - try: - signal.signal(signal.SIGTERM, signal_handler) - engine = IPEXLLMMQLLMEngine.from_engine_args(engine_args=engine_args, usage_context=usage_context, ipc_path=ipc_path, load_in_low_bit=load_in_low_bit) + + signal.signal(signal.SIGTERM, signal_handler) + engine.start() + except BaseException as e: logger.exception(e) engine_alive.value = False diff --git a/python/llm/src/ipex_llm/vllm/cpu/entrypoints/api_server.py b/python/llm/src/ipex_llm/vllm/cpu/entrypoints/api_server.py index e1f50daab0e..d2343ecfb35 100644 --- a/python/llm/src/ipex_llm/vllm/cpu/entrypoints/api_server.py +++ b/python/llm/src/ipex_llm/vllm/cpu/entrypoints/api_server.py @@ -1,787 +1,174 @@ +""" +NOTE: This API server is used only for demonstrating usage of AsyncEngine +and simple performance benchmarks. It is not intended for production use. +For production use, we recommend using our OpenAI compatible server. +We are also not going to accept PRs modifying this file, please +change `vllm/entrypoints/openai/api_server.py` instead. +""" import asyncio -import atexit -import importlib -import inspect -import multiprocessing -import os -import re -import signal -import socket -import tempfile -import uuid +import json +import ssl from argparse import Namespace -from contextlib import asynccontextmanager -from functools import partial -from http import HTTPStatus -from typing import AsyncIterator, Optional, Set, Tuple - -import uvloop -from fastapi import APIRouter, FastAPI, Request -from fastapi.exceptions import RequestValidationError -from fastapi.middleware.cors import CORSMiddleware +from typing import Any, AsyncGenerator, Optional + +from fastapi import FastAPI, Request from fastapi.responses import JSONResponse, Response, StreamingResponse -from starlette.datastructures import State -from starlette.routing import Mount -from typing_extensions import assert_never -import vllm.envs as envs -from vllm.config import ModelConfig from vllm.engine.arg_utils import AsyncEngineArgs # from vllm.engine.async_llm_engine import AsyncLLMEngine # type: ignore from ipex_llm.vllm.cpu.engine import IPEXLLMAsyncLLMEngine as AsyncLLMEngine -from vllm.engine.multiprocessing.client import MQLLMEngineClient -# from vllm.engine.multiprocessing.engine import run_mp_engine -from ipex_llm.vllm.cpu.engine import run_mp_engine -from vllm.engine.protocol import EngineClient -from vllm.entrypoints.chat_utils import load_chat_template from vllm.entrypoints.launcher import serve_http -from vllm.entrypoints.logger import RequestLogger -from vllm.entrypoints.openai.cli_args import (make_arg_parser, - validate_parsed_serve_args) -from vllm.entrypoints.openai.serving_engine import OpenAIServing -# yapf conflicts with isort for this block -# yapf: disable -from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, - ChatCompletionResponse, - CompletionRequest, - CompletionResponse, - DetokenizeRequest, - DetokenizeResponse, - EmbeddingRequest, - EmbeddingResponse, - EmbeddingResponseData, - ErrorResponse, - LoadLoraAdapterRequest, - PoolingRequest, PoolingResponse, - ScoreRequest, ScoreResponse, - TokenizeRequest, - TokenizeResponse, - UnloadLoraAdapterRequest) -# yapf: enable -from vllm.entrypoints.openai.serving_chat import OpenAIServingChat -from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion -from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding -# from vllm.entrypoints.openai.serving_engine import BaseModelPath, OpenAIServing -from vllm.entrypoints.openai.serving_models import (BaseModelPath, - OpenAIServingModels) - -from vllm.entrypoints.openai.serving_pooling import OpenAIServingPooling -from vllm.entrypoints.openai.serving_score import OpenAIServingScores -from vllm.entrypoints.openai.serving_tokenization import ( - OpenAIServingTokenization) -from vllm.entrypoints.openai.tool_parsers import ToolParserManager from vllm.entrypoints.utils import with_cancellation from vllm.logger import init_logger +from vllm.sampling_params import SamplingParams from vllm.usage.usage_lib import UsageContext -from vllm.utils import (FlexibleArgumentParser, get_open_zmq_ipc_path, - is_valid_ipv6_address, set_ulimit) +from vllm.utils import FlexibleArgumentParser, random_uuid, set_ulimit from vllm.version import __version__ as VLLM_VERSION -TIMEOUT_KEEP_ALIVE = 5 # seconds - -prometheus_multiproc_dir: tempfile.TemporaryDirectory - -# Cannot use __name__ (https://github.com/vllm-project/vllm/pull/4765) -logger = init_logger('vllm.entrypoints.openai.api_server') - -_running_tasks: Set[asyncio.Task] = set() - - -@asynccontextmanager -async def lifespan(app: FastAPI): - try: - if app.state.log_stats: - engine_client: EngineClient = app.state.engine_client - - async def _force_log(): - while True: - await asyncio.sleep(10.) - await engine_client.do_log_stats() - - task = asyncio.create_task(_force_log()) - _running_tasks.add(task) - task.add_done_callback(_running_tasks.remove) - else: - task = None - try: - yield - finally: - if task is not None: - task.cancel() - finally: - # Ensure app state including engine ref is gc'd - del app.state - - -@asynccontextmanager -async def build_async_engine_client( - args: Namespace) -> AsyncIterator[EngineClient]: - - # Context manager to handle engine_client lifecycle - # Ensures everything is shutdown and cleaned up on error/exit - engine_args = AsyncEngineArgs.from_cli_args(args) - - async with build_async_engine_client_from_engine_args( - engine_args, args.disable_frontend_multiprocessing, args.load_in_low_bit) as engine: - yield engine - - -@asynccontextmanager -async def build_async_engine_client_from_engine_args( - engine_args: AsyncEngineArgs, - disable_frontend_multiprocessing: bool = False, - load_in_low_bit: str = "sym_int4", -) -> AsyncIterator[EngineClient]: - """ - Create EngineClient, either: - - in-process using the AsyncLLMEngine Directly - - multiprocess using AsyncLLMEngine RPC - - Returns the Client or None if the creation failed. - """ - - # Fall back - # TODO: fill out feature matrix. - if (MQLLMEngineClient.is_unsupported_config(engine_args) - or envs.VLLM_USE_V1 or disable_frontend_multiprocessing): - engine_config = engine_args.create_engine_config( - UsageContext.OPENAI_API_SERVER) - uses_ray = getattr(AsyncLLMEngine._get_executor_cls(engine_config), - "uses_ray", False) - - build_engine = partial(AsyncLLMEngine.from_engine_args, - load_in_low_bit=load_in_low_bit, - engine_args=engine_args, - engine_config=engine_config, - usage_context=UsageContext.OPENAI_API_SERVER) - if uses_ray: - # Must run in main thread with ray for its signal handlers to work - engine_client = build_engine() - else: - engine_client = await asyncio.get_running_loop().run_in_executor( - None, build_engine) - - yield engine_client - if hasattr(engine_client, "shutdown"): - engine_client.shutdown() - return - - # Otherwise, use the multiprocessing AsyncLLMEngine. - else: - if "PROMETHEUS_MULTIPROC_DIR" not in os.environ: - # Make TemporaryDirectory for prometheus multiprocessing - # Note: global TemporaryDirectory will be automatically - # cleaned up upon exit. - global prometheus_multiproc_dir - prometheus_multiproc_dir = tempfile.TemporaryDirectory() - os.environ[ - "PROMETHEUS_MULTIPROC_DIR"] = prometheus_multiproc_dir.name - else: - logger.warning( - "Found PROMETHEUS_MULTIPROC_DIR was set by user. " - "This directory must be wiped between vLLM runs or " - "you will find inaccurate metrics. Unset the variable " - "and vLLM will properly handle cleanup.") - - # Select random path for IPC. - ipc_path = get_open_zmq_ipc_path() - logger.debug("Multiprocessing frontend to use %s for IPC Path.", - ipc_path) - - # Start RPCServer in separate process (holds the LLMEngine). - # the current process might have CUDA context, - # so we need to spawn a new process - context = multiprocessing.get_context("spawn") - - # The Process can raise an exception during startup, which may - # not actually result in an exitcode being reported. As a result - # we use a shared variable to communicate the information. - engine_alive = multiprocessing.Value('b', True, lock=False) - engine_process = context.Process(target=run_mp_engine, - args=(engine_args, - UsageContext.OPENAI_API_SERVER, - ipc_path, load_in_low_bit, engine_alive)) - engine_process.start() - engine_pid = engine_process.pid - assert engine_pid is not None, "Engine process failed to start." - logger.info("Started engine process with PID %d", engine_pid) - - def _cleanup_ipc_path(): - socket_path = ipc_path.replace("ipc://", "") - if os.path.exists(socket_path): - os.remove(socket_path) - - # Ensure we clean up the local IPC socket file on exit. - atexit.register(_cleanup_ipc_path) - - # Build RPCClient, which conforms to EngineClient Protocol. - engine_config = engine_args.create_engine_config() - build_client = partial(MQLLMEngineClient, ipc_path, engine_config, - engine_pid) - mq_engine_client = await asyncio.get_running_loop().run_in_executor( - None, build_client) - try: - while True: - try: - await mq_engine_client.setup() - break - except TimeoutError: - if (not engine_process.is_alive() - or not engine_alive.value): - raise RuntimeError( - "Engine process failed to start. See stack " - "trace for the root cause.") from None - - yield mq_engine_client # type: ignore[misc] - finally: - # Ensure rpc server process was terminated - engine_process.terminate() - - # Close all open connections to the backend - mq_engine_client.close() - - # Wait for engine process to join - engine_process.join(4) - if engine_process.exitcode is None: - # Kill if taking longer than 5 seconds to stop - engine_process.kill() - - # Lazy import for prometheus multiprocessing. - # We need to set PROMETHEUS_MULTIPROC_DIR environment variable - # before prometheus_client is imported. - # See https://prometheus.github.io/client_python/multiprocess/ - from prometheus_client import multiprocess - multiprocess.mark_process_dead(engine_process.pid) - - -router = APIRouter() - - -def mount_metrics(app: FastAPI): - # Lazy import for prometheus multiprocessing. - # We need to set PROMETHEUS_MULTIPROC_DIR environment variable - # before prometheus_client is imported. - # See https://prometheus.github.io/client_python/multiprocess/ - from prometheus_client import (CollectorRegistry, make_asgi_app, - multiprocess) - - prometheus_multiproc_dir_path = os.getenv("PROMETHEUS_MULTIPROC_DIR", None) - if prometheus_multiproc_dir_path is not None: - logger.debug("vLLM to use %s as PROMETHEUS_MULTIPROC_DIR", - prometheus_multiproc_dir_path) - registry = CollectorRegistry() - multiprocess.MultiProcessCollector(registry) - - # Add prometheus asgi middleware to route /metrics requests - metrics_route = Mount("/metrics", make_asgi_app(registry=registry)) - else: - # Add prometheus asgi middleware to route /metrics requests - metrics_route = Mount("/metrics", make_asgi_app()) - - # Workaround for 307 Redirect for /metrics - metrics_route.path_regex = re.compile("^/metrics(?P.*)$") - app.routes.append(metrics_route) - - -def base(request: Request) -> OpenAIServing: - # Reuse the existing instance - return tokenization(request) - - -def chat(request: Request) -> Optional[OpenAIServingChat]: - return request.app.state.openai_serving_chat - - -def completion(request: Request) -> Optional[OpenAIServingCompletion]: - return request.app.state.openai_serving_completion - +logger = init_logger("vllm.entrypoints.api_server") -def pooling(request: Request) -> Optional[OpenAIServingPooling]: - return request.app.state.openai_serving_pooling +TIMEOUT_KEEP_ALIVE = 5 # seconds. +app = FastAPI() +engine = None -def embedding(request: Request) -> Optional[OpenAIServingEmbedding]: - return request.app.state.openai_serving_embedding - - -def score(request: Request) -> Optional[OpenAIServingScores]: - return request.app.state.openai_serving_scores - - -def tokenization(request: Request) -> OpenAIServingTokenization: - return request.app.state.openai_serving_tokenization - - -def engine_client(request: Request) -> EngineClient: - return request.app.state.engine_client - - -@router.get("/health") -async def health(raw_request: Request) -> Response: +@app.get("/health") +async def health() -> Response: """Health check.""" - await engine_client(raw_request).check_health() return Response(status_code=200) -@router.post("/tokenize") -@with_cancellation -async def tokenize(request: TokenizeRequest, raw_request: Request): - handler = tokenization(raw_request) - - generator = await handler.create_tokenize(request, raw_request) - if isinstance(generator, ErrorResponse): - return JSONResponse(content=generator.model_dump(), - status_code=generator.code) - elif isinstance(generator, TokenizeResponse): - return JSONResponse(content=generator.model_dump()) - - assert_never(generator) - - -@router.post("/detokenize") -@with_cancellation -async def detokenize(request: DetokenizeRequest, raw_request: Request): - handler = tokenization(raw_request) - - generator = await handler.create_detokenize(request, raw_request) - if isinstance(generator, ErrorResponse): - return JSONResponse(content=generator.model_dump(), - status_code=generator.code) - elif isinstance(generator, DetokenizeResponse): - return JSONResponse(content=generator.model_dump()) - - assert_never(generator) - - -@router.get("/v1/models") -async def show_available_models(raw_request: Request): - handler = base(raw_request) - - models = await handler.show_available_models() - return JSONResponse(content=models.model_dump()) - - -@router.get("/version") -async def show_version(): - ver = {"version": VLLM_VERSION} - return JSONResponse(content=ver) - +@app.post("/generate") +async def generate(request: Request) -> Response: + """Generate completion for the request. -@router.post("/v1/chat/completions") -@with_cancellation -async def create_chat_completion(request: ChatCompletionRequest, - raw_request: Request): - handler = chat(raw_request) - if handler is None: - return base(raw_request).create_error_response( - message="The model does not support Chat Completions API") - - generator = await handler.create_chat_completion(request, raw_request) - - if isinstance(generator, ErrorResponse): - return JSONResponse(content=generator.model_dump(), - status_code=generator.code) - - elif isinstance(generator, ChatCompletionResponse): - return JSONResponse(content=generator.model_dump()) - - return StreamingResponse(content=generator, media_type="text/event-stream") - - -@router.post("/v1/completions") -@with_cancellation -async def create_completion(request: CompletionRequest, raw_request: Request): - handler = completion(raw_request) - if handler is None: - return base(raw_request).create_error_response( - message="The model does not support Completions API") - - generator = await handler.create_completion(request, raw_request) - if isinstance(generator, ErrorResponse): - return JSONResponse(content=generator.model_dump(), - status_code=generator.code) - elif isinstance(generator, CompletionResponse): - return JSONResponse(content=generator.model_dump()) - - return StreamingResponse(content=generator, media_type="text/event-stream") - - -@router.post("/v1/embeddings") -@with_cancellation -async def create_embedding(request: EmbeddingRequest, raw_request: Request): - handler = embedding(raw_request) - if handler is None: - fallback_handler = pooling(raw_request) - if fallback_handler is None: - return base(raw_request).create_error_response( - message="The model does not support Embeddings API") - - logger.warning( - "Embeddings API will become exclusive to embedding models " - "in a future release. To return the hidden states directly, " - "use the Pooling API (`/pooling`) instead.") - - res = await fallback_handler.create_pooling(request, raw_request) - if isinstance(res, PoolingResponse): - generator = EmbeddingResponse( - id=res.id, - object=res.object, - created=res.created, - model=res.model, - data=[ - EmbeddingResponseData( - index=d.index, - embedding=d.data, # type: ignore - ) for d in res.data - ], - usage=res.usage, - ) - else: - generator = res - else: - generator = await handler.create_embedding(request, raw_request) - - if isinstance(generator, ErrorResponse): - return JSONResponse(content=generator.model_dump(), - status_code=generator.code) - elif isinstance(generator, EmbeddingResponse): - return JSONResponse(content=generator.model_dump()) - - assert_never(generator) - - -@router.post("/pooling") -@with_cancellation -async def create_pooling(request: PoolingRequest, raw_request: Request): - handler = pooling(raw_request) - if handler is None: - return base(raw_request).create_error_response( - message="The model does not support Pooling API") - - generator = await handler.create_pooling(request, raw_request) - if isinstance(generator, ErrorResponse): - return JSONResponse(content=generator.model_dump(), - status_code=generator.code) - elif isinstance(generator, PoolingResponse): - return JSONResponse(content=generator.model_dump()) - - assert_never(generator) + The request should be a JSON object with the following fields: + - prompt: the prompt to use for the generation. + - stream: whether to stream the results or not. + - other fields: the sampling parameters (See `SamplingParams` for details). + """ + request_dict = await request.json() + return await _generate(request_dict, raw_request=request) -@router.post("/score") @with_cancellation -async def create_score(request: ScoreRequest, raw_request: Request): - handler = score(raw_request) - if handler is None: - return base(raw_request).create_error_response( - message="The model does not support Score API") - - generator = await handler.create_score(request, raw_request) - if isinstance(generator, ErrorResponse): - return JSONResponse(content=generator.model_dump(), - status_code=generator.code) - elif isinstance(generator, ScoreResponse): - return JSONResponse(content=generator.model_dump()) - - assert_never(generator) +async def _generate(request_dict: dict, raw_request: Request) -> Response: + prompt = request_dict.pop("prompt") + stream = request_dict.pop("stream", False) + sampling_params = SamplingParams(**request_dict) + request_id = random_uuid() + + assert engine is not None + results_generator = engine.generate(prompt, sampling_params, request_id) + + # Streaming case + async def stream_results() -> AsyncGenerator[bytes, None]: + async for request_output in results_generator: + prompt = request_output.prompt + assert prompt is not None + text_outputs = [ + prompt + output.text for output in request_output.outputs + ] + ret = {"text": text_outputs} + yield (json.dumps(ret) + "\n").encode("utf-8") + + if stream: + return StreamingResponse(stream_results()) + + # Non-streaming case + final_output = None + try: + async for request_output in results_generator: + final_output = request_output + except asyncio.CancelledError: + return Response(status_code=499) - -@router.post("/v1/score") -@with_cancellation -async def create_score_v1(request: ScoreRequest, raw_request: Request): - logger.warning( - "To indicate that Score API is not part of standard OpenAI API, we " - "have moved it to `/score`. Please update your client accordingly.") - - return await create_score(request, raw_request) - - -if envs.VLLM_TORCH_PROFILER_DIR: - logger.warning( - "Torch Profiler is enabled in the API server. This should ONLY be " - "used for local development!") - - @router.post("/start_profile") - async def start_profile(raw_request: Request): - logger.info("Starting profiler...") - await engine_client(raw_request).start_profile() - logger.info("Profiler started.") - return Response(status_code=200) - - @router.post("/stop_profile") - async def stop_profile(raw_request: Request): - logger.info("Stopping profiler...") - await engine_client(raw_request).stop_profile() - logger.info("Profiler stopped.") - return Response(status_code=200) - - -if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING: - logger.warning( - "Lora dynamic loading & unloading is enabled in the API server. " - "This should ONLY be used for local development!") - - @router.post("/v1/load_lora_adapter") - async def load_lora_adapter(request: LoadLoraAdapterRequest, - raw_request: Request): - for route in [chat, completion, embedding]: - handler = route(raw_request) - if handler is not None: - response = await handler.load_lora_adapter(request) - if isinstance(response, ErrorResponse): - return JSONResponse(content=response.model_dump(), - status_code=response.code) - - return Response(status_code=200, content=response) - - @router.post("/v1/unload_lora_adapter") - async def unload_lora_adapter(request: UnloadLoraAdapterRequest, - raw_request: Request): - for route in [chat, completion, embedding]: - handler = route(raw_request) - if handler is not None: - response = await handler.unload_lora_adapter(request) - if isinstance(response, ErrorResponse): - return JSONResponse(content=response.model_dump(), - status_code=response.code) - - return Response(status_code=200, content=response) + assert final_output is not None + prompt = final_output.prompt + assert prompt is not None + text_outputs = [prompt + output.text for output in final_output.outputs] + ret = {"text": text_outputs} + return JSONResponse(ret) def build_app(args: Namespace) -> FastAPI: - if args.disable_fastapi_docs: - app = FastAPI(openapi_url=None, - docs_url=None, - redoc_url=None, - lifespan=lifespan) - else: - app = FastAPI(lifespan=lifespan) - app.include_router(router) - app.root_path = args.root_path - - mount_metrics(app) - - app.add_middleware( - CORSMiddleware, - allow_origins=args.allowed_origins, - allow_credentials=args.allow_credentials, - allow_methods=args.allowed_methods, - allow_headers=args.allowed_headers, - ) - - @app.exception_handler(RequestValidationError) - async def validation_exception_handler(_, exc): - err = ErrorResponse(message=str(exc), - type="BadRequestError", - code=HTTPStatus.BAD_REQUEST) - return JSONResponse(err.model_dump(), - status_code=HTTPStatus.BAD_REQUEST) - - if token := envs.VLLM_API_KEY or args.api_key: - - @app.middleware("http") - async def authentication(request: Request, call_next): - if request.method == "OPTIONS": - return await call_next(request) - url_path = request.url.path - if app.root_path and url_path.startswith(app.root_path): - url_path = url_path[len(app.root_path):] - if not url_path.startswith("/v1"): - return await call_next(request) - if request.headers.get("Authorization") != "Bearer " + token: - return JSONResponse(content={"error": "Unauthorized"}, - status_code=401) - return await call_next(request) - - if args.enable_request_id_headers: - logger.warning( - "CAUTION: Enabling X-Request-Id headers in the API Server. " - "This can harm performance at high QPS.") - - @app.middleware("http") - async def add_request_id(request: Request, call_next): - request_id = request.headers.get( - "X-Request-Id") or uuid.uuid4().hex - response = await call_next(request) - response.headers["X-Request-Id"] = request_id - return response - - for middleware in args.middleware: - module_path, object_name = middleware.rsplit(".", 1) - imported = getattr(importlib.import_module(module_path), object_name) - if inspect.isclass(imported): - app.add_middleware(imported) - elif inspect.iscoroutinefunction(imported): - app.middleware("http")(imported) - else: - raise ValueError(f"Invalid middleware {middleware}. " - f"Must be a function or a class.") + global app + app.root_path = args.root_path return app -def init_app_state( - engine_client: EngineClient, - model_config: ModelConfig, - state: State, +async def init_app( args: Namespace, -) -> None: - if args.served_model_name is not None: - served_model_names = args.served_model_name - else: - served_model_names = [args.model] - - if args.disable_log_requests: - request_logger = None - else: - request_logger = RequestLogger(max_log_len=args.max_log_len) - - base_model_paths = [ - BaseModelPath(name=name, model_path=args.model) - for name in served_model_names - ] - - state.engine_client = engine_client - state.log_stats = not args.disable_log_stats - - resolved_chat_template = load_chat_template(args.chat_template) - logger.info("Using supplied chat template:\n%s", resolved_chat_template) - - state.openai_serving_chat = OpenAIServingChat( - engine_client, - model_config, - base_model_paths, - args.response_role, - lora_modules=args.lora_modules, - prompt_adapters=args.prompt_adapters, - request_logger=request_logger, - chat_template=resolved_chat_template, - chat_template_content_format=args.chat_template_content_format, - return_tokens_as_token_ids=args.return_tokens_as_token_ids, - enable_auto_tools=args.enable_auto_tool_choice, - tool_parser=args.tool_call_parser, - enable_prompt_tokens_details=args.enable_prompt_tokens_details, - ) if model_config.runner_type == "generate" else None - state.openai_serving_completion = OpenAIServingCompletion( - engine_client, - model_config, - base_model_paths, - lora_modules=args.lora_modules, - prompt_adapters=args.prompt_adapters, - request_logger=request_logger, - return_tokens_as_token_ids=args.return_tokens_as_token_ids, - ) if model_config.runner_type == "generate" else None - state.openai_serving_pooling = OpenAIServingPooling( - engine_client, - model_config, - base_model_paths, - request_logger=request_logger, - chat_template=resolved_chat_template, - chat_template_content_format=args.chat_template_content_format, - ) if model_config.runner_type == "pooling" else None - state.openai_serving_embedding = OpenAIServingEmbedding( - engine_client, - model_config, - base_model_paths, - request_logger=request_logger, - chat_template=resolved_chat_template, - chat_template_content_format=args.chat_template_content_format, - ) if model_config.task == "embed" else None - state.openai_serving_scores = OpenAIServingScores( - engine_client, - model_config, - base_model_paths, - request_logger=request_logger - ) if model_config.task == "score" else None - state.openai_serving_tokenization = OpenAIServingTokenization( - engine_client, - model_config, - base_model_paths, - lora_modules=args.lora_modules, - request_logger=request_logger, - chat_template=resolved_chat_template, - chat_template_content_format=args.chat_template_content_format, - ) + llm_engine: Optional[AsyncLLMEngine] = None, +) -> FastAPI: + app = build_app(args) + global engine -def create_server_socket(addr: Tuple[str, int]) -> socket.socket: - family = socket.AF_INET - if is_valid_ipv6_address(addr[0]): - family = socket.AF_INET6 - - sock = socket.socket(family=family, type=socket.SOCK_STREAM) - sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - sock.bind(addr) + engine_args = AsyncEngineArgs.from_cli_args(args) + engine = (llm_engine + if llm_engine is not None else AsyncLLMEngine.from_engine_args( + engine_args, usage_context=UsageContext.API_SERVER)) - return sock + return app -async def run_server(args, **uvicorn_kwargs) -> None: +async def run_server(args: Namespace, + llm_engine: Optional[AsyncLLMEngine] = None, + **uvicorn_kwargs: Any) -> None: logger.info("vLLM API server version %s", VLLM_VERSION) logger.info("args: %s", args) - if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3: - ToolParserManager.import_tool_parser(args.tool_parser_plugin) - - valide_tool_parses = ToolParserManager.tool_parsers.keys() - if args.enable_auto_tool_choice \ - and args.tool_call_parser not in valide_tool_parses: - raise KeyError(f"invalid tool call parser: {args.tool_call_parser} " - f"(chose from {{ {','.join(valide_tool_parses)} }})") - - # workaround to make sure that we bind the port before the engine is set up. - # This avoids race conditions with ray. - # see https://github.com/vllm-project/vllm/issues/8204 - sock_addr = (args.host or "", args.port) - sock = create_server_socket(sock_addr) - - # workaround to avoid footguns where uvicorn drops requests with too - # many concurrent requests active set_ulimit() - def signal_handler(*_) -> None: - # Interrupt server on sigterm while initializing - raise KeyboardInterrupt("terminated") - - signal.signal(signal.SIGTERM, signal_handler) - - async with build_async_engine_client(args) as engine_client: - app = build_app(args) - - model_config = await engine_client.get_model_config() - init_app_state(engine_client, model_config, app.state, args) - - shutdown_task = await serve_http( - app, - host=args.host, - port=args.port, - log_level=args.uvicorn_log_level, - timeout_keep_alive=TIMEOUT_KEEP_ALIVE, - ssl_keyfile=args.ssl_keyfile, - ssl_certfile=args.ssl_certfile, - ssl_ca_certs=args.ssl_ca_certs, - ssl_cert_reqs=args.ssl_cert_reqs, - **uvicorn_kwargs, - ) + app = await init_app(args, llm_engine) + assert engine is not None + + shutdown_task = await serve_http( + app, + host=args.host, + port=args.port, + log_level=args.log_level, + timeout_keep_alive=TIMEOUT_KEEP_ALIVE, + ssl_keyfile=args.ssl_keyfile, + ssl_certfile=args.ssl_certfile, + ssl_ca_certs=args.ssl_ca_certs, + ssl_cert_reqs=args.ssl_cert_reqs, + **uvicorn_kwargs, + ) - # NB: Await server shutdown only after the backend context is exited await shutdown_task - sock.close() - if __name__ == "__main__": - # NOTE(simon): - # This section should be in sync with vllm/scripts.py for CLI entrypoints. - parser = FlexibleArgumentParser( - description="vLLM OpenAI-Compatible RESTful API server.") - parser = make_arg_parser(parser) + parser = FlexibleArgumentParser() + parser.add_argument("--host", type=str, default=None) + parser.add_argument("--port", type=int, default=8000) + parser.add_argument("--ssl-keyfile", type=str, default=None) + parser.add_argument("--ssl-certfile", type=str, default=None) + parser.add_argument("--ssl-ca-certs", + type=str, + default=None, + help="The CA certificates file") + parser.add_argument( + "--ssl-cert-reqs", + type=int, + default=int(ssl.CERT_NONE), + help="Whether client certificate is required (see stdlib ssl module's)" + ) + parser.add_argument( + "--root-path", + type=str, + default=None, + help="FastAPI root_path when app is behind a path based routing proxy") + parser.add_argument("--log-level", type=str, default="debug") + parser = AsyncEngineArgs.add_cli_args(parser) parser.add_argument( "--load-in-low-bit", type=str, default="sym_int4", help="Low-bit quantization for IPEX-LLM models") args = parser.parse_args() - validate_parsed_serve_args(args) - uvloop.run(run_server(args)) \ No newline at end of file + asyncio.run(run_server(args)) diff --git a/python/llm/src/ipex_llm/vllm/cpu/entrypoints/openai/api_server.py b/python/llm/src/ipex_llm/vllm/cpu/entrypoints/openai/api_server.py index de6fa07fe54..c0ff6c397f0 100644 --- a/python/llm/src/ipex_llm/vllm/cpu/entrypoints/openai/api_server.py +++ b/python/llm/src/ipex_llm/vllm/cpu/entrypoints/openai/api_server.py @@ -1,5 +1,6 @@ import asyncio import atexit +import gc import importlib import inspect import multiprocessing @@ -7,16 +8,17 @@ import re import signal import socket +import sys import tempfile import uuid from argparse import Namespace from contextlib import asynccontextmanager from functools import partial from http import HTTPStatus -from typing import AsyncIterator, Optional, Set, Tuple +from typing import AsyncIterator, Dict, Optional, Set, Tuple, Union import uvloop -from fastapi import APIRouter, FastAPI, Request +from fastapi import APIRouter, FastAPI, HTTPException, Request from fastapi.exceptions import RequestValidationError from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, Response, StreamingResponse @@ -27,14 +29,14 @@ import vllm.envs as envs from vllm.config import ModelConfig from vllm.engine.arg_utils import AsyncEngineArgs -from ipex_llm.vllm.cpu.engine import IPEXLLMAsyncLLMEngine as AsyncLLMEngine +from vllm.engine.async_llm_engine import AsyncLLMEngine # type: ignore from vllm.engine.multiprocessing.client import MQLLMEngineClient from ipex_llm.vllm.cpu.engine import run_mp_engine from vllm.engine.protocol import EngineClient from vllm.entrypoints.chat_utils import load_chat_template from vllm.entrypoints.launcher import serve_http from vllm.entrypoints.logger import RequestLogger -from vllm.entrypoints.openai.cli_args import (make_arg_parser, +from ipex_llm.vllm.cpu.entrypoints.openai.cli_args import (make_arg_parser, validate_parsed_serve_args) # yapf conflicts with isort for this block # yapf: disable @@ -44,22 +46,31 @@ CompletionResponse, DetokenizeRequest, DetokenizeResponse, + EmbeddingChatRequest, + EmbeddingCompletionRequest, EmbeddingRequest, EmbeddingResponse, EmbeddingResponseData, ErrorResponse, LoadLoraAdapterRequest, + PoolingChatRequest, + PoolingCompletionRequest, PoolingRequest, PoolingResponse, + RerankRequest, RerankResponse, ScoreRequest, ScoreResponse, TokenizeRequest, TokenizeResponse, UnloadLoraAdapterRequest) +from vllm.entrypoints.openai.reasoning_parsers import ReasoningParserManager # yapf: enable from vllm.entrypoints.openai.serving_chat import OpenAIServingChat from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding -from vllm.entrypoints.openai.serving_engine import BaseModelPath, OpenAIServing +from vllm.entrypoints.openai.serving_engine import OpenAIServing +from vllm.entrypoints.openai.serving_models import (BaseModelPath, + OpenAIServingModels) from vllm.entrypoints.openai.serving_pooling import OpenAIServingPooling +from vllm.entrypoints.openai.serving_rerank import JinaAIServingRerank from vllm.entrypoints.openai.serving_score import OpenAIServingScores from vllm.entrypoints.openai.serving_tokenization import ( OpenAIServingTokenization) @@ -97,6 +108,11 @@ async def _force_log(): task.add_done_callback(_running_tasks.remove) else: task = None + + # Mark the startup heap as static so that it's ignored by GC. + # Reduces pause times of oldest generation collections. + gc.collect() + gc.freeze() try: yield finally: @@ -124,7 +140,7 @@ async def build_async_engine_client( async def build_async_engine_client_from_engine_args( engine_args: AsyncEngineArgs, disable_frontend_multiprocessing: bool = False, - load_in_low_bit: str = "sym_int4", + load_in_low_bit: str = 'sym_int4', ) -> AsyncIterator[EngineClient]: """ Create EngineClient, either: @@ -134,33 +150,22 @@ async def build_async_engine_client_from_engine_args( Returns the Client or None if the creation failed. """ - # Fall back - # TODO: fill out feature matrix. + # AsyncLLMEngine. if (MQLLMEngineClient.is_unsupported_config(engine_args) or envs.VLLM_USE_V1 or disable_frontend_multiprocessing): - engine_config = engine_args.create_engine_config( - UsageContext.OPENAI_API_SERVER) - uses_ray = getattr(AsyncLLMEngine._get_executor_cls(engine_config), - "uses_ray", False) - - build_engine = partial(AsyncLLMEngine.from_engine_args, - engine_args=engine_args, - engine_config=engine_config, - load_in_low_bit=load_in_low_bit, - usage_context=UsageContext.OPENAI_API_SERVER) - if uses_ray: - # Must run in main thread with ray for its signal handlers to work - engine_client = build_engine() - else: - engine_client = await asyncio.get_running_loop().run_in_executor( - None, build_engine) - yield engine_client - if hasattr(engine_client, "shutdown"): - engine_client.shutdown() - return + engine_client: Optional[EngineClient] = None + try: + engine_client = AsyncLLMEngine.from_engine_args( + engine_args=engine_args, + load_in_low_bit=load_in_low_bit, + usage_context=UsageContext.OPENAI_API_SERVER) + yield engine_client + finally: + if engine_client and hasattr(engine_client, "shutdown"): + engine_client.shutdown() - # Otherwise, use the multiprocessing AsyncLLMEngine. + # MQLLMEngine. else: if "PROMETHEUS_MULTIPROC_DIR" not in os.environ: # Make TemporaryDirectory for prometheus multiprocessing @@ -282,6 +287,10 @@ def base(request: Request) -> OpenAIServing: return tokenization(request) +def models(request: Request) -> OpenAIServingModels: + return request.app.state.openai_serving_models + + def chat(request: Request) -> Optional[OpenAIServingChat]: return request.app.state.openai_serving_chat @@ -302,6 +311,10 @@ def score(request: Request) -> Optional[OpenAIServingScores]: return request.app.state.openai_serving_scores +def rerank(request: Request) -> Optional[JinaAIServingRerank]: + return request.app.state.jinaai_serving_reranking + + def tokenization(request: Request) -> OpenAIServingTokenization: return request.app.state.openai_serving_tokenization @@ -317,6 +330,12 @@ async def health(raw_request: Request) -> Response: return Response(status_code=200) +@router.api_route("/ping", methods=["GET", "POST"]) +async def ping(raw_request: Request) -> Response: + """Ping check. Endpoint required for SageMaker""" + return await health(raw_request) + + @router.post("/tokenize") @with_cancellation async def tokenize(request: TokenizeRequest, raw_request: Request): @@ -349,10 +368,10 @@ async def detokenize(request: DetokenizeRequest, raw_request: Request): @router.get("/v1/models") async def show_available_models(raw_request: Request): - handler = base(raw_request) + handler = models(raw_request) - models = await handler.show_available_models() - return JSONResponse(content=models.model_dump()) + models_ = await handler.show_available_models() + return JSONResponse(content=models_.model_dump()) @router.get("/version") @@ -416,6 +435,8 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request): "use the Pooling API (`/pooling`) instead.") res = await fallback_handler.create_pooling(request, raw_request) + + generator: Union[ErrorResponse, EmbeddingResponse] if isinstance(res, PoolingResponse): generator = EmbeddingResponse( id=res.id, @@ -490,6 +511,103 @@ async def create_score_v1(request: ScoreRequest, raw_request: Request): return await create_score(request, raw_request) +@router.post("/rerank") +@with_cancellation +async def do_rerank(request: RerankRequest, raw_request: Request): + handler = rerank(raw_request) + if handler is None: + return base(raw_request).create_error_response( + message="The model does not support Rerank (Score) API") + generator = await handler.do_rerank(request, raw_request) + if isinstance(generator, ErrorResponse): + return JSONResponse(content=generator.model_dump(), + status_code=generator.code) + elif isinstance(generator, RerankResponse): + return JSONResponse(content=generator.model_dump()) + + assert_never(generator) + + +@router.post("/v1/rerank") +@with_cancellation +async def do_rerank_v1(request: RerankRequest, raw_request: Request): + logger.warning_once( + "To indicate that the rerank API is not part of the standard OpenAI" + " API, we have located it at `/rerank`. Please update your client" + "accordingly. (Note: Conforms to JinaAI rerank API)") + + return await do_rerank(request, raw_request) + + +@router.post("/v2/rerank") +@with_cancellation +async def do_rerank_v2(request: RerankRequest, raw_request: Request): + return await do_rerank(request, raw_request) + + +TASK_HANDLERS: Dict[str, Dict[str, tuple]] = { + "generate": { + "messages": (ChatCompletionRequest, create_chat_completion), + "default": (CompletionRequest, create_completion), + }, + "embed": { + "messages": (EmbeddingChatRequest, create_embedding), + "default": (EmbeddingCompletionRequest, create_embedding), + }, + "score": { + "default": (RerankRequest, do_rerank) + }, + "rerank": { + "default": (RerankRequest, do_rerank) + }, + "reward": { + "messages": (PoolingChatRequest, create_pooling), + "default": (PoolingCompletionRequest, create_pooling), + }, + "classify": { + "messages": (PoolingChatRequest, create_pooling), + "default": (PoolingCompletionRequest, create_pooling), + }, +} + +if envs.VLLM_SERVER_DEV_MODE: + + @router.post("/reset_prefix_cache") + async def reset_prefix_cache(raw_request: Request): + """ + Reset the prefix cache. Note that we currently do not check if the + prefix cache is successfully reset in the API server. + """ + logger.info("Resetting prefix cache...") + await engine_client(raw_request).reset_prefix_cache() + return Response(status_code=200) + + +@router.post("/invocations") +async def invocations(raw_request: Request): + """ + For SageMaker, routes requests to other handlers based on model `task`. + """ + body = await raw_request.json() + task = raw_request.app.state.task + + if task not in TASK_HANDLERS: + raise HTTPException( + status_code=400, + detail=f"Unsupported task: '{task}' for '/invocations'. " + f"Expected one of {set(TASK_HANDLERS.keys())}") + + handler_config = TASK_HANDLERS[task] + if "messages" in body: + request_model, handler = handler_config["messages"] + else: + request_model, handler = handler_config["default"] + + # this is required since we lose the FastAPI automatic casting + request = request_model.model_validate(body) + return await handler(request, raw_request) + + if envs.VLLM_TORCH_PROFILER_DIR: logger.warning( "Torch Profiler is enabled in the API server. This should ONLY be " @@ -518,26 +636,22 @@ async def stop_profile(raw_request: Request): @router.post("/v1/load_lora_adapter") async def load_lora_adapter(request: LoadLoraAdapterRequest, raw_request: Request): - for route in [chat, completion, embedding]: - handler = route(raw_request) - if handler is not None: - response = await handler.load_lora_adapter(request) - if isinstance(response, ErrorResponse): - return JSONResponse(content=response.model_dump(), - status_code=response.code) + handler = models(raw_request) + response = await handler.load_lora_adapter(request) + if isinstance(response, ErrorResponse): + return JSONResponse(content=response.model_dump(), + status_code=response.code) return Response(status_code=200, content=response) @router.post("/v1/unload_lora_adapter") async def unload_lora_adapter(request: UnloadLoraAdapterRequest, raw_request: Request): - for route in [chat, completion, embedding]: - handler = route(raw_request) - if handler is not None: - response = await handler.unload_lora_adapter(request) - if isinstance(response, ErrorResponse): - return JSONResponse(content=response.model_dump(), - status_code=response.code) + handler = models(raw_request) + response = await handler.unload_lora_adapter(request) + if isinstance(response, ErrorResponse): + return JSONResponse(content=response.model_dump(), + status_code=response.code) return Response(status_code=200, content=response) @@ -604,7 +718,7 @@ async def add_request_id(request: Request, call_next): module_path, object_name = middleware.rsplit(".", 1) imported = getattr(importlib.import_module(module_path), object_name) if inspect.isclass(imported): - app.add_middleware(imported) + app.add_middleware(imported) # type: ignore[arg-type] elif inspect.iscoroutinefunction(imported): app.middleware("http")(imported) else: @@ -614,7 +728,7 @@ async def add_request_id(request: Request, call_next): return app -def init_app_state( +async def init_app_state( engine_client: EngineClient, model_config: ModelConfig, state: State, @@ -641,34 +755,40 @@ def init_app_state( resolved_chat_template = load_chat_template(args.chat_template) logger.info("Using supplied chat template:\n%s", resolved_chat_template) + state.openai_serving_models = OpenAIServingModels( + engine_client=engine_client, + model_config=model_config, + base_model_paths=base_model_paths, + lora_modules=args.lora_modules, + prompt_adapters=args.prompt_adapters, + ) + await state.openai_serving_models.init_static_loras() state.openai_serving_chat = OpenAIServingChat( engine_client, model_config, - base_model_paths, + state.openai_serving_models, args.response_role, - lora_modules=args.lora_modules, - prompt_adapters=args.prompt_adapters, request_logger=request_logger, chat_template=resolved_chat_template, chat_template_content_format=args.chat_template_content_format, return_tokens_as_token_ids=args.return_tokens_as_token_ids, enable_auto_tools=args.enable_auto_tool_choice, tool_parser=args.tool_call_parser, + enable_reasoning=args.enable_reasoning, + reasoning_parser=args.reasoning_parser, enable_prompt_tokens_details=args.enable_prompt_tokens_details, ) if model_config.runner_type == "generate" else None state.openai_serving_completion = OpenAIServingCompletion( engine_client, model_config, - base_model_paths, - lora_modules=args.lora_modules, - prompt_adapters=args.prompt_adapters, + state.openai_serving_models, request_logger=request_logger, return_tokens_as_token_ids=args.return_tokens_as_token_ids, ) if model_config.runner_type == "generate" else None state.openai_serving_pooling = OpenAIServingPooling( engine_client, model_config, - base_model_paths, + state.openai_serving_models, request_logger=request_logger, chat_template=resolved_chat_template, chat_template_content_format=args.chat_template_content_format, @@ -676,7 +796,7 @@ def init_app_state( state.openai_serving_embedding = OpenAIServingEmbedding( engine_client, model_config, - base_model_paths, + state.openai_serving_models, request_logger=request_logger, chat_template=resolved_chat_template, chat_template_content_format=args.chat_template_content_format, @@ -684,18 +804,24 @@ def init_app_state( state.openai_serving_scores = OpenAIServingScores( engine_client, model_config, - base_model_paths, + state.openai_serving_models, + request_logger=request_logger + ) if model_config.task == "score" else None + state.jinaai_serving_reranking = JinaAIServingRerank( + engine_client, + model_config, + state.openai_serving_models, request_logger=request_logger ) if model_config.task == "score" else None state.openai_serving_tokenization = OpenAIServingTokenization( engine_client, model_config, - base_model_paths, - lora_modules=args.lora_modules, + state.openai_serving_models, request_logger=request_logger, chat_template=resolved_chat_template, chat_template_content_format=args.chat_template_content_format, ) + state.task = model_config.task def create_server_socket(addr: Tuple[str, int]) -> socket.socket: @@ -717,11 +843,18 @@ async def run_server(args, **uvicorn_kwargs) -> None: if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3: ToolParserManager.import_tool_parser(args.tool_parser_plugin) - valide_tool_parses = ToolParserManager.tool_parsers.keys() + valid_tool_parses = ToolParserManager.tool_parsers.keys() if args.enable_auto_tool_choice \ - and args.tool_call_parser not in valide_tool_parses: + and args.tool_call_parser not in valid_tool_parses: raise KeyError(f"invalid tool call parser: {args.tool_call_parser} " - f"(chose from {{ {','.join(valide_tool_parses)} }})") + f"(chose from {{ {','.join(valid_tool_parses)} }})") + + valid_reasoning_parses = ReasoningParserManager.reasoning_parsers.keys() + if args.enable_reasoning \ + and args.reasoning_parser not in valid_reasoning_parses: + raise KeyError( + f"invalid reasoning parser: {args.reasoning_parser} " + f"(chose from {{ {','.join(valid_reasoning_parses)} }})") # workaround to make sure that we bind the port before the engine is set up. # This avoids race conditions with ray. @@ -743,7 +876,7 @@ def signal_handler(*_) -> None: app = build_app(args) model_config = await engine_client.get_model_config() - init_app_state(engine_client, model_config, app.state, args) + await init_app_state(engine_client, model_config, app.state, args) shutdown_task = await serve_http( app, @@ -755,6 +888,8 @@ def signal_handler(*_) -> None: ssl_certfile=args.ssl_certfile, ssl_ca_certs=args.ssl_ca_certs, ssl_cert_reqs=args.ssl_cert_reqs, + # Workaround to work on macOS + fd=sock.fileno() if sys.platform.startswith("darwin") else None, **uvicorn_kwargs, ) @@ -770,11 +905,6 @@ def signal_handler(*_) -> None: parser = FlexibleArgumentParser( description="vLLM OpenAI-Compatible RESTful API server.") parser = make_arg_parser(parser) - parser.add_argument( - "--load-in-low-bit", - type=str, - default="sym_int4", - help="Low-bit quantization for IPEX-LLM models") args = parser.parse_args() validate_parsed_serve_args(args) diff --git a/python/llm/src/ipex_llm/vllm/cpu/entrypoints/openai/cli_args.py b/python/llm/src/ipex_llm/vllm/cpu/entrypoints/openai/cli_args.py index 92afa12d25c..bc06205d3ec 100644 --- a/python/llm/src/ipex_llm/vllm/cpu/entrypoints/openai/cli_args.py +++ b/python/llm/src/ipex_llm/vllm/cpu/entrypoints/openai/cli_args.py @@ -12,7 +12,8 @@ from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str from vllm.entrypoints.chat_utils import (ChatTemplateContentFormatOption, validate_chat_template) -from vllm.entrypoints.openai.serving_engine import (LoRAModulePath, +from vllm.entrypoints.openai.reasoning_parsers import ReasoningParserManager +from vllm.entrypoints.openai.serving_models import (LoRAModulePath, PromptAdapterPath) from vllm.entrypoints.openai.tool_parsers import ToolParserManager from vllm.utils import FlexibleArgumentParser @@ -79,29 +80,29 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: parser.add_argument("--host", type=nullable_str, default=None, - help="host name") - parser.add_argument("--port", type=int, default=8000, help="port number") + help="Host name.") + parser.add_argument("--port", type=int, default=8000, help="Port number.") parser.add_argument( "--uvicorn-log-level", type=str, default="info", choices=['debug', 'info', 'warning', 'error', 'critical', 'trace'], - help="log level for uvicorn") + help="Log level for uvicorn.") parser.add_argument("--allow-credentials", action="store_true", - help="allow credentials") + help="Allow credentials.") parser.add_argument("--allowed-origins", type=json.loads, default=["*"], - help="allowed origins") + help="Allowed origins.") parser.add_argument("--allowed-methods", type=json.loads, default=["*"], - help="allowed methods") + help="Allowed methods.") parser.add_argument("--allowed-headers", type=json.loads, default=["*"], - help="allowed headers") + help="Allowed headers.") parser.add_argument("--api-key", type=nullable_str, default=None, @@ -115,10 +116,10 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: action=LoRAParserAction, help="LoRA module configurations in either 'name=path' format" "or JSON format. " - "Example (old format): 'name=path' " + "Example (old format): ``'name=path'`` " "Example (new format): " - "'{\"name\": \"name\", \"local_path\": \"path\", " - "\"base_model_name\": \"id\"}'") + "``{\"name\": \"name\", \"path\": \"lora_path\", " + "\"base_model_name\": \"id\"}``") parser.add_argument( "--prompt-adapters", type=nullable_str, @@ -132,7 +133,7 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: default=None, help="The file path to the chat template, " "or the template in single-line form " - "for the specified model") + "for the specified model.") parser.add_argument( '--chat-template-content-format', type=str, @@ -141,38 +142,39 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: help='The format to render message content within a chat template.' '\n\n' '* "string" will render the content as a string. ' - 'Example: "Hello World"\n' + 'Example: ``"Hello World"``\n' '* "openai" will render the content as a list of dictionaries, ' 'similar to OpenAI schema. ' - 'Example: [{"type": "text", "text": "Hello world!"}]') + 'Example: ``[{"type": "text", "text": "Hello world!"}]``') parser.add_argument("--response-role", type=nullable_str, default="assistant", help="The role name to return if " - "`request.add_generation_prompt=true`.") + "``request.add_generation_prompt=true``.") parser.add_argument("--ssl-keyfile", type=nullable_str, default=None, - help="The file path to the SSL key file") + help="The file path to the SSL key file.") parser.add_argument("--ssl-certfile", type=nullable_str, default=None, - help="The file path to the SSL cert file") + help="The file path to the SSL cert file.") parser.add_argument("--ssl-ca-certs", type=nullable_str, default=None, - help="The CA certificates file") + help="The CA certificates file.") parser.add_argument( "--ssl-cert-reqs", type=int, default=int(ssl.CERT_NONE), - help="Whether client certificate is required (see stdlib ssl module's)" + help="Whether client certificate is required (see stdlib ssl module's)." ) parser.add_argument( "--root-path", type=nullable_str, default=None, - help="FastAPI root_path when app is behind a path based routing proxy") + help="FastAPI root_path when app is behind a path based routing proxy." + ) parser.add_argument( "--middleware", type=nullable_str, @@ -182,15 +184,15 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: "We accept multiple --middleware arguments. " "The value should be an import path. " "If a function is provided, vLLM will add it to the server " - "using @app.middleware('http'). " + "using ``@app.middleware('http')``. " "If a class is provided, vLLM will add it to the server " - "using app.add_middleware(). ") + "using ``app.add_middleware()``. ") parser.add_argument( "--return-tokens-as-token-ids", action="store_true", - help="When --max-logprobs is specified, represents single tokens as " - "strings of the form 'token_id:{token_id}' so that tokens that " - "are not JSON-encodable can be identified.") + help="When ``--max-logprobs`` is specified, represents single tokens " + " as strings of the form 'token_id:{token_id}' so that tokens " + "that are not JSON-encodable can be identified.") parser.add_argument( "--disable-frontend-multiprocessing", action="store_true", @@ -205,8 +207,24 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: "--enable-auto-tool-choice", action="store_true", default=False, - help="Enable auto tool choice for supported models. Use --tool-call-parser" - " to specify which parser to use") + help="Enable auto tool choice for supported models. Use " + "``--tool-call-parser`` to specify which parser to use.") + parser.add_argument( + "--enable-reasoning", + action="store_true", + default=False, + help="Whether to enable reasoning_content for the model. " + "If enabled, the model will be able to generate reasoning content.") + + valid_reasoning_parsers = ReasoningParserManager.reasoning_parsers.keys() + parser.add_argument( + "--reasoning-parser", + type=str, + metavar="{" + ",".join(valid_reasoning_parsers) + "}", + default=None, + help="Select the reasoning parser depending on the model that you're using." + " This is used to parse the reasoning content into OpenAI API " + "format. Required for ``--enable-reasoning``.") valid_tool_parsers = ToolParserManager.tool_parsers.keys() parser.add_argument( @@ -217,7 +235,7 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: default=None, help="Select the tool call parser depending on the model that you're using." " This is used to parse the model-generated tool call into OpenAI API " - "format. Required for --enable-auto-tool-choice.") + "format. Required for ``--enable-auto-tool-choice``.") parser.add_argument( "--tool-parser-plugin", @@ -225,7 +243,7 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: default="", help="Special the tool parser plugin write to parse the model-generated tool" " into OpenAI API format, the name register in this plugin can be used " - "in --tool-call-parser.") + "in ``--tool-call-parser``.") parser = AsyncEngineArgs.add_cli_args(parser) @@ -240,7 +258,7 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: "--disable-fastapi-docs", action='store_true', default=False, - help="Disable FastAPI's OpenAPI schema, Swagger UI, and ReDoc endpoint" + help="Disable FastAPI's OpenAPI schema, Swagger UI, and ReDoc endpoint." ) parser.add_argument( "--enable-prompt-tokens-details", @@ -270,6 +288,18 @@ def validate_parsed_serve_args(args: argparse.Namespace): raise TypeError("Error: --enable-auto-tool-choice requires " # noqa "--tool-call-parser") + # Enable reasoning needs a reasoning parser to be valid + if args.enable_reasoning and not args.reasoning_parser: + raise TypeError("Error: --enable-reasoning requires " # noqa + "--reasoning-parser") + + # Ref https://api-docs.deepseek.com/guides/reasoning_model + # tool call and reasoning cannot be enabled at the same time. + if args.enable_auto_tool_choice and args.enable_reasoning: + raise TypeError( # noqa + "Error: --enable-auto-tool-choice and " + "--enable-reasoning cannot be enabled at the same time") + def create_parser_for_docs() -> FlexibleArgumentParser: parser_for_docs = FlexibleArgumentParser( diff --git a/python/llm/src/ipex_llm/vllm/cpu/ipex_llm_v1_wrapper.py b/python/llm/src/ipex_llm/vllm/cpu/ipex_llm_v1_wrapper.py deleted file mode 100644 index 2dc81bb6bed..00000000000 --- a/python/llm/src/ipex_llm/vllm/cpu/ipex_llm_v1_wrapper.py +++ /dev/null @@ -1,23 +0,0 @@ -from vllm.logger import init_logger -from vllm.v1.executor.ray_utils import RayWorkerWrapper - - -logger = init_logger(__name__) - - -class IPEXLLMV1Wrapper(RayWorkerWrapper): - def __init__(self, load_in_low_bit="sym_int4", *args, **kwargs) -> None: - super().__init__(*args, **kwargs) - from ipex_llm.vllm.cpu.model_convert import _ipex_llm_convert - _ipex_llm_convert(load_in_low_bit=load_in_low_bit) - self.compiled_dag_cuda_device_set = False - - -def get_ipex_llm_v1_wrapper(load_in_low_bit): - # The reason why we not using functools.partial is that - # ray seems not work well with it. - class WrapperWithLoadBit(IPEXLLMV1Wrapper): - def __init__(self, *args, **kwargs) -> None: - super().__init__(load_in_low_bit=load_in_low_bit, *args, **kwargs) - - return WrapperWithLoadBit diff --git a/python/llm/src/ipex_llm/vllm/cpu/model_convert.py b/python/llm/src/ipex_llm/vllm/cpu/model_convert.py index 803328f3f83..ca05b0dd3a1 100644 --- a/python/llm/src/ipex_llm/vllm/cpu/model_convert.py +++ b/python/llm/src/ipex_llm/vllm/cpu/model_convert.py @@ -48,7 +48,7 @@ def _sample_get_logits( logits = lm_head(hidden_states) if embedding_bias is not None: logits += embedding_bias - if self.use_gather: + if not self.use_all_gather: logits = tensor_model_parallel_gather(logits) else: logits = tensor_model_parallel_all_gather(logits) @@ -65,12 +65,9 @@ def _model_sample_convert(): def _ipex_llm_convert(load_in_low_bit): from vllm.worker.cpu_model_runner import CPUModelRunner from ipex_llm.vllm.cpu.ipex_llm_wrapper import get_ipex_llm_wrapper - from ipex_llm.vllm.cpu.ipex_llm_v1_wrapper import get_ipex_llm_v1_wrapper - import vllm.executor.ray_utils as ray_utils_v0 - import vllm.v1.executor.ray_utils as ray_utils_v1 + import vllm.executor.ray_utils as ray_utils setattr(CPUModelRunner, "load_model", get_load_function(load_in_low_bit)) - setattr(ray_utils_v0, "RayWorkerWrapper", get_ipex_llm_wrapper(load_in_low_bit)) - setattr(ray_utils_v1, "RayWorkerWrapper", get_ipex_llm_v1_wrapper(load_in_low_bit)) + setattr(ray_utils, "RayWorkerWrapper", get_ipex_llm_wrapper(load_in_low_bit)) def get_load_function(low_bit):