Skip to content

[vllm] support base_url parameter for vLLM client initialization #3324

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 76 additions & 0 deletions tests/test_vllm_client_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,3 +162,79 @@ def tearDownClass(cls):
child.send_signal(signal.SIGTERM)
cls.server_process.terminate()
cls.server_process.wait()


@pytest.mark.slow
@require_torch_multi_gpu
class TestVLLMClientServerBaseURL(unittest.TestCase):
model_id = "Qwen/Qwen2.5-1.5B"

@classmethod
def setUpClass(cls):
# We want the server to run on GPU 1, so we set CUDA_VISIBLE_DEVICES to "1"
env = os.environ.copy()
env["CUDA_VISIBLE_DEVICES"] = "1" # Restrict to GPU 1

# Start the server process
cls.server_process = subprocess.Popen(
["trl", "vllm-serve", "--model", cls.model_id], stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=env
)

# Initialize the client with base_url
cls.client = VLLMClient(base_url="http://localhost:8000", connection_timeout=120)

def test_generate(self):
prompts = ["Hello, AI!", "Tell me a joke"]
outputs = self.client.generate(prompts)

# Check that the output is a list
self.assertIsInstance(outputs, list)

# Check that the number of generated sequences is equal to the number of prompts
self.assertEqual(len(outputs), len(prompts))

# Check that the generated sequences are lists of integers
for seq in outputs:
self.assertTrue(all(isinstance(tok, int) for tok in seq))

def test_generate_with_params(self):
prompts = ["Hello, AI!", "Tell me a joke"]
outputs = self.client.generate(prompts, n=2, repetition_penalty=0.9, temperature=0.8, max_tokens=32)

# Check that the output is a list
self.assertIsInstance(outputs, list)

# Check that the number of generated sequences is 2 times the number of prompts
self.assertEqual(len(outputs), 2 * len(prompts))

# Check that the generated sequences are lists of integers
for seq in outputs:
self.assertTrue(all(isinstance(tok, int) for tok in seq))

# Check that the length of the generated sequences is less than or equal to 32
for seq in outputs:
self.assertLessEqual(len(seq), 32)

def test_update_model_params(self):
model = AutoModelForCausalLM.from_pretrained(self.model_id, device_map="cuda")
self.client.update_model_params(model)

def test_reset_prefix_cache(self):
# Test resetting the prefix cache
self.client.reset_prefix_cache()

@classmethod
def tearDownClass(cls):
super().tearDownClass()

# Close the client
cls.client.close_communicator()

# vLLM x pytest (or Popen) seems not to handle process termination well. To avoid zombie processes, we need to
# kill the server process and its children explicitly.
parent = psutil.Process(cls.server_process.pid)
children = parent.children(recursive=True)
for child in children:
child.send_signal(signal.SIGTERM)
cls.server_process.terminate()
cls.server_process.wait()
77 changes: 59 additions & 18 deletions trl/extras/vllm_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
import atexit
import logging
import time
from typing import Optional
from typing import Optional, Union
from urllib.parse import urlparse

import torch
from torch import nn
Expand Down Expand Up @@ -47,10 +48,12 @@ class VLLMClient:
weights in a distributed setting. Before using it, start the vLLM server with `trl vllm-serve`.

Args:
base_url (`str`, *optional*, defaults to `None`):
Base URL for the vLLM server (e.g., `"http://localhost:8000"`). If provided, host and server_port are ignored.
host (`str`, *optional*, defaults to `"0.0.0.0"`):
IP address of the vLLM server.
IP address of the vLLM server. Ignored if `base_url` is provided.
server_port (`int`, *optional*, defaults to `8000`):
Port number of the vLLM server.
Port number of the vLLM server. Ignored if `base_url` is provided.
group_port (`int`, *optional*, defaults to `51216`):
Port number for the weight update group.
connection_timeout (`float`, *optional*, defaults to `0.0`):
Expand All @@ -67,8 +70,25 @@ class VLLMClient:
INFO: Uvicorn running on http://0.0.0.0:8000 (Press CTRL+C to quit)
```

Use the client to generate completions and update model weights:
There are two ways to initialize the client:

1. Using base_url:
```python
>>> from trl.extras.vllm_client import VLLMClient
>>> # Connect to a local server
>>> client = VLLMClient(base_url="http://localhost:8000")
>>> # Or connect to a remote server
>>> client = VLLMClient(base_url="http://192.168.1.100:8000")
```
2. Using host and server_port:
```python
>>> from trl.extras.vllm_client import VLLMClient
>>> # Connect to a local server
>>> client = VLLMClient(host="localhost", server_port=8000)
>>> # Or connect to a remote server
>>> client = VLLMClient(host="192.168.1.100", server_port=8000)
```
Use the client to generate completions and update model weights:
```python
>>> from trl.extras.vllm_client import VLLMClient
>>> client = VLLMClient()
Expand All @@ -78,25 +98,37 @@ class VLLMClient:

>>> from transformers import AutoModelForCausalLM
>>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-7B", device_map="cuda")
>>> client.init_communicator()
>>> client.update_model_params(model)
```
"""

def __init__(
self, host: str = "0.0.0.0", server_port: int = 8000, group_port: int = 51216, connection_timeout: float = 0.0
self,
base_url: Optional[str] = None,
host: str = "0.0.0.0",
server_port: int = 8000,
group_port: int = 51216,
connection_timeout: float = 0.0
):
if not is_requests_available():
raise ImportError("requests is not installed. Please install it with `pip install requests`.")
if not is_vllm_available():
raise ImportError("vLLM is not installed. Please install it with `pip install vllm`.")

self.session = requests.Session()
self.host = host
self.server_port = server_port

if base_url is not None:
# Parse the base_url to extract host and port
parsed_url = urlparse(base_url)
scheme = parsed_url.scheme or "http"
Copy link
Member

@qgallouedec qgallouedec Apr 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this allows to something something like "localhost:8000" instead of "http://localhost:8000" right?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes

self.base_url = f"{scheme}://{parsed_url.netloc}{parsed_url.path}"
else:
self.host = host
self.server_port = server_port
self.base_url = f"http://{self.host}:{self.server_port}"
self.group_port = group_port
self.check_server(connection_timeout) # check server and fail after timeout
self.init_communicator()
atexit.register(self.close_communicator) # when the client object is deleted, close the weight update group

def check_server(self, total_timeout: float = 0.0, retry_interval: float = 2.0):
"""
Expand All @@ -109,7 +141,7 @@ def check_server(self, total_timeout: float = 0.0, retry_interval: float = 2.0):
total_timeout (`float`, *optional*, defaults to `0.0`):
Total timeout duration in seconds.
"""
url = f"http://{self.host}:{self.server_port}/health/"
url = f"{self.base_url}/health/"
start_time = time.time() # Record the start time

while True:
Expand All @@ -120,7 +152,7 @@ def check_server(self, total_timeout: float = 0.0, retry_interval: float = 2.0):
elapsed_time = time.time() - start_time
if elapsed_time >= total_timeout:
raise ConnectionError(
f"The vLLM server can't be reached at {self.host}:{self.server_port} after {total_timeout} "
f"The vLLM server can't be reached at {self.base_url} after {total_timeout} "
"seconds. Make sure the server is running by running `trl vllm-serve`."
) from exc
else:
Expand Down Expand Up @@ -171,7 +203,7 @@ def generate(
`list[list[int]]`:
List of lists of token IDs representing the model-generated completions for each prompt.
"""
url = f"http://{self.host}:{self.server_port}/generate/"
url = f"{self.base_url}/generate/"
response = self.session.post(
url,
json={
Expand All @@ -195,28 +227,36 @@ def init_communicator(self):
"""
Initializes the weight update group in a distributed setup for model synchronization.
"""
# Get the tensor parallel size from the server
url = f"http://{self.host}:{self.server_port}/get_tensor_parallel_size/"
# Get the world size from the server
url = f"{self.base_url}/get_world_size/"
response = requests.get(url)
if response.status_code == 200:
tensor_parallel_size = response.json()["tensor_parallel_size"]
vllm_world_size = response.json()["world_size"]
else:
raise Exception(f"Request failed: {response.status_code}, {response.text}")

world_size = tensor_parallel_size + 1
self.rank = tensor_parallel_size # The client's rank is the last process
world_size = vllm_world_size + 1 # add the client to the world
self.rank = vllm_world_size # the client's rank is the last process

# Initialize weight update group
url = f"http://{self.host}:{self.server_port}/init_communicator/"
url = f"{self.base_url}/init_communicator/"
# In the server side, the host is set to 0.0.0.0
response = self.session.post(url, json={"host": "0.0.0.0", "port": self.group_port, "world_size": world_size})
if response.status_code != 200:
raise Exception(f"Request failed: {response.status_code}, {response.text}")

# Brief delay to allow server initialization. While not strictly required (client socket will retry on
# connection failure), this prevents log warnings like:
# [W416 23:24:57.460001114 socket.cpp:204] [c10d] The hostname of the client socket cannot be retrieved. err=-3
time.sleep(0.1)

# Set up the communication group for weight broadcasting
pg = StatelessProcessGroup.create(host=self.host, port=self.group_port, rank=self.rank, world_size=world_size)
self.pynccl_comm = PyNcclCommunicator(pg, device=0)

# When the client object is deleted, close the weight update group
atexit.register(self.close_communicator)

def update_named_param(self, name: str, weights: torch.Tensor):
"""
Updates a specific named parameter in the model and broadcasts it to other processes.
Expand Down Expand Up @@ -279,6 +319,7 @@ def close_communicator(self):
from vllm import SamplingParams

client = VLLMClient()
client.init_communicator()

# Generate completions
responses = client.generate(["Hello, AI!", "Tell me a joke"], n=4, max_tokens=32, sampling_params=SamplingParams())
Expand Down
18 changes: 14 additions & 4 deletions trl/trainer/grpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,13 @@ class GRPOConfig(TrainingArguments):
use_vllm (`bool`, *optional*, defaults to `False`):
Whether to use vLLM for generating completions. If set to `True`, ensure that a GPU is kept unused for
training, as vLLM will require one for generation. vLLM must be installed (`pip install vllm`).
vllm_server_base_url (`str` or `None`, *optional*, defaults to `None`):
Base URL for the vLLM server (e.g., "http://localhost:8000"). If provided, vllm_server_host and
vllm_server_port are ignored.
vllm_server_host (`str`, *optional*, defaults to `"0.0.0.0"`):
Host of the vLLM server to connect to.
Host of the vLLM server to connect to. Ignored if `vllm_server_base_url` is provided.
vllm_server_port (`int`, *optional*, defaults to `8000`):
Port of the vLLM server to connect to.
Port of the vLLM server to connect to. Ignored if `vllm_server_base_url` is provided.
vllm_server_timeout (`float`, *optional*, defaults to `120.0`):
Total timeout duration in seconds to wait for the vLLM server to be up. If the server is not up after the
timeout, a `ConnectionError` is raised.
Expand Down Expand Up @@ -270,13 +273,20 @@ class GRPOConfig(TrainingArguments):
"running. To run the server, install vLLM (`pip install vllm`) and run `trl vllm-serve`."
},
)
vllm_server_base_url: Optional[str] = field(
default=None,
metadata={
"help": "Base URL for the vLLM server (e.g., 'http://localhost:8000'). If provided, vllm_server_host and "
"vllm_server_port are ignored."
},
)
vllm_server_host: str = field(
default="0.0.0.0",
metadata={"help": "Host of the vLLM server to connect to."},
metadata={"help": "Host of the vLLM server to connect to. Ignored if vllm_server_base_url is provided."},
)
vllm_server_port: int = field(
default=8000,
metadata={"help": "Port of the vLLM server to connect to."},
metadata={"help": "Port of the vLLM server to connect to. Ignored if vllm_server_base_url is provided."},
)
vllm_server_timeout: float = field(
default=120.0,
Expand Down
Loading