Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: vLLM DistributedRuntime Monolith and Disagg Workers Example #113

Merged
merged 21 commits into from
Feb 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 141 additions & 0 deletions examples/python_rs/llm/vllm/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
<!--
SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
SPDX-License-Identifier: Apache-2.0

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->

# vLLM Integration with Triton Distributed

This example demonstrates how to use Triton Distributed to serve large language models with the vLLM engine, enabling efficient model serving with both monolithic and disaggregated deployment options.

## Prerequisites
Copy link
Contributor

@rmccorm4 rmccorm4 Feb 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ptarasiewiczNV Is there an assumed container as the base environment for these steps too? Such as the:

./container/build.sh --framework vllm
./container/run.sh --framework vllm -it

flow beforehand?

All the rust/runtime examples that don't use vllm or GPUs work fine on my host, so I tried that here too since the README doesn't mention any containers:

cd triton_distributed/
git checkout ptarasiewicz/vllm-example-rust-runtime
cd runtime/rust/python-wheel/
uv venv
source .venv/bin/activate
uv pip install maturin
maturin develop --uv
uv pip install vllm==0.7.2

But when using vllm on host (no container) I got this error from following the steps to launch monolith worker (Note I don't have cuda toolkit/runtime installed on my host globally, only the driver -- but the venv installs the cuda toolkit during vllm/pytorch install):

$ python3 -m monolith.worker \
    --model deepseek-ai/DeepSeek-R1-Distill-Llama-8B \
    --max-model-len 100 \
    --enforce-eager
...
ImportError: /home/rmccormick/triton/distributed/v0.2.0/triton_distributed/runtime/rust/python-wheel/.venv/lib/python3.10/site-packages/torch/lib/../../nvidia/cusparse/lib/libcusparse.so.12: undefined symbol: __nvJitLinkComplete_12_4, version libnvJitLink.so.12

Seems to be work-aroundable by doing something like this to add this libnvjitlink to the ld library path:

rmccormick@ced35d0-lcedt:~/triton/distributed/v0.2.0/triton_distributed/examples/python_rs/llm/vllm$ export LD_LIBRARY_PATH=$PWD/../../../../runtime/rust/python-wheel/.venv/lib64/python3.10/site-packages/nvidia/nvjitlink/lib:${LD_LIBRARY_PATH}

Similar issue: pytorch/pytorch#111469

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After the above WAR to put libnvjitlink on ld library path, I can run the monolith worker and client:

rmccormick@ced35d0-lcedt:~/triton/distributed/v0.2.0/triton_distributed/runtime/rust/python-wheel$ source .venv/bin/activate

(python-wheel) rmccormick@ced35d0-lcedt:~/triton/distributed/v0.2.0/triton_distributed/runtime/rust/python-wheel$ cd ../../../examples/python_rs/llm/vllm

(python-wheel) rmccormick@ced35d0-lcedt:~/triton/distributed/v0.2.0/triton_distributed/examples/python_rs/llm/vllm$ python3 -m common.client \
    --prompt "what is the capital of france?" \
    --max-tokens 10 \
    --temperature 0.5
INFO 02-07 11:18:48 __init__.py:190] Automatically detected platform cuda.
...
[7587884607120538396]
Annotated(data=' Well', event=None, comment=[], id=None)
Annotated(data=' Well,', event=None, comment=[], id=None)
Annotated(data=' Well, France', event=None, comment=[], id=None)
Annotated(data=' Well, France is', event=None, comment=[], id=None)
Annotated(data=' Well, France is a', event=None, comment=[], id=None)
Annotated(data=' Well, France is a country', event=None, comment=[], id=None)
Annotated(data=' Well, France is a country located', event=None, comment=[], id=None)
Annotated(data=' Well, France is a country located in', event=None, comment=[], id=None)
Annotated(data=' Well, France is a country located in Western', event=None, comment=[], id=None)
Annotated(data=' Well, France is a country located in Western Europe', event=None, comment=[], id=None)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did not assume container usage. I was running this both through ngc pytorch container and on bare metal on my desktop. I like the idea of not having to run it through the container as the user should be able to just install wheels.

We can still refer to our container for reproduction, but it would have to be

./container/build.sh
./container/run.sh -it

as vllm container has some settings required for the old example. Or we just remove the old example right now, @nnshah1 ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Assuming same fixes mentioned, I'm able to run the disaggregated example on single node with 2 gpus (heterogeneous gpus too!)

Prefill Worker

# Make sure venv is activated, and LD_LIBRARY_PATH has necessary libs from venv
$ CUDA_VISIBLE_DEVICES=0 python3 -m disaggregated.prefill_worker \
    --model deepseek-ai/DeepSeek-R1-Distill-Llama-8B \
    --max-model-len 100 \
    --gpu-memory-utilization 0.8 \
    --enforce-eager \
    --kv-transfer-config \
    '{"kv_connector":"PyNcclConnector","kv_role":"kv_producer","kv_rank":0,"kv_parallel_size":2}'
...
INFO 02-07 11:27:46 model_runner.py:1115] Loading model weights took 14.9888 GB
INFO 02-07 11:27:47 worker.py:267] Memory profiling takes 0.98 seconds
INFO 02-07 11:27:47 worker.py:267] the current vLLM instance can use total_gpu_memory (47.50GiB) x gpu_memory_utilization (0.80) = 38.00GiB
INFO 02-07 11:27:47 worker.py:267] model weights take 14.99GiB; non_torch_memory takes 0.17GiB; PyTorch activation peak memory takes 1.19GiB; the rest of the memory reserved for KV Cache is 21.66GiB.
INFO 02-07 11:27:47 executor_base.py:110] # CUDA blocks: 11089, # CPU blocks: 2048
INFO 02-07 11:27:47 executor_base.py:115] Maximum concurrency for 100 tokens per request: 1774.24x
INFO 02-07 11:27:50 llm_engine.py:431] init engine (profile, create kv cache, warmup model) took 4.11 seconds
INFO 02-07 11:28:09 prefill_worker.py:41] Received prefill request: prompt='what is the capital of france?' sampling_params={'temperature': 0.5, 'max_tokens': 1} request_id='84bf7368-f5b1-43e2-89da-e46867af5ac8'
INFO 02-07 11:28:09 async_llm_engine.py:211] Added request 84bf7368-f5b1-43e2-89da-e46867af5ac8.
INFO 02-07 11:28:09 metrics.py:455] Avg prompt throughput: 0.4 tokens/s, Avg generation throughput: 0.1 tokens/s, Running: 0 reqs, Swapped: 0 reqs, Pending: 0 reqs, GPU KV cache usage: 0.0%, CPU KV cache usage: 0.0%.
INFO 02-07 11:28:09 async_llm_engine.py:179] Finished request 84bf7368-f5b1-43e2-89da-e46867af5ac8.

Decode Worker

# Make sure venv is activated, and LD_LIBRARY_PATH has necessary libs from venv
$ CUDA_VISIBLE_DEVICES=1 python3 -m disaggregated.decode_worker     --model deepseek-ai/DeepSeek-R1-Distill-Llama-8B     --max-model-len 100     --gpu-memory-utilization 0.8     --enforce-eager     --kv-transfer-config     '{"kv_connector":"PyNcclConnector","kv_role":"kv_consumer","kv_rank":1,"kv_parallel_size":2}'
...
INFO 02-07 11:27:46 model_runner.py:1115] Loading model weights took 14.9888 GB
INFO 02-07 11:27:47 worker.py:267] Memory profiling takes 1.16 seconds
INFO 02-07 11:27:47 worker.py:267] the current vLLM instance can use total_gpu_memory (23.69GiB) x gpu_memory_utilization (0.80) = 18.95GiB
INFO 02-07 11:27:47 worker.py:267] model weights take 14.99GiB; non_torch_memory takes 0.14GiB; PyTorch activation peak memory takes 1.19GiB; the rest of the memory reserved for KV Cache is 2.63GiB.
INFO 02-07 11:27:47 executor_base.py:110] # CUDA blocks: 1347, # CPU blocks: 2048
INFO 02-07 11:27:47 executor_base.py:115] Maximum concurrency for 100 tokens per request: 215.52x
INFO 02-07 11:27:50 llm_engine.py:431] init engine (profile, create kv cache, warmup model) took 4.31 seconds
INFO 02-07 11:28:09 decode_worker.py:43] Received request: prompt='what is the capital of france?' sampling_params={'temperature': 0.5, 'max_tokens': 10}
INFO 02-07 11:28:09 async_llm_engine.py:211] Added request 84bf7368-f5b1-43e2-89da-e46867af5ac8.
INFO 02-07 11:28:09 metrics.py:455] Avg prompt throughput: 0.4 tokens/s, Avg generation throughput: 0.1 tokens/s, Running: 1 reqs, Swapped: 0 reqs, Pending: 0 reqs, GPU KV cache usage: 0.1%, CPU KV cache usage: 0.0%.
INFO 02-07 11:28:10 async_llm_engine.py:179] Finished request 84bf7368-f5b1-43e2-89da-e46867af5ac8.

Client

# Make sure venv is activated, and LD_LIBRARY_PATH has necessary libs from venv
(python-wheel) rmccormick@ced35d0-lcedt:~/triton/distributed/v0.2.0/triton_distributed/examples/python_rs/llm/vllm$ python3 -m common.client \
    --prompt "what is the capital of france?" \
    --max-tokens 10 \
    --temperature 0.5

INFO 02-07 11:28:09 __init__.py:190] Automatically detected platform cuda.
WARNING 02-07 11:28:09 cuda.py:336] Detected different devices in the system:
WARNING 02-07 11:28:09 cuda.py:336] NVIDIA GeForce RTX 3090
WARNING 02-07 11:28:09 cuda.py:336] NVIDIA RTX 5880 Ada Generation
WARNING 02-07 11:28:09 cuda.py:336] Please make sure to set `CUDA_DEVICE_ORDER=PCI_BUS_ID` to avoid unexpected behavior.
[7587884607120538406]
Annotated(data=' Well', event=None, comment=[], id=None)
Annotated(data=' Well,', event=None, comment=[], id=None)
Annotated(data=' Well, France', event=None, comment=[], id=None)
Annotated(data=' Well, France is', event=None, comment=[], id=None)
Annotated(data=' Well, France is a', event=None, comment=[], id=None)
Annotated(data=' Well, France is a country', event=None, comment=[], id=None)
Annotated(data=' Well, France is a country located', event=None, comment=[], id=None)
Annotated(data=' Well, France is a country located in', event=None, comment=[], id=None)
Annotated(data=' Well, France is a country located in Western', event=None, comment=[], id=None)
Annotated(data=' Well, France is a country located in Western Europe', event=None, comment=[], id=None)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lets not remove the old example yet - let's update the instructions now


1. Follow the setup instructions in the Python bindings [README](/runtime/rust/python-wheel/README.md) to prepare your environment

2. Install vLLM:
```bash
uv pip install vllm==0.7.2
```

3. Start required services (etcd and NATS):

Option A: Using [Docker Compose](/runtime/rust/docker-compose.yml) (Recommended)
```bash
docker-compose up -d
```

Option B: Manual Setup

- [NATS.io](https://docs.nats.io/running-a-nats-service/introduction/installation) server with [Jetstream](https://docs.nats.io/nats-concepts/jetstream)
- example: `nats-server -js --trace`
- [etcd](https://etcd.io) server
- follow instructions in [etcd installation](https://etcd.io/docs/v3.5/install/) to start an `etcd-server` locally

## Deployment Options

### 1. Monolithic Deployment

Run the server and client components in separate terminal sessions:

**Terminal 1 - Server:**
```bash
python3 -m monolith.worker \
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this in the container or outside the container - or not related to container?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assumed no container, meaning just running in users environment as we now have wheels.

--model deepseek-ai/DeepSeek-R1-Distill-Llama-8B \
--max-model-len 100 \
--enforce-eager
```

**Terminal 2 - Client:**
```bash
python3 -m common.client \
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here - is this from the same container / separate?

--prompt "what is the capital of france?" \
--max-tokens 10 \
--temperature 0.5
```

The output should look similar to:
```
Annotated(data=' Well', event=None, comment=[], id=None)
Annotated(data=' Well,', event=None, comment=[], id=None)
Annotated(data=' Well, France', event=None, comment=[], id=None)
Annotated(data=' Well, France is', event=None, comment=[], id=None)
Annotated(data=' Well, France is a', event=None, comment=[], id=None)
Annotated(data=' Well, France is a country', event=None, comment=[], id=None)
Annotated(data=' Well, France is a country located', event=None, comment=[], id=None)
Annotated(data=' Well, France is a country located in', event=None, comment=[], id=None)
Annotated(data=' Well, France is a country located in Western', event=None, comment=[], id=None)
Annotated(data=' Well, France is a country located in Western Europe', event=None, comment=[], id=None)
```


### 2. Disaggregated Deployment

This deployment option splits the model serving across prefill and decode workers, enabling more efficient resource utilization.

**Terminal 1 - Prefill Worker:**
```bash
CUDA_VISIBLE_DEVICES=0 python3 -m disaggregated.prefill_worker \
--model deepseek-ai/DeepSeek-R1-Distill-Llama-8B \
--max-model-len 100 \
--gpu-memory-utilization 0.8 \
--enforce-eager \
--kv-transfer-config \
'{"kv_connector":"PyNcclConnector","kv_role":"kv_producer","kv_rank":0,"kv_parallel_size":2}'
```

**Terminal 2 - Decode Worker:**
```bash
CUDA_VISIBLE_DEVICES=1 python3 -m disaggregated.decode_worker \
--model deepseek-ai/DeepSeek-R1-Distill-Llama-8B \
--max-model-len 100 \
--gpu-memory-utilization 0.8 \
--enforce-eager \
--kv-transfer-config \
'{"kv_connector":"PyNcclConnector","kv_role":"kv_consumer","kv_rank":1,"kv_parallel_size":2}'
```

**Terminal 3 - Client:**
```bash
python3 -m common.client \
--prompt "what is the capital of france?" \
--max-tokens 10 \
--temperature 0.5
```

The disaggregated deployment utilizes separate GPUs for prefill and decode operations, allowing for optimized resource allocation and improved performance. For more details on the disaggregated deployment, please refer to the [vLLM documentation](https://docs.vllm.ai/en/latest/features/disagg_prefill.html).



### 3. Multi-Node Deployment

The vLLM workers can be deployed across multiple nodes by configuring the NATS and etcd connection endpoints through environment variables. This enables distributed inference across a cluster.

Set the following environment variables on each node before running the workers:

```bash
export NATS_SERVER="nats://<nats-server-host>:<nats-server-port>"
export ETCD_ENDPOINTS="http://<etcd-server-host1>:<etcd-server-port>,http://<etcd-server-host2>:<etcd-server-port>",...
```

For disaggregated deployment, you will also need to pass the `kv_ip` and `kv_port` to the workers in the `kv_transfer_config` argument:

```bash
...
--kv-transfer-config \
'{"kv_connector":"PyNcclConnector","kv_role":"kv_producer","kv_rank":<rank>,"kv_parallel_size":2,"kv_ip":<master_node_ip>,"kv_port":<kv_port>}'
```




Empty file.
65 changes: 65 additions & 0 deletions examples/python_rs/llm/vllm/common/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import asyncio

import uvloop
from triton_distributed_rs import DistributedRuntime, triton_worker
from vllm.utils import FlexibleArgumentParser

from .protocol import Request


@triton_worker()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to do - does seem this would be better as @triton_distributed_component

async def worker(
runtime: DistributedRuntime, prompt: str, max_tokens: int, temperature: float
):
"""
Instantiate a `backend` client and call the `generate` endpoint
"""
# get endpoint
endpoint = runtime.namespace("triton-init").component("vllm").endpoint("generate")

# create client
client = await endpoint.client()

# list the endpoints
print(client.endpoint_ids())

# issue request
stream = await client.generate(
Request(
prompt=prompt,
sampling_params={"temperature": temperature, "max_tokens": max_tokens},
).model_dump_json()
)

# process response
async for resp in stream:
print(resp)


if __name__ == "__main__":
uvloop.install()

parser = FlexibleArgumentParser()
parser.add_argument("--prompt", type=str, default="what is the capital of france?")
parser.add_argument("--max-tokens", type=int, default=10)
parser.add_argument("--temperature", type=float, default=0.5)

args = parser.parse_args()

asyncio.run(worker(args.prompt, args.max_tokens, args.temperature))
25 changes: 25 additions & 0 deletions examples/python_rs/llm/vllm/common/parser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.utils import FlexibleArgumentParser


def parse_vllm_args() -> AsyncEngineArgs:
parser = FlexibleArgumentParser()
parser = AsyncEngineArgs.add_cli_args(parser)
args = parser.parse_args()
return AsyncEngineArgs.from_cli_args(args)
34 changes: 34 additions & 0 deletions examples/python_rs/llm/vllm/common/protocol.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from pydantic import BaseModel


class Request(BaseModel):
prompt: str
sampling_params: dict


class PrefillRequest(Request):
request_id: str


class Response(BaseModel):
text: str


class PrefillResponse(BaseModel):
prefilled: bool
Empty file.
92 changes: 92 additions & 0 deletions examples/python_rs/llm/vllm/disaggregated/decode_worker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import asyncio
import uuid

import uvloop
import vllm
from common.parser import parse_vllm_args
from common.protocol import PrefillRequest, Request, Response
from triton_distributed_rs import DistributedRuntime, triton_endpoint, triton_worker
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.logger import logger as vllm_logger


class VllmDecodeEngine:
"""
Request handler for the generate endpoint
"""

def __init__(self, engine_args: AsyncEngineArgs, prefill):
assert (
engine_args.kv_transfer_config.is_kv_consumer
), "Decode worker must be a KV consumer"
self.engine = vllm.AsyncLLMEngine.from_engine_args(engine_args)
self.prefill = prefill

@triton_endpoint(Request, Response)
async def generate(self, request):
vllm_logger.info(f"Received request: {request}")
sampling_params = vllm.SamplingParams(**request.sampling_params)
request_id = str(uuid.uuid4())

prefill_sampling_params = {**request.sampling_params}
prefill_sampling_params["max_tokens"] = 1
prefill_request = PrefillRequest(
prompt=request.prompt,
sampling_params=prefill_sampling_params,
request_id=request_id,
)
prefill_generator = await self.prefill.generate(
prefill_request.model_dump_json()
)
prefill_response = [resp async for resp in prefill_generator]
assert len(prefill_response) == 1, "Prefill response should be a single boolean"
prefill_response = prefill_response[0]
vllm_logger.debug(f"Prefill response: {prefill_response}")

async for response in self.engine.generate(
request.prompt, sampling_params, request_id
):
vllm_logger.debug(f"Generated response: {response}")
yield response.outputs[0].text


@triton_worker()
async def worker(runtime: DistributedRuntime, engine_args: AsyncEngineArgs):
"""
Instantiate a `backend` component and serve the `generate` endpoint
A `Component` can serve multiple endpoints
"""
component = runtime.namespace("triton-init").component("vllm")
await component.create_service()

prefill = (
await runtime.namespace("triton-init")
.component("prefill")
.endpoint("generate")
.client()
)

endpoint = component.endpoint("generate")
await endpoint.serve_endpoint(VllmDecodeEngine(engine_args, prefill).generate)


if __name__ == "__main__":
uvloop.install()
engine_args = parse_vllm_args()
asyncio.run(worker(engine_args))
Loading
Loading