Skip to content

Commit

Permalink
[Fix] Adjust default chunked prefill size and cuda graph max bs accor…
Browse files Browse the repository at this point in the history
…ding to GPU memory capacity (#2044)
  • Loading branch information
merrymercy authored Nov 15, 2024
1 parent c29b98e commit b01df48
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 3 deletions.
20 changes: 17 additions & 3 deletions python/sglang/srt/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,12 @@
import tempfile
from typing import List, Optional

from sglang.srt.utils import is_flashinfer_available, is_ipv6, is_port_available
from sglang.srt.utils import (
get_gpu_memory_capacity,
is_flashinfer_available,
is_ipv6,
is_port_available,
)

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -143,6 +148,9 @@ def __post_init__(self):
# Disable chunked prefill
self.chunked_prefill_size = None

if self.random_seed is None:
self.random_seed = random.randint(0, 1 << 30)

# Mem fraction depends on the tensor parallelism size
if self.mem_fraction_static is None:
if self.tp_size >= 16:
Expand All @@ -156,8 +164,14 @@ def __post_init__(self):
else:
self.mem_fraction_static = 0.88

if self.random_seed is None:
self.random_seed = random.randint(0, 1 << 30)
# Adjust for GPUs with small memory capacities
gpu_mem = get_gpu_memory_capacity()
if gpu_mem < 25000:
logger.warning(
"Automatically adjust --chunked-prefill-size for small GPUs."
)
self.chunked_prefill_size //= 4 # make it 2048
self.cuda_graph_max_bs = 4

# Deprecation warnings
if self.disable_flashinfer:
Expand Down
33 changes: 33 additions & 0 deletions python/sglang/srt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import shutil
import signal
import socket
import subprocess
import tempfile
import time
import warnings
Expand Down Expand Up @@ -791,3 +792,35 @@ def add_prometheus_middleware(app):
# Workaround for 307 Redirect for /metrics
metrics_route.path_regex = re.compile("^/metrics(?P<path>.*)$")
app.routes.append(metrics_route)


def get_gpu_memory_capacity():
try:
# Run nvidia-smi and capture the output
result = subprocess.run(
["nvidia-smi", "--query-gpu=memory.total", "--format=csv,noheader,nounits"],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
)

if result.returncode != 0:
raise RuntimeError(f"nvidia-smi error: {result.stderr.strip()}")

# Parse the output to extract memory values
memory_values = [
float(mem)
for mem in result.stdout.strip().split("\n")
if re.match(r"^\d+(\.\d+)?$", mem.strip())
]

if not memory_values:
raise ValueError("No GPU memory values found.")

# Return the minimum memory value
return min(memory_values)

except FileNotFoundError:
raise RuntimeError(
"nvidia-smi not found. Ensure NVIDIA drivers are installed and accessible."
)

0 comments on commit b01df48

Please sign in to comment.