-
Notifications
You must be signed in to change notification settings - Fork 1.8k
[vllm] support base_url parameter for vLLM client initialization #3324
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
- Introduce base_url parameter to simplify server connection - Update __init__ method to support both base_url and host+port configurations - Modify URL construction in various methods to use base_url - Update documentation to include new initialization examples
7219281
to
38c9c8e
Compare
- Verify correct behavior of base_url attribute - Update tests to include tensor parallelism scenarios - Ensure proper cleanup of client resources
- Changed the host parameter from self.host to "0.0.0.0" in StatelessProcessGroup.create() - This modification ensures that weight broadcasting works correctly in theVLLM client
- Update `VLLMClient` initialization to support base URL- Modify existing parameters `vllm_server_host` and `vllm_server_port` to be ignored if base URL is provided
…m_client_custom_url
|
||
|
||
@pytest.mark.slow | ||
@require_3_gpus | ||
class TestVLLMClientServerTPBaseURL(unittest.TestCase): | ||
model_id = "Qwen/Qwen2.5-1.5B" | ||
|
||
@classmethod | ||
def setUpClass(cls): | ||
# We want the server to run on GPU 1 and 2, so we set CUDA_VISIBLE_DEVICES to "1,2" | ||
env = os.environ.copy() | ||
env["CUDA_VISIBLE_DEVICES"] = "1,2" # Restrict to GPU 1 and 2 | ||
|
||
# Start the server process | ||
cls.server_process = subprocess.Popen( | ||
["trl", "vllm-serve", "--model", cls.model_id, "--tensor_parallel_size", "2"], | ||
stdout=subprocess.PIPE, | ||
stderr=subprocess.PIPE, | ||
env=env, | ||
) | ||
|
||
# Initialize the client with base_url | ||
cls.client = VLLMClient(base_url="http://localhost:8000", connection_timeout=120) | ||
|
||
def test_generate(self): | ||
prompts = ["Hello, AI!", "Tell me a joke"] | ||
outputs = self.client.generate(prompts) | ||
|
||
# Check that the output is a list | ||
self.assertIsInstance(outputs, list) | ||
|
||
# Check that the number of generated sequences is equal to the number of prompts | ||
self.assertEqual(len(outputs), len(prompts)) | ||
|
||
# Check that the generated sequences are lists of integers | ||
for seq in outputs: | ||
self.assertTrue(all(isinstance(tok, int) for tok in seq)) | ||
|
||
def test_update_model_params(self): | ||
model = AutoModelForCausalLM.from_pretrained(self.model_id, device_map="cuda") | ||
self.client.update_model_params(model) | ||
|
||
def test_reset_prefix_cache(self): | ||
# Test resetting the prefix cache | ||
self.client.reset_prefix_cache() | ||
|
||
@classmethod | ||
def tearDownClass(cls): | ||
super().tearDownClass() | ||
|
||
# Close the client | ||
cls.client.close_communicator() | ||
|
||
# vLLM x pytest (or Popen) seems not to handle process termination well. To avoid zombie processes, we need to | ||
# kill the server process and its children explicitly. | ||
parent = psutil.Process(cls.server_process.pid) | ||
children = parent.children(recursive=True) | ||
for child in children: | ||
child.send_signal(signal.SIGTERM) | ||
cls.server_process.terminate() | ||
cls.server_process.wait() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@pytest.mark.slow | |
@require_3_gpus | |
class TestVLLMClientServerTPBaseURL(unittest.TestCase): | |
model_id = "Qwen/Qwen2.5-1.5B" | |
@classmethod | |
def setUpClass(cls): | |
# We want the server to run on GPU 1 and 2, so we set CUDA_VISIBLE_DEVICES to "1,2" | |
env = os.environ.copy() | |
env["CUDA_VISIBLE_DEVICES"] = "1,2" # Restrict to GPU 1 and 2 | |
# Start the server process | |
cls.server_process = subprocess.Popen( | |
["trl", "vllm-serve", "--model", cls.model_id, "--tensor_parallel_size", "2"], | |
stdout=subprocess.PIPE, | |
stderr=subprocess.PIPE, | |
env=env, | |
) | |
# Initialize the client with base_url | |
cls.client = VLLMClient(base_url="http://localhost:8000", connection_timeout=120) | |
def test_generate(self): | |
prompts = ["Hello, AI!", "Tell me a joke"] | |
outputs = self.client.generate(prompts) | |
# Check that the output is a list | |
self.assertIsInstance(outputs, list) | |
# Check that the number of generated sequences is equal to the number of prompts | |
self.assertEqual(len(outputs), len(prompts)) | |
# Check that the generated sequences are lists of integers | |
for seq in outputs: | |
self.assertTrue(all(isinstance(tok, int) for tok in seq)) | |
def test_update_model_params(self): | |
model = AutoModelForCausalLM.from_pretrained(self.model_id, device_map="cuda") | |
self.client.update_model_params(model) | |
def test_reset_prefix_cache(self): | |
# Test resetting the prefix cache | |
self.client.reset_prefix_cache() | |
@classmethod | |
def tearDownClass(cls): | |
super().tearDownClass() | |
# Close the client | |
cls.client.close_communicator() | |
# vLLM x pytest (or Popen) seems not to handle process termination well. To avoid zombie processes, we need to | |
# kill the server process and its children explicitly. | |
parent = psutil.Process(cls.server_process.pid) | |
children = parent.children(recursive=True) | |
for child in children: | |
child.send_signal(signal.SIGTERM) | |
cls.server_process.terminate() | |
cls.server_process.wait() |
I think one test is enough
@@ -47,10 +45,12 @@ class VLLMClient: | |||
weights in a distributed setting. Before using it, start the vLLM server with `trl vllm-serve`. | |||
|
|||
Args: | |||
base_url (`str`, *optional*, defaults to `None`): | |||
Base URL for the vLLM server (e.g., "http://localhost:8000"). If provided, host and server_port are ignored. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Base URL for the vLLM server (e.g., "http://localhost:8000"). If provided, host and server_port are ignored. | |
Base URL for the vLLM server (e.g., `"http://localhost:8000"`). If provided, host and server_port are ignored. |
host (`str`, *optional*, defaults to `"0.0.0.0"`): | ||
IP address of the vLLM server. | ||
IP address of the vLLM server. Ignored if base_url is provided. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IP address of the vLLM server. Ignored if base_url is provided. | |
IP address of the vLLM server. Ignored if `base_url` is provided. |
server_port (`int`, *optional*, defaults to `8000`): | ||
Port number of the vLLM server. | ||
Port number of the vLLM server. Ignored if base_url is provided. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Port number of the vLLM server. Ignored if base_url is provided. | |
Port number of the vLLM server. Ignored if `base_url` is provided. |
if base_url is not None: | ||
# Parse the base_url to extract host and port | ||
parsed_url = urlparse(base_url) | ||
scheme = parsed_url.scheme or "http" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this allows to something something like "localhost:8000"
instead of "http://localhost:8000"
right?
if args.vllm_server_base_url is not None: | ||
self.vllm_client = VLLMClient( | ||
base_url=args.vllm_server_base_url, connection_timeout=args.vllm_server_timeout | ||
) | ||
else: | ||
self.vllm_client = VLLMClient( | ||
host=args.vllm_server_host, | ||
server_port=args.vllm_server_port, | ||
connection_timeout=args.vllm_server_timeout | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe this is easier?
if args.vllm_server_base_url is not None: | |
self.vllm_client = VLLMClient( | |
base_url=args.vllm_server_base_url, connection_timeout=args.vllm_server_timeout | |
) | |
else: | |
self.vllm_client = VLLMClient( | |
host=args.vllm_server_host, | |
server_port=args.vllm_server_port, | |
connection_timeout=args.vllm_server_timeout | |
) | |
base_url = args.vllm_server_base_url if args.vllm_server_base_url is not None else f"http://{args.vllm_server_port}:{args.vllm_server_port}" | |
self.vllm_client = VLLMClient(base_url=base_url, connection_timeout=args.vllm_server_timeout) |
# In the server side, the host is set to 0.0.0.0 | ||
response = self.session.post(url, json={"host": "0.0.0.0", "port": self.group_port, "world_size": world_size}) | ||
if response.status_code != 200: | ||
raise Exception(f"Request failed: {response.status_code}, {response.text}") | ||
|
||
# Set up the communication group for weight broadcasting | ||
pg = StatelessProcessGroup.create(host=self.host, port=self.group_port, rank=self.rank, world_size=world_size) | ||
pg = StatelessProcessGroup.create(host="0.0.0.0", port=self.group_port, rank=self.rank, world_size=world_size) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm a bit worried about this change, but I don't have any example in mind where you could have host =! 0.0.0.0
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this change makes sense, thanks for suggesting it.
Sorry for the delayed review.
I've made a few comments.
Have you tested on your infrastructure?
related issue #3322