Skip to content

Commit

Permalink
Use an env var SGLANG_SET_CPU_AFFINITY to set cpu affinity; turn it o…
Browse files Browse the repository at this point in the history
…ff by default (#2217)
  • Loading branch information
merrymercy authored Nov 27, 2024
1 parent 37c8a57 commit a0e5874
Show file tree
Hide file tree
Showing 6 changed files with 23 additions and 20 deletions.
4 changes: 2 additions & 2 deletions python/sglang/bench_one_batch_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
It accepts server arguments (the same as launch_server.py) and benchmark arguments (e.g., batch size, input lengths).
Usage:
python3 -m sglang.bench_server_latency --model meta-llama/Meta-Llama-3.1-8B --batch-size 1 16 64 --input-len 1024 --output-len 8
python3 -m sglang.bench_one_batch_server --model meta-llama/Meta-Llama-3.1-8B --batch-size 1 16 64 --input-len 1024 --output-len 8
python3 -m sglang.bench_server_latency --model None --base-url http://localhost:30000 --batch-size 16 --input-len 1024 --output-len 8
python3 -m sglang.bench_one_batch_server --model None --base-url http://localhost:30000 --batch-size 16 --input-len 1024 --output-len 8
"""

import argparse
Expand Down
8 changes: 2 additions & 6 deletions python/sglang/srt/configs/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@

import json
import logging
import os
from enum import IntEnum, auto
from typing import List, Optional

from transformers import PretrainedConfig

from sglang.srt.hf_transformers_utils import get_config, get_context_length
from sglang.srt.utils import get_bool_env_var

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -59,13 +59,9 @@ def __init__(

# Derive context length
derived_context_len = get_context_length(self.hf_text_config)
allow_long_context = os.environ.get(
"SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN", None
)

if context_length is not None:
if context_length > derived_context_len:
if allow_long_context:
if get_bool_env_var("SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN"):
logger.warning(
f"Warning: User-specified context_length ({context_length}) is greater than the derived context_length ({derived_context_len}). "
f"This may lead to incorrect model outputs or CUDA errors."
Expand Down
6 changes: 3 additions & 3 deletions python/sglang/srt/layers/attention/flashinfer_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from sglang.global_config import global_config
from sglang.srt.layers.attention import AttentionBackend
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.utils import is_flashinfer_available
from sglang.srt.utils import get_bool_env_var, is_flashinfer_available

if TYPE_CHECKING:
from sglang.srt.layers.radix_attention import RadixAttention
Expand Down Expand Up @@ -47,8 +47,8 @@ def __init__(self, model_runner: ModelRunner):

# Parse constants
if "SGLANG_FLASHINFER_USE_TENSOR_CORE" in os.environ:
self.decode_use_tensor_cores = (
os.environ["SGLANG_FLASHINFER_USE_TENSOR_CORE"].lower() == "true"
self.decode_use_tensor_cores = get_bool_env_var(
"SGLANG_FLASHINFER_USE_TENSOR_CORE"
)
else:
if not _grouped_size_compiled_for_decode_kernels(
Expand Down
8 changes: 5 additions & 3 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,10 @@
broadcast_pyobj,
configure_logger,
crash_on_warnings,
get_bool_env_var,
get_zmq_socket,
gpu_proc_affinity,
kill_parent_process,
set_gpu_proc_affinity,
set_random_seed,
suppress_other_loggers,
)
Expand All @@ -82,7 +83,7 @@
logger = logging.getLogger(__name__)

# Test retract decode
test_retract = os.getenv("SGLANG_TEST_RETRACT", "false").lower() == "true"
test_retract = get_bool_env_var("SGLANG_TEST_RETRACT")


class Scheduler:
Expand Down Expand Up @@ -1405,7 +1406,8 @@ def run_scheduler_process(
pipe_writer,
):
# set cpu affinity to this gpu process
gpu_proc_affinity(server_args.tp_size, server_args.nnodes, gpu_id)
if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"):
set_gpu_proc_affinity(server_args.tp_size, server_args.nnodes, gpu_id)

# [For Router] if env var "DP_RANK" exist, set dp_rank to the value of the env var
if dp_rank is None and "DP_RANK" in os.environ:
Expand Down
13 changes: 9 additions & 4 deletions python/sglang/srt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def is_flashinfer_available():
Check whether flashinfer is available.
As of Oct. 6, 2024, it is only available on NVIDIA GPUs.
"""
if os.environ.get("SGLANG_IS_FLASHINFER_AVAILABLE", "true") == "false":
if get_bool_env_var("SGLANG_IS_FLASHINFER_AVAILABLE", default="true"):
return False
return torch.cuda.is_available() and not is_hip()

Expand Down Expand Up @@ -626,7 +626,7 @@ async def authentication(request, call_next):


def prepare_model_and_tokenizer(model_path: str, tokenizer_path: str):
if "SGLANG_USE_MODELSCOPE" in os.environ:
if get_bool_env_var("SGLANG_USE_MODELSCOPE"):
if not os.path.exists(model_path):
from modelscope import snapshot_download

Expand Down Expand Up @@ -931,7 +931,7 @@ def get_nvgpu_memory_capacity():

def crash_on_warnings():
# Crash on warning if we are running CI tests
return os.getenv("SGLANG_IS_IN_CI", "false").lower() == "true"
return get_bool_env_var("SGLANG_IS_IN_CI")


def get_device_name(device_id: int = 0) -> str:
Expand Down Expand Up @@ -990,7 +990,7 @@ def direct_register_custom_op(
my_lib._register_fake(op_name, fake_impl)


def gpu_proc_affinity(
def set_gpu_proc_affinity(
tp_size: int,
nnodes: int,
gpu_id: int,
Expand Down Expand Up @@ -1022,3 +1022,8 @@ def gpu_proc_affinity(
# set cpu_affinity to current process
p.cpu_affinity(bind_cpu_ids)
logger.info(f"Process {pid} gpu_id {gpu_id} is running on CPUs: {p.cpu_affinity()}")


def get_bool_env_var(name: str, default: str = "false") -> bool:
value = os.getenv(name, default)
return value.lower() in ("true", "1")
4 changes: 2 additions & 2 deletions python/sglang/test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from sglang.global_config import global_config
from sglang.lang.backend.openai import OpenAI
from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
from sglang.srt.utils import kill_child_process
from sglang.srt.utils import get_bool_env_var, kill_child_process
from sglang.test.run_eval import run_eval
from sglang.utils import get_exception_traceback

Expand All @@ -44,7 +44,7 @@

def is_in_ci():
"""Return whether it is in CI runner."""
return os.getenv("SGLANG_IS_IN_CI", "false").lower() == "true"
return get_bool_env_var("SGLANG_IS_IN_CI")


if is_in_ci():
Expand Down

0 comments on commit a0e5874

Please sign in to comment.