Skip to content

[vllm] support base_url parameter for vLLM client initialization #3324

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from

Conversation

re-imagined
Copy link

@re-imagined re-imagined commented Apr 18, 2025

related issue #3322

  • Introduce base_url parameter to simplify server connection
  • Update init method to support both base_url and host+port configurations
  • Modify URL construction in various methods to use base_url
  • Update documentation to include new initialization examples

@re-imagined re-imagined changed the title feat(draft): add base_url parameter for vLLM server connection feat(draft): support base_url parameter for vLLM client initialization Apr 18, 2025
- Introduce base_url parameter to simplify server connection
- Update __init__ method to support both base_url and host+port configurations
- Modify URL construction in various methods to use base_url
- Update documentation to include new initialization examples
@re-imagined re-imagined force-pushed the vllm_client_custom_url branch from 7219281 to 38c9c8e Compare April 18, 2025 05:31
- Verify correct behavior of base_url attribute
- Update tests to include tensor parallelism scenarios
- Ensure proper cleanup of client resources
- Changed the host parameter from self.host to "0.0.0.0" in StatelessProcessGroup.create()
- This modification ensures that weight broadcasting works correctly in theVLLM client
@re-imagined re-imagined changed the title feat(draft): support base_url parameter for vLLM client initialization [vllm] support base_url parameter for vLLM client initialization Apr 22, 2025
@re-imagined re-imagined marked this pull request as ready for review April 22, 2025 13:03
- Update `VLLMClient` initialization to support base URL- Modify existing parameters `vllm_server_host` and `vllm_server_port` to be ignored if base URL is provided
Comment on lines +241 to +301


@pytest.mark.slow
@require_3_gpus
class TestVLLMClientServerTPBaseURL(unittest.TestCase):
model_id = "Qwen/Qwen2.5-1.5B"

@classmethod
def setUpClass(cls):
# We want the server to run on GPU 1 and 2, so we set CUDA_VISIBLE_DEVICES to "1,2"
env = os.environ.copy()
env["CUDA_VISIBLE_DEVICES"] = "1,2" # Restrict to GPU 1 and 2

# Start the server process
cls.server_process = subprocess.Popen(
["trl", "vllm-serve", "--model", cls.model_id, "--tensor_parallel_size", "2"],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
env=env,
)

# Initialize the client with base_url
cls.client = VLLMClient(base_url="http://localhost:8000", connection_timeout=120)

def test_generate(self):
prompts = ["Hello, AI!", "Tell me a joke"]
outputs = self.client.generate(prompts)

# Check that the output is a list
self.assertIsInstance(outputs, list)

# Check that the number of generated sequences is equal to the number of prompts
self.assertEqual(len(outputs), len(prompts))

# Check that the generated sequences are lists of integers
for seq in outputs:
self.assertTrue(all(isinstance(tok, int) for tok in seq))

def test_update_model_params(self):
model = AutoModelForCausalLM.from_pretrained(self.model_id, device_map="cuda")
self.client.update_model_params(model)

def test_reset_prefix_cache(self):
# Test resetting the prefix cache
self.client.reset_prefix_cache()

@classmethod
def tearDownClass(cls):
super().tearDownClass()

# Close the client
cls.client.close_communicator()

# vLLM x pytest (or Popen) seems not to handle process termination well. To avoid zombie processes, we need to
# kill the server process and its children explicitly.
parent = psutil.Process(cls.server_process.pid)
children = parent.children(recursive=True)
for child in children:
child.send_signal(signal.SIGTERM)
cls.server_process.terminate()
cls.server_process.wait()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
@pytest.mark.slow
@require_3_gpus
class TestVLLMClientServerTPBaseURL(unittest.TestCase):
model_id = "Qwen/Qwen2.5-1.5B"
@classmethod
def setUpClass(cls):
# We want the server to run on GPU 1 and 2, so we set CUDA_VISIBLE_DEVICES to "1,2"
env = os.environ.copy()
env["CUDA_VISIBLE_DEVICES"] = "1,2" # Restrict to GPU 1 and 2
# Start the server process
cls.server_process = subprocess.Popen(
["trl", "vllm-serve", "--model", cls.model_id, "--tensor_parallel_size", "2"],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
env=env,
)
# Initialize the client with base_url
cls.client = VLLMClient(base_url="http://localhost:8000", connection_timeout=120)
def test_generate(self):
prompts = ["Hello, AI!", "Tell me a joke"]
outputs = self.client.generate(prompts)
# Check that the output is a list
self.assertIsInstance(outputs, list)
# Check that the number of generated sequences is equal to the number of prompts
self.assertEqual(len(outputs), len(prompts))
# Check that the generated sequences are lists of integers
for seq in outputs:
self.assertTrue(all(isinstance(tok, int) for tok in seq))
def test_update_model_params(self):
model = AutoModelForCausalLM.from_pretrained(self.model_id, device_map="cuda")
self.client.update_model_params(model)
def test_reset_prefix_cache(self):
# Test resetting the prefix cache
self.client.reset_prefix_cache()
@classmethod
def tearDownClass(cls):
super().tearDownClass()
# Close the client
cls.client.close_communicator()
# vLLM x pytest (or Popen) seems not to handle process termination well. To avoid zombie processes, we need to
# kill the server process and its children explicitly.
parent = psutil.Process(cls.server_process.pid)
children = parent.children(recursive=True)
for child in children:
child.send_signal(signal.SIGTERM)
cls.server_process.terminate()
cls.server_process.wait()

I think one test is enough

@@ -47,10 +45,12 @@ class VLLMClient:
weights in a distributed setting. Before using it, start the vLLM server with `trl vllm-serve`.

Args:
base_url (`str`, *optional*, defaults to `None`):
Base URL for the vLLM server (e.g., "http://localhost:8000"). If provided, host and server_port are ignored.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Base URL for the vLLM server (e.g., "http://localhost:8000"). If provided, host and server_port are ignored.
Base URL for the vLLM server (e.g., `"http://localhost:8000"`). If provided, host and server_port are ignored.

host (`str`, *optional*, defaults to `"0.0.0.0"`):
IP address of the vLLM server.
IP address of the vLLM server. Ignored if base_url is provided.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
IP address of the vLLM server. Ignored if base_url is provided.
IP address of the vLLM server. Ignored if `base_url` is provided.

server_port (`int`, *optional*, defaults to `8000`):
Port number of the vLLM server.
Port number of the vLLM server. Ignored if base_url is provided.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Port number of the vLLM server. Ignored if base_url is provided.
Port number of the vLLM server. Ignored if `base_url` is provided.

if base_url is not None:
# Parse the base_url to extract host and port
parsed_url = urlparse(base_url)
scheme = parsed_url.scheme or "http"
Copy link
Member

@qgallouedec qgallouedec Apr 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this allows to something something like "localhost:8000" instead of "http://localhost:8000" right?

Comment on lines +630 to +639
if args.vllm_server_base_url is not None:
self.vllm_client = VLLMClient(
base_url=args.vllm_server_base_url, connection_timeout=args.vllm_server_timeout
)
else:
self.vllm_client = VLLMClient(
host=args.vllm_server_host,
server_port=args.vllm_server_port,
connection_timeout=args.vllm_server_timeout
)
Copy link
Member

@qgallouedec qgallouedec Apr 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe this is easier?

Suggested change
if args.vllm_server_base_url is not None:
self.vllm_client = VLLMClient(
base_url=args.vllm_server_base_url, connection_timeout=args.vllm_server_timeout
)
else:
self.vllm_client = VLLMClient(
host=args.vllm_server_host,
server_port=args.vllm_server_port,
connection_timeout=args.vllm_server_timeout
)
base_url = args.vllm_server_base_url if args.vllm_server_base_url is not None else f"http://{args.vllm_server_port}:{args.vllm_server_port}"
self.vllm_client = VLLMClient(base_url=base_url, connection_timeout=args.vllm_server_timeout)

# In the server side, the host is set to 0.0.0.0
response = self.session.post(url, json={"host": "0.0.0.0", "port": self.group_port, "world_size": world_size})
if response.status_code != 200:
raise Exception(f"Request failed: {response.status_code}, {response.text}")

# Set up the communication group for weight broadcasting
pg = StatelessProcessGroup.create(host=self.host, port=self.group_port, rank=self.rank, world_size=world_size)
pg = StatelessProcessGroup.create(host="0.0.0.0", port=self.group_port, rank=self.rank, world_size=world_size)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a bit worried about this change, but I don't have any example in mind where you could have host =! 0.0.0.0

Copy link
Member

@qgallouedec qgallouedec left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this change makes sense, thanks for suggesting it.
Sorry for the delayed review.
I've made a few comments.
Have you tested on your infrastructure?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants