diff --git a/examples/python_rs/llm/vllm/README.md b/examples/python_rs/llm/vllm/README.md new file mode 100644 index 00000000..09347444 --- /dev/null +++ b/examples/python_rs/llm/vllm/README.md @@ -0,0 +1,141 @@ + + +# vLLM Integration with Triton Distributed + +This example demonstrates how to use Triton Distributed to serve large language models with the vLLM engine, enabling efficient model serving with both monolithic and disaggregated deployment options. + +## Prerequisites + +1. Follow the setup instructions in the Python bindings [README](/runtime/rust/python-wheel/README.md) to prepare your environment + +2. Install vLLM: + ```bash + uv pip install vllm==0.7.2 + ``` + +3. Start required services (etcd and NATS): + + Option A: Using [Docker Compose](/runtime/rust/docker-compose.yml) (Recommended) + ```bash + docker-compose up -d + ``` + + Option B: Manual Setup + + - [NATS.io](https://docs.nats.io/running-a-nats-service/introduction/installation) server with [Jetstream](https://docs.nats.io/nats-concepts/jetstream) + - example: `nats-server -js --trace` + - [etcd](https://etcd.io) server + - follow instructions in [etcd installation](https://etcd.io/docs/v3.5/install/) to start an `etcd-server` locally + +## Deployment Options + +### 1. Monolithic Deployment + +Run the server and client components in separate terminal sessions: + +**Terminal 1 - Server:** +```bash +python3 -m monolith.worker \ + --model deepseek-ai/DeepSeek-R1-Distill-Llama-8B \ + --max-model-len 100 \ + --enforce-eager +``` + +**Terminal 2 - Client:** +```bash +python3 -m common.client \ + --prompt "what is the capital of france?" \ + --max-tokens 10 \ + --temperature 0.5 +``` + +The output should look similar to: +``` +Annotated(data=' Well', event=None, comment=[], id=None) +Annotated(data=' Well,', event=None, comment=[], id=None) +Annotated(data=' Well, France', event=None, comment=[], id=None) +Annotated(data=' Well, France is', event=None, comment=[], id=None) +Annotated(data=' Well, France is a', event=None, comment=[], id=None) +Annotated(data=' Well, France is a country', event=None, comment=[], id=None) +Annotated(data=' Well, France is a country located', event=None, comment=[], id=None) +Annotated(data=' Well, France is a country located in', event=None, comment=[], id=None) +Annotated(data=' Well, France is a country located in Western', event=None, comment=[], id=None) +Annotated(data=' Well, France is a country located in Western Europe', event=None, comment=[], id=None) +``` + + +### 2. Disaggregated Deployment + +This deployment option splits the model serving across prefill and decode workers, enabling more efficient resource utilization. + +**Terminal 1 - Prefill Worker:** +```bash +CUDA_VISIBLE_DEVICES=0 python3 -m disaggregated.prefill_worker \ + --model deepseek-ai/DeepSeek-R1-Distill-Llama-8B \ + --max-model-len 100 \ + --gpu-memory-utilization 0.8 \ + --enforce-eager \ + --kv-transfer-config \ + '{"kv_connector":"PyNcclConnector","kv_role":"kv_producer","kv_rank":0,"kv_parallel_size":2}' +``` + +**Terminal 2 - Decode Worker:** +```bash +CUDA_VISIBLE_DEVICES=1 python3 -m disaggregated.decode_worker \ + --model deepseek-ai/DeepSeek-R1-Distill-Llama-8B \ + --max-model-len 100 \ + --gpu-memory-utilization 0.8 \ + --enforce-eager \ + --kv-transfer-config \ + '{"kv_connector":"PyNcclConnector","kv_role":"kv_consumer","kv_rank":1,"kv_parallel_size":2}' +``` + +**Terminal 3 - Client:** +```bash +python3 -m common.client \ + --prompt "what is the capital of france?" \ + --max-tokens 10 \ + --temperature 0.5 +``` + +The disaggregated deployment utilizes separate GPUs for prefill and decode operations, allowing for optimized resource allocation and improved performance. For more details on the disaggregated deployment, please refer to the [vLLM documentation](https://docs.vllm.ai/en/latest/features/disagg_prefill.html). + + + +### 3. Multi-Node Deployment + +The vLLM workers can be deployed across multiple nodes by configuring the NATS and etcd connection endpoints through environment variables. This enables distributed inference across a cluster. + +Set the following environment variables on each node before running the workers: + +```bash +export NATS_SERVER="nats://:" +export ETCD_ENDPOINTS="http://:,http://:",... +``` + +For disaggregated deployment, you will also need to pass the `kv_ip` and `kv_port` to the workers in the `kv_transfer_config` argument: + +```bash +... + --kv-transfer-config \ + '{"kv_connector":"PyNcclConnector","kv_role":"kv_producer","kv_rank":,"kv_parallel_size":2,"kv_ip":,"kv_port":}' +``` + + + + diff --git a/examples/python_rs/llm/vllm/common/__init__.py b/examples/python_rs/llm/vllm/common/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/examples/python_rs/llm/vllm/common/client.py b/examples/python_rs/llm/vllm/common/client.py new file mode 100644 index 00000000..b33c1c38 --- /dev/null +++ b/examples/python_rs/llm/vllm/common/client.py @@ -0,0 +1,65 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import asyncio + +import uvloop +from triton_distributed_rs import DistributedRuntime, triton_worker +from vllm.utils import FlexibleArgumentParser + +from .protocol import Request + + +@triton_worker() +async def worker( + runtime: DistributedRuntime, prompt: str, max_tokens: int, temperature: float +): + """ + Instantiate a `backend` client and call the `generate` endpoint + """ + # get endpoint + endpoint = runtime.namespace("triton-init").component("vllm").endpoint("generate") + + # create client + client = await endpoint.client() + + # list the endpoints + print(client.endpoint_ids()) + + # issue request + stream = await client.generate( + Request( + prompt=prompt, + sampling_params={"temperature": temperature, "max_tokens": max_tokens}, + ).model_dump_json() + ) + + # process response + async for resp in stream: + print(resp) + + +if __name__ == "__main__": + uvloop.install() + + parser = FlexibleArgumentParser() + parser.add_argument("--prompt", type=str, default="what is the capital of france?") + parser.add_argument("--max-tokens", type=int, default=10) + parser.add_argument("--temperature", type=float, default=0.5) + + args = parser.parse_args() + + asyncio.run(worker(args.prompt, args.max_tokens, args.temperature)) diff --git a/examples/python_rs/llm/vllm/common/parser.py b/examples/python_rs/llm/vllm/common/parser.py new file mode 100644 index 00000000..8d5946ae --- /dev/null +++ b/examples/python_rs/llm/vllm/common/parser.py @@ -0,0 +1,25 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from vllm.engine.arg_utils import AsyncEngineArgs +from vllm.utils import FlexibleArgumentParser + + +def parse_vllm_args() -> AsyncEngineArgs: + parser = FlexibleArgumentParser() + parser = AsyncEngineArgs.add_cli_args(parser) + args = parser.parse_args() + return AsyncEngineArgs.from_cli_args(args) diff --git a/examples/python_rs/llm/vllm/common/protocol.py b/examples/python_rs/llm/vllm/common/protocol.py new file mode 100644 index 00000000..1e5065ff --- /dev/null +++ b/examples/python_rs/llm/vllm/common/protocol.py @@ -0,0 +1,34 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from pydantic import BaseModel + + +class Request(BaseModel): + prompt: str + sampling_params: dict + + +class PrefillRequest(Request): + request_id: str + + +class Response(BaseModel): + text: str + + +class PrefillResponse(BaseModel): + prefilled: bool diff --git a/examples/python_rs/llm/vllm/disaggregated/__init__.py b/examples/python_rs/llm/vllm/disaggregated/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/examples/python_rs/llm/vllm/disaggregated/decode_worker.py b/examples/python_rs/llm/vllm/disaggregated/decode_worker.py new file mode 100644 index 00000000..2fc9fa17 --- /dev/null +++ b/examples/python_rs/llm/vllm/disaggregated/decode_worker.py @@ -0,0 +1,92 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import asyncio +import uuid + +import uvloop +import vllm +from common.parser import parse_vllm_args +from common.protocol import PrefillRequest, Request, Response +from triton_distributed_rs import DistributedRuntime, triton_endpoint, triton_worker +from vllm.engine.arg_utils import AsyncEngineArgs +from vllm.logger import logger as vllm_logger + + +class VllmDecodeEngine: + """ + Request handler for the generate endpoint + """ + + def __init__(self, engine_args: AsyncEngineArgs, prefill): + assert ( + engine_args.kv_transfer_config.is_kv_consumer + ), "Decode worker must be a KV consumer" + self.engine = vllm.AsyncLLMEngine.from_engine_args(engine_args) + self.prefill = prefill + + @triton_endpoint(Request, Response) + async def generate(self, request): + vllm_logger.info(f"Received request: {request}") + sampling_params = vllm.SamplingParams(**request.sampling_params) + request_id = str(uuid.uuid4()) + + prefill_sampling_params = {**request.sampling_params} + prefill_sampling_params["max_tokens"] = 1 + prefill_request = PrefillRequest( + prompt=request.prompt, + sampling_params=prefill_sampling_params, + request_id=request_id, + ) + prefill_generator = await self.prefill.generate( + prefill_request.model_dump_json() + ) + prefill_response = [resp async for resp in prefill_generator] + assert len(prefill_response) == 1, "Prefill response should be a single boolean" + prefill_response = prefill_response[0] + vllm_logger.debug(f"Prefill response: {prefill_response}") + + async for response in self.engine.generate( + request.prompt, sampling_params, request_id + ): + vllm_logger.debug(f"Generated response: {response}") + yield response.outputs[0].text + + +@triton_worker() +async def worker(runtime: DistributedRuntime, engine_args: AsyncEngineArgs): + """ + Instantiate a `backend` component and serve the `generate` endpoint + A `Component` can serve multiple endpoints + """ + component = runtime.namespace("triton-init").component("vllm") + await component.create_service() + + prefill = ( + await runtime.namespace("triton-init") + .component("prefill") + .endpoint("generate") + .client() + ) + + endpoint = component.endpoint("generate") + await endpoint.serve_endpoint(VllmDecodeEngine(engine_args, prefill).generate) + + +if __name__ == "__main__": + uvloop.install() + engine_args = parse_vllm_args() + asyncio.run(worker(engine_args)) diff --git a/examples/python_rs/llm/vllm/disaggregated/prefill_worker.py b/examples/python_rs/llm/vllm/disaggregated/prefill_worker.py new file mode 100644 index 00000000..2f0a9481 --- /dev/null +++ b/examples/python_rs/llm/vllm/disaggregated/prefill_worker.py @@ -0,0 +1,66 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import asyncio + +import uvloop +import vllm +from common.parser import parse_vllm_args +from common.protocol import PrefillRequest, PrefillResponse +from triton_distributed_rs import DistributedRuntime, triton_endpoint, triton_worker +from vllm.engine.arg_utils import AsyncEngineArgs +from vllm.logger import logger as vllm_logger + + +class VllmPrefillEngine: + """ + Request handler for the generate endpoint + """ + + def __init__(self, engine_args: AsyncEngineArgs): + assert ( + engine_args.kv_transfer_config.is_kv_producer + ), "Prefill worker must be a KV producer" + self.engine = vllm.AsyncLLMEngine.from_engine_args(engine_args) + + @triton_endpoint(PrefillRequest, PrefillResponse) + async def generate(self, request): + vllm_logger.info(f"Received prefill request: {request}") + sampling_params = vllm.SamplingParams(**request.sampling_params) + async for response in self.engine.generate( + request.prompt, sampling_params, request.request_id + ): + vllm_logger.debug(f"Generated response: {response}") + yield True + + +@triton_worker() +async def worker(runtime: DistributedRuntime, engine_args: AsyncEngineArgs): + """ + Instantiate a `backend` component and serve the `generate` endpoint + A `Component` can serve multiple endpoints + """ + component = runtime.namespace("triton-init").component("prefill") + await component.create_service() + + endpoint = component.endpoint("generate") + await endpoint.serve_endpoint(VllmPrefillEngine(engine_args).generate) + + +if __name__ == "__main__": + uvloop.install() + engine_args = parse_vllm_args() + asyncio.run(worker(engine_args)) diff --git a/examples/python_rs/llm/vllm/monolith/__init__.py b/examples/python_rs/llm/vllm/monolith/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/examples/python_rs/llm/vllm/monolith/worker.py b/examples/python_rs/llm/vllm/monolith/worker.py new file mode 100644 index 00000000..6951b828 --- /dev/null +++ b/examples/python_rs/llm/vllm/monolith/worker.py @@ -0,0 +1,65 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import asyncio +import uuid + +import uvloop +import vllm +from common.parser import parse_vllm_args +from common.protocol import Request, Response +from triton_distributed_rs import DistributedRuntime, triton_endpoint, triton_worker +from vllm.engine.arg_utils import AsyncEngineArgs +from vllm.logger import logger as vllm_logger + + +class VllmEngine: + """ + Request handler for the generate endpoint + """ + + def __init__(self, engine_args: AsyncEngineArgs): + self.engine = vllm.AsyncLLMEngine.from_engine_args(engine_args) + + @triton_endpoint(Request, Response) + async def generate(self, request): + vllm_logger.debug(f"Received request: {request}") + sampling_params = vllm.SamplingParams(**request.sampling_params) + request_id = str(uuid.uuid4()) + async for response in self.engine.generate( + request.prompt, sampling_params, request_id + ): + vllm_logger.debug(f"Generated response: {response}") + yield response.outputs[0].text + + +@triton_worker() +async def worker(runtime: DistributedRuntime, engine_args: AsyncEngineArgs): + """ + Instantiate a `backend` component and serve the `generate` endpoint + A `Component` can serve multiple endpoints + """ + component = runtime.namespace("triton-init").component("vllm") + await component.create_service() + + endpoint = component.endpoint("generate") + await endpoint.serve_endpoint(VllmEngine(engine_args).generate) + + +if __name__ == "__main__": + uvloop.install() + engine_args = parse_vllm_args() + asyncio.run(worker(engine_args))