diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index c20aac847d..1531de09fe 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -22,7 +22,12 @@ import tempfile from typing import List, Optional -from sglang.srt.utils import is_flashinfer_available, is_ipv6, is_port_available +from sglang.srt.utils import ( + get_gpu_memory_capacity, + is_flashinfer_available, + is_ipv6, + is_port_available, +) logger = logging.getLogger(__name__) @@ -143,6 +148,9 @@ def __post_init__(self): # Disable chunked prefill self.chunked_prefill_size = None + if self.random_seed is None: + self.random_seed = random.randint(0, 1 << 30) + # Mem fraction depends on the tensor parallelism size if self.mem_fraction_static is None: if self.tp_size >= 16: @@ -156,8 +164,14 @@ def __post_init__(self): else: self.mem_fraction_static = 0.88 - if self.random_seed is None: - self.random_seed = random.randint(0, 1 << 30) + # Adjust for GPUs with small memory capacities + gpu_mem = get_gpu_memory_capacity() + if gpu_mem < 25000: + logger.warning( + "Automatically adjust --chunked-prefill-size for small GPUs." + ) + self.chunked_prefill_size //= 4 # make it 2048 + self.cuda_graph_max_bs = 4 # Deprecation warnings if self.disable_flashinfer: diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index ff0ef5e421..f7e32e653f 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -27,6 +27,7 @@ import shutil import signal import socket +import subprocess import tempfile import time import warnings @@ -791,3 +792,35 @@ def add_prometheus_middleware(app): # Workaround for 307 Redirect for /metrics metrics_route.path_regex = re.compile("^/metrics(?P.*)$") app.routes.append(metrics_route) + + +def get_gpu_memory_capacity(): + try: + # Run nvidia-smi and capture the output + result = subprocess.run( + ["nvidia-smi", "--query-gpu=memory.total", "--format=csv,noheader,nounits"], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + ) + + if result.returncode != 0: + raise RuntimeError(f"nvidia-smi error: {result.stderr.strip()}") + + # Parse the output to extract memory values + memory_values = [ + float(mem) + for mem in result.stdout.strip().split("\n") + if re.match(r"^\d+(\.\d+)?$", mem.strip()) + ] + + if not memory_values: + raise ValueError("No GPU memory values found.") + + # Return the minimum memory value + return min(memory_values) + + except FileNotFoundError: + raise RuntimeError( + "nvidia-smi not found. Ensure NVIDIA drivers are installed and accessible." + )