Skip to content

Commit

Permalink
Add a new api configure_logging to allow dumping the requests (#2875)
Browse files Browse the repository at this point in the history
  • Loading branch information
merrymercy authored Jan 13, 2025
1 parent 923f518 commit 46d4431
Show file tree
Hide file tree
Showing 13 changed files with 164 additions and 71 deletions.
2 changes: 1 addition & 1 deletion 3rdparty/amd/profiling/PROFILING.md
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ loadTracer.sh python3 -m sglang.launch_server \
--model-path /sgl-workspace/sglang/dummy_grok1 \
--tokenizer-path Xenova/grok-1-tokenizer \
--load-format dummy \
--quant fp8 \
--quantization fp8 \
--tp 8 \
--port 30000 \
--disable-radix-cache 2>&1 | tee "$LOGFILE"
Expand Down
2 changes: 1 addition & 1 deletion 3rdparty/amd/profiling/server.sh
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ loadTracer.sh python3 -m sglang.launch_server \
--model-path /sgl-workspace/sglang/dummy_grok1 \
--tokenizer-path Xenova/grok-1-tokenizer \
--load-format dummy \
--quant fp8 \
--quantization fp8 \
--tp 8 \
--port 30000 \
--disable-radix-cache 2>&1 | tee "$LOGFILE"
2 changes: 1 addition & 1 deletion 3rdparty/amd/tuning/TUNING.md
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ To maximize moe kernel efficiency, need to use below scripts to find out the bes

```bash
#Tuning
#for example, we have one case like this "python3 -m sglang.bench_latency --model dummy_grok1/ --load-format dummy --tokenizer-path Xenova/grok-1-tokenizer --tp 8 --batch-size 32 --input 1024 --output 8 --attention-backend triton --sampling-backend pytorch --quant fp" to run, it defined batch-size 32 input lenth 1024 and output length 8, from "--batch" in moe view point, the prefill batch is 32*1024 = 32768, the decode batch is 32*1(only one output token generated in each run).
#for example, we have one case like this "python3 -m sglang.bench_latency --model dummy_grok1/ --load-format dummy --tokenizer-path Xenova/grok-1-tokenizer --tp 8 --batch-size 32 --input 1024 --output 8 --attention-backend triton --sampling-backend pytorch --quantization fp8" to run, it defined batch-size 32 input lenth 1024 and output length 8, from "--batch" in moe view point, the prefill batch is 32*1024 = 32768, the decode batch is 32*1(only one output token generated in each run).
#so we can tune decode moe use below command
python benchmark_moe_rocm.py --model grok1 --tp-size 8 --dtype float8 --batch "32"
# and use this command to tune prefill moe
Expand Down
2 changes: 1 addition & 1 deletion benchmark/blog_v0_2/405b_sglang.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
# wget https://huggingface.co/neuralmagic/Meta-Llama-3.1-8B-Instruct-quantized.w8a8/resolve/main/tokenizer_config.json

# Launch sglang
# python -m sglang.launch_server --model-path ~/llama-3.1-405b-fp8-dummy/ --load-format dummy --tp 8 --quant fp8 --disable-radix --mem-frac 0.87
# python -m sglang.launch_server --model-path ~/llama-3.1-405b-fp8-dummy/ --load-format dummy --tp 8 --quantization fp8 --disable-radix --mem-frac 0.87

# offline
python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompt 3000 --random-input 1024 --random-output 1024 > sglang_log11
Expand Down
43 changes: 43 additions & 0 deletions python/sglang/srt/managers/configure_logging.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
"""
Copyright 2023-2025 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

"""
Configure the logging settings of a server.
Usage:
python3 -m sglang.srt.managers.configure_logging --url http://localhost:30000
"""

import argparse

import requests

if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--url", type=str, default="http://localhost:30000")
parser.add_argument(
"--dump-requests-folder", type=str, default="/tmp/sglang_request_dump"
)
parser.add_argument("--dump-requests-threshold", type=int, default=1000)
args = parser.parse_args()

response = requests.post(
args.url + "/configure_logging",
json={
"dump_requests_folder": args.dump_requests_folder,
"dump_requests_threshold": args.dump_requests_threshold,
},
)
assert response.status_code == 200
7 changes: 7 additions & 0 deletions python/sglang/srt/managers/io_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,13 @@ class ProfileReq(Enum):
STOP_PROFILE = 2


@dataclass
class ConfigureLoggingReq:
log_requests: Optional[bool] = None
dump_requests_folder: Optional[str] = None
dump_requests_threshold: Optional[int] = None


@dataclass
class OpenSessionReqInput:
capacity_of_str_len: int
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@
from sglang.srt.model_executor.forward_batch_info import ForwardMode
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
from sglang.srt.utils import (
broadcast_pyobj,
configure_logger,
Expand All @@ -92,7 +93,6 @@
set_random_seed,
suppress_other_loggers,
)
from sglang.torch_memory_saver_adapter import TorchMemorySaverAdapter
from sglang.utils import get_exception_traceback

logger = logging.getLogger(__name__)
Expand Down
41 changes: 39 additions & 2 deletions python/sglang/srt/managers/tokenizer_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,12 @@
import dataclasses
import logging
import os
import pickle
import signal
import sys
import time
import uuid
from datetime import datetime
from typing import Any, Awaitable, Dict, Generic, List, Optional, Tuple, TypeVar, Union

import fastapi
Expand All @@ -43,6 +45,7 @@
BatchStrOut,
BatchTokenIDOut,
CloseSessionReqInput,
ConfigureLoggingReq,
EmbeddingReqInput,
FlushCacheReq,
GenerateReqInput,
Expand Down Expand Up @@ -109,6 +112,7 @@ def __init__(
# Parse args
self.server_args = server_args
self.enable_metrics = server_args.enable_metrics
self.log_requests = server_args.log_requests

# Init inter-process communication
context = zmq.asyncio.Context(2)
Expand Down Expand Up @@ -167,6 +171,9 @@ def __init__(
# Store states
self.to_create_loop = True
self.rid_to_state: Dict[str, ReqState] = {}
self.dump_requests_folder = "" # By default do not dump
self.dump_requests_threshold = 1000
self.dump_request_list: List[Tuple] = []

# The event to notify the weight sync is finished.
self.model_update_lock = RWLock()
Expand Down Expand Up @@ -225,7 +232,7 @@ async def generate_request(

obj.normalize_batch_and_arguments()

if self.server_args.log_requests:
if self.log_requests:
logger.info(f"Receive: obj={dataclass_to_string_truncated(obj)}")

async with self.model_update_lock.reader_lock:
Expand Down Expand Up @@ -346,7 +353,7 @@ async def _wait_one_response(

state.out_list = []
if state.finished:
if self.server_args.log_requests:
if self.log_requests:
msg = f"Finish: obj={dataclass_to_string_truncated(obj)}, out={dataclass_to_string_truncated(out)}"
logger.info(msg)
del self.rid_to_state[obj.rid]
Expand Down Expand Up @@ -597,6 +604,15 @@ async def close_session(
assert not self.to_create_loop, "close session should not be the first request"
await self.send_to_scheduler.send_pyobj(obj)

def configure_logging(self, obj: ConfigureLoggingReq):
if obj.log_requests is not None:
self.log_requests = obj.log_requests
if obj.dump_requests_folder is not None:
self.dump_requests_folder = obj.dump_requests_folder
if obj.dump_requests_threshold is not None:
self.dump_requests_threshold = obj.dump_requests_threshold
logging.info(f"Config logging: {obj=}")

def create_abort_task(self, obj: GenerateReqInput):
# Abort the request if the client is disconnected.
async def abort_request():
Expand Down Expand Up @@ -708,6 +724,8 @@ async def handle_loop(self):

if self.enable_metrics:
self.collect_metrics(state, recv_obj, i)
if self.dump_requests_folder and state.finished:
self.dump_requests(state, out_dict)
elif isinstance(recv_obj, OpenSessionReqOutput):
self.session_futures[recv_obj.session_id].set_result(
recv_obj.session_id if recv_obj.success else None
Expand Down Expand Up @@ -850,6 +868,25 @@ def collect_metrics(self, state: ReqState, recv_obj: BatchStrOut, i: int):
(time.time() - state.created_time) / completion_tokens
)

def dump_requests(self, state: ReqState, out_dict: dict):
self.dump_request_list.append(
(state.obj, out_dict, state.created_time, time.time())
)

if len(self.dump_request_list) >= self.dump_requests_threshold:
to_dump = self.dump_request_list
self.dump_request_list = []

def background_task():
os.makedirs(self.dump_requests_folder, exist_ok=True)
current_time = datetime.now()
filename = current_time.strftime("%Y-%m-%d_%H-%M-%S") + ".pkl"
with open(os.path.join(self.dump_requests_folder, filename), "wb") as f:
pickle.dump(to_dump, f)

# Schedule the task to run in the background without awaiting it
asyncio.create_task(asyncio.to_thread(background_task))


class SignalHandler:
def __init__(self, tokenizer_manager):
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/mem_cache/memory_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
limitations under the License.
"""

from sglang.torch_memory_saver_adapter import TorchMemorySaverAdapter
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter

"""
Memory pool.
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/model_executor/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
from sglang.srt.model_loader import get_model
from sglang.srt.server_args import ServerArgs
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
from sglang.srt.utils import (
enable_show_time_cost,
get_available_gpu_memory,
Expand All @@ -60,7 +61,6 @@
monkey_patch_vllm_p2p_access_check,
set_cpu_offload_max_bytes,
)
from sglang.torch_memory_saver_adapter import TorchMemorySaverAdapter

logger = logging.getLogger(__name__)

Expand Down
Loading

0 comments on commit 46d4431

Please sign in to comment.